Skip to main content

rlx_ocr/
ctc.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! CTC decoding (greedy + beam search).
17
18use crate::config::DecodeMethod;
19use rten_tensor::NdTensorView;
20use rten_tensor::prelude::*;
21use std::collections::HashMap;
22use std::num::NonZeroU32;
23
24/// Item in a decoded sequence.
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub struct CtcStep {
27    pub label: u32,
28    pub pos: u32,
29}
30
31/// Decoded label sequence with log score.
32#[derive(Clone, Debug)]
33pub struct CtcHypothesis {
34    steps: Vec<CtcStep>,
35    score: f32,
36}
37
38impl CtcHypothesis {
39    pub fn steps(&self) -> &[CtcStep] {
40        &self.steps
41    }
42
43    pub fn score(&self) -> f32 {
44        self.score
45    }
46}
47
48/// Decode a `[seq, classes]` log-probability matrix.
49pub fn decode(input_seq: NdTensorView<f32, 2>, method: DecodeMethod) -> CtcHypothesis {
50    match method {
51        DecodeMethod::Greedy => decode_greedy(input_seq),
52        DecodeMethod::BeamSearch { width } => {
53            decode_beam(input_seq, width.max(1)).unwrap_or_else(|| decode_greedy(input_seq))
54        }
55    }
56}
57
58fn decode_greedy(prob_seq: NdTensorView<f32, 2>) -> CtcHypothesis {
59    let mut last_label = 0;
60    let mut steps = Vec::new();
61    let mut score = 0.;
62
63    for pos in 0..prob_seq.size(0) {
64        let mut best_label = 0usize;
65        let mut best_lp = prob_seq[[pos, 0]];
66        for label in 1..prob_seq.size(1) {
67            let lp = prob_seq[[pos, label]];
68            if lp > best_lp {
69                best_lp = lp;
70                best_label = label;
71            }
72        }
73        let label = best_label;
74        score += best_lp;
75        if label == last_label {
76            continue;
77        }
78        last_label = label;
79        if label > 0 {
80            steps.push(CtcStep {
81                label: label as u32,
82                pos: pos as u32,
83            });
84        }
85    }
86
87    CtcHypothesis { steps, score }
88}
89
90#[derive(Debug)]
91struct BeamProbs {
92    prob_blank: f32,
93    prob_no_blank: f32,
94}
95
96fn log_sum_exp<const N: usize>(log_probs: [f32; N]) -> f32 {
97    if log_probs.iter().all(|&x| x == f32::NEG_INFINITY) {
98        f32::NEG_INFINITY
99    } else {
100        let lp_max = log_probs
101            .into_iter()
102            .reduce(f32::max)
103            .unwrap_or(f32::NEG_INFINITY);
104        lp_max
105            + log_probs
106                .iter()
107                .map(|x| (x - lp_max).exp())
108                .sum::<f32>()
109                .ln()
110    }
111}
112
113fn decode_beam(prob_seq: NdTensorView<f32, 2>, beam_size: u32) -> Option<CtcHypothesis> {
114    let beam_size = NonZeroU32::new(beam_size)?;
115    let mut states: HashMap<Vec<CtcStep>, BeamProbs> = HashMap::new();
116    states.insert(
117        Vec::new(),
118        BeamProbs {
119            prob_blank: 0.,
120            prob_no_blank: f32::NEG_INFINITY,
121        },
122    );
123
124    for t in 0..prob_seq.size(0) {
125        let mut next: HashMap<Vec<CtcStep>, BeamProbs> = HashMap::new();
126        let blank_lp = prob_seq[[t, 0]];
127        for (prefix, state) in &states {
128            let p_b = state.prob_blank;
129            let p_nb = state.prob_no_blank;
130            merge_beam(
131                &mut next,
132                prefix.clone(),
133                log_sum_exp([p_b + blank_lp, p_nb + blank_lp]),
134                f32::NEG_INFINITY,
135            );
136            for label in 1..prob_seq.size(1) {
137                let lp = prob_seq[[t, label]];
138                let mut new_prefix = prefix.clone();
139                let step = CtcStep {
140                    label: label as u32,
141                    pos: t as u32,
142                };
143                let last = new_prefix.last().map(|s| s.label);
144                if last != Some(step.label) {
145                    new_prefix.push(step);
146                }
147                let (nb_lp, b_lp) = if last == Some(step.label) {
148                    (p_nb + lp, p_b + lp)
149                } else {
150                    (log_sum_exp([p_b + lp, p_nb + lp]), f32::NEG_INFINITY)
151                };
152                merge_beam(&mut next, new_prefix, b_lp, nb_lp);
153            }
154        }
155        let mut ranked: Vec<_> = next.into_iter().collect();
156        ranked.sort_by(|(_, a), (_, b)| {
157            let sa = log_sum_exp([a.prob_blank, a.prob_no_blank]);
158            let sb = log_sum_exp([b.prob_blank, b.prob_no_blank]);
159            sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
160        });
161        ranked.truncate(beam_size.get() as usize);
162        states = ranked.into_iter().collect();
163    }
164
165    let (prefix, probs) = states.into_iter().max_by(|(_, a), (_, b)| {
166        let sa = log_sum_exp([a.prob_blank, a.prob_no_blank]);
167        let sb = log_sum_exp([b.prob_blank, b.prob_no_blank]);
168        sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
169    })?;
170    Some(CtcHypothesis {
171        steps: prefix,
172        score: log_sum_exp([probs.prob_blank, probs.prob_no_blank]),
173    })
174}
175
176fn merge_beam(map: &mut HashMap<Vec<CtcStep>, BeamProbs>, prefix: Vec<CtcStep>, pb: f32, pnb: f32) {
177    use std::collections::hash_map::Entry;
178    match map.entry(prefix) {
179        Entry::Vacant(e) => {
180            e.insert(BeamProbs {
181                prob_blank: pb,
182                prob_no_blank: pnb,
183            });
184        }
185        Entry::Occupied(mut e) => {
186            let s = e.get_mut();
187            s.prob_blank = log_sum_exp([s.prob_blank, pb]);
188            s.prob_no_blank = log_sum_exp([s.prob_no_blank, pnb]);
189        }
190    }
191}