Skip to main content

oxicuda_seq/ctc/
ctc_decode.rs

1//! CTC decoding: best-path (greedy) and prefix-beam search (Graves 2006; Hannun 2014).
2//!
3//! Given per-frame log-probabilities `[T, C]` over a blank-augmented alphabet, a
4//! CTC decoder produces the most probable *label* sequence after applying the
5//! CTC collapse `B` (merge repeats, then drop blanks).
6//!
7//! Two strategies are provided:
8//!
9//! * [`ctc_greedy_decode`] — best-path decoding: take the arg-max symbol at each
10//!   frame and collapse. Fast (`O(T·C)`) but only a lower bound on the true
11//!   sequence probability because it ignores alignment multiplicity.
12//!
13//! * [`ctc_prefix_beam_search`] — prefix-beam search: maintain a beam of label
14//!   prefixes, tracking, for each prefix, the probability that it ends in a
15//!   blank (`p_b`) versus a non-blank (`p_nb`). This correctly sums the
16//!   probabilities of distinct alignments that collapse to the same prefix and
17//!   recovers higher-probability sequences than greedy decoding.
18//!
19//! All probabilities are accumulated in **log-space**.
20
21use crate::error::{SeqError, SeqResult};
22use std::collections::HashMap;
23
24/// Numerically-stable `log(exp(a) + exp(b))`.
25#[inline]
26fn log_add_exp(a: f64, b: f64) -> f64 {
27    if a == f64::NEG_INFINITY {
28        return b;
29    }
30    if b == f64::NEG_INFINITY {
31        return a;
32    }
33    let (hi, lo) = if a > b { (a, b) } else { (b, a) };
34    hi + (lo - hi).exp().ln_1p()
35}
36
37/// Validate the emission tensor shape and blank index.
38fn validate(log_probs: &[f64], t_len: usize, n_symbols: usize, blank: usize) -> SeqResult<()> {
39    if t_len == 0 || n_symbols == 0 {
40        return Err(SeqError::EmptyInput);
41    }
42    if log_probs.len() != t_len * n_symbols {
43        return Err(SeqError::ShapeMismatch {
44            expected: t_len * n_symbols,
45            got: log_probs.len(),
46        });
47    }
48    if blank >= n_symbols {
49        return Err(SeqError::IndexOutOfBounds {
50            index: blank,
51            len: n_symbols,
52        });
53    }
54    Ok(())
55}
56
57/// Best-path (greedy) CTC decode: arg-max per frame followed by CTC collapse.
58///
59/// Returns the decoded label sequence (blanks removed, repeats merged).
60pub fn ctc_greedy_decode(
61    log_probs: &[f64],
62    t_len: usize,
63    n_symbols: usize,
64    blank: usize,
65) -> SeqResult<Vec<usize>> {
66    validate(log_probs, t_len, n_symbols, blank)?;
67    let mut raw = Vec::with_capacity(t_len);
68    for ti in 0..t_len {
69        let row = &log_probs[ti * n_symbols..ti * n_symbols + n_symbols];
70        let mut best = 0usize;
71        let mut best_val = row[0];
72        for (c, &v) in row.iter().enumerate() {
73            if v.is_nan() {
74                return Err(SeqError::NumericalInstability(
75                    "NaN in CTC log-probs".into(),
76                ));
77            }
78            if v > best_val {
79                best_val = v;
80                best = c;
81            }
82        }
83        raw.push(best);
84    }
85    // Collapse: merge consecutive duplicates, then remove blanks.
86    let mut out = Vec::new();
87    let mut prev = usize::MAX;
88    for &sym in &raw {
89        if sym != prev && sym != blank {
90            out.push(sym);
91        }
92        prev = sym;
93    }
94    Ok(out)
95}
96
97/// Per-prefix log-probabilities split by trailing-blank vs trailing-non-blank.
98#[derive(Clone, Copy)]
99struct PrefixProb {
100    /// log P(prefix, last frame emitted blank).
101    p_blank: f64,
102    /// log P(prefix, last frame emitted a non-blank).
103    p_non_blank: f64,
104}
105
106impl PrefixProb {
107    #[inline]
108    fn total(&self) -> f64 {
109        log_add_exp(self.p_blank, self.p_non_blank)
110    }
111}
112
113/// A scored CTC decoding hypothesis returned by [`ctc_prefix_beam_search`].
114#[derive(Debug, Clone, PartialEq)]
115pub struct CtcHypothesis {
116    /// The decoded label sequence.
117    pub labels: Vec<usize>,
118    /// Total log-probability of the prefix (summed over alignments).
119    pub log_prob: f64,
120}
121
122/// Prefix-beam-search CTC decoding (Graves 2006; Hannun 2014).
123///
124/// * `beam_width` — maximum number of prefixes kept after each frame (`≥ 1`).
125///
126/// Returns the surviving hypotheses sorted by descending total log-probability;
127/// the first element is the most probable decoding.
128pub fn ctc_prefix_beam_search(
129    log_probs: &[f64],
130    t_len: usize,
131    n_symbols: usize,
132    blank: usize,
133    beam_width: usize,
134) -> SeqResult<Vec<CtcHypothesis>> {
135    validate(log_probs, t_len, n_symbols, blank)?;
136    if beam_width == 0 {
137        return Err(SeqError::InvalidParameter {
138            name: "beam_width".into(),
139            value: 0.0,
140        });
141    }
142    for &v in log_probs {
143        if v.is_nan() {
144            return Err(SeqError::NumericalInstability(
145                "NaN in CTC log-probs".into(),
146            ));
147        }
148    }
149
150    // beam maps prefix -> PrefixProb. The empty prefix starts with p_blank = 0.
151    let mut beam: HashMap<Vec<usize>, PrefixProb> = HashMap::new();
152    beam.insert(
153        Vec::new(),
154        PrefixProb {
155            p_blank: 0.0,
156            p_non_blank: f64::NEG_INFINITY,
157        },
158    );
159
160    for ti in 0..t_len {
161        let row = &log_probs[ti * n_symbols..ti * n_symbols + n_symbols];
162        let mut next: HashMap<Vec<usize>, PrefixProb> = HashMap::new();
163
164        for (prefix, prob) in &beam {
165            // 1) Emit blank: prefix is unchanged, accumulates into p_blank.
166            let entry = next.entry(prefix.clone()).or_insert(PrefixProb {
167                p_blank: f64::NEG_INFINITY,
168                p_non_blank: f64::NEG_INFINITY,
169            });
170            entry.p_blank = log_add_exp(entry.p_blank, prob.total() + row[blank]);
171
172            // 2) Emit a non-blank symbol c.
173            for c in 0..n_symbols {
174                if c == blank {
175                    continue;
176                }
177                let lp_c = row[c];
178                let last = prefix.last().copied();
179                if last == Some(c) {
180                    // Repeat of the current last label.
181                    // (a) extends to a NEW token only from a blank-terminated path.
182                    let mut new_prefix = prefix.clone();
183                    new_prefix.push(c);
184                    let e = next.entry(new_prefix).or_insert(PrefixProb {
185                        p_blank: f64::NEG_INFINITY,
186                        p_non_blank: f64::NEG_INFINITY,
187                    });
188                    e.p_non_blank = log_add_exp(e.p_non_blank, prob.p_blank + lp_c);
189                    // (b) merges into the SAME prefix from a non-blank path.
190                    let e_same = next.entry(prefix.clone()).or_insert(PrefixProb {
191                        p_blank: f64::NEG_INFINITY,
192                        p_non_blank: f64::NEG_INFINITY,
193                    });
194                    e_same.p_non_blank = log_add_exp(e_same.p_non_blank, prob.p_non_blank + lp_c);
195                } else {
196                    // Distinct from the last label: always extends the prefix,
197                    // from either blank- or non-blank-terminated paths.
198                    let mut new_prefix = prefix.clone();
199                    new_prefix.push(c);
200                    let e = next.entry(new_prefix).or_insert(PrefixProb {
201                        p_blank: f64::NEG_INFINITY,
202                        p_non_blank: f64::NEG_INFINITY,
203                    });
204                    e.p_non_blank = log_add_exp(e.p_non_blank, prob.total() + lp_c);
205                }
206            }
207        }
208
209        // Prune to the top `beam_width` prefixes by total probability.
210        let mut scored: Vec<(Vec<usize>, PrefixProb)> = next.into_iter().collect();
211        scored.sort_by(|a, b| {
212            b.1.total()
213                .partial_cmp(&a.1.total())
214                .unwrap_or(std::cmp::Ordering::Equal)
215        });
216        scored.truncate(beam_width);
217        beam = scored.into_iter().collect();
218    }
219
220    let mut hyps: Vec<CtcHypothesis> = beam
221        .into_iter()
222        .map(|(labels, prob)| CtcHypothesis {
223            labels,
224            log_prob: prob.total(),
225        })
226        .collect();
227    hyps.sort_by(|a, b| {
228        b.log_prob
229            .partial_cmp(&a.log_prob)
230            .unwrap_or(std::cmp::Ordering::Equal)
231    });
232    Ok(hyps)
233}
234
235// ─── Tests ───────────────────────────────────────────────────────────────────
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    fn to_log(probs: &[f64]) -> Vec<f64> {
242        probs.iter().map(|&p| p.max(1e-30).ln()).collect()
243    }
244
245    #[test]
246    fn greedy_collapses_repeats_and_blanks() {
247        // argmax path per frame: [1, 1, 0(blank), 2] → collapse → [1, 2].
248        let probs = vec![
249            0.1, 0.8, 0.1, //
250            0.1, 0.8, 0.1, //
251            0.8, 0.1, 0.1, //
252            0.1, 0.1, 0.8, //
253        ];
254        let lp = to_log(&probs);
255        let out = ctc_greedy_decode(&lp, 4, 3, 0).expect("decode");
256        assert_eq!(out, vec![1, 2]);
257    }
258
259    #[test]
260    fn greedy_all_blank_is_empty() {
261        let probs = vec![
262            0.9, 0.05, 0.05, //
263            0.9, 0.05, 0.05, //
264        ];
265        let lp = to_log(&probs);
266        let out = ctc_greedy_decode(&lp, 2, 3, 0).expect("decode");
267        assert!(out.is_empty());
268    }
269
270    #[test]
271    fn greedy_repeat_without_blank_merges() {
272        // [1,1] with no separating blank collapses to [1].
273        let probs = vec![
274            0.1, 0.8, 0.1, //
275            0.1, 0.8, 0.1, //
276        ];
277        let lp = to_log(&probs);
278        let out = ctc_greedy_decode(&lp, 2, 3, 0).expect("decode");
279        assert_eq!(out, vec![1]);
280    }
281
282    #[test]
283    fn greedy_blank_at_last_index() {
284        // blank = C-1 = 2; argmax path [0, 1, 2(blank)] → [0, 1].
285        let probs = vec![
286            0.8, 0.1, 0.1, //
287            0.1, 0.8, 0.1, //
288            0.1, 0.1, 0.8, //
289        ];
290        let lp = to_log(&probs);
291        let out = ctc_greedy_decode(&lp, 3, 3, 2).expect("decode");
292        assert_eq!(out, vec![0, 1]);
293    }
294
295    #[test]
296    fn beam_returns_sorted_hypotheses() {
297        let probs = vec![
298            0.2, 0.5, 0.3, //
299            0.1, 0.6, 0.3, //
300            0.3, 0.2, 0.5, //
301            0.4, 0.3, 0.3, //
302        ];
303        let lp = to_log(&probs);
304        let hyps = ctc_prefix_beam_search(&lp, 4, 3, 0, 8).expect("beam");
305        assert!(!hyps.is_empty());
306        for w in hyps.windows(2) {
307            assert!(w[0].log_prob >= w[1].log_prob - 1e-12);
308        }
309    }
310
311    #[test]
312    fn beam_top1_matches_greedy_for_peaked_input() {
313        // Sharply peaked emissions → beam best == greedy best.
314        let probs = vec![
315            0.02, 0.96, 0.02, //
316            0.96, 0.02, 0.02, //
317            0.02, 0.02, 0.96, //
318        ];
319        let lp = to_log(&probs);
320        let greedy = ctc_greedy_decode(&lp, 3, 3, 0).expect("greedy");
321        let beam = ctc_prefix_beam_search(&lp, 3, 3, 0, 16).expect("beam");
322        assert_eq!(beam[0].labels, greedy);
323    }
324
325    #[test]
326    fn beam_total_probability_consistent_with_loss() {
327        // For a peaked distribution where one sequence dominates, the beam's top
328        // log-prob should be close to the negative CTC loss of that sequence.
329        use crate::ctc::ctc_loss::ctc_loss;
330        let probs = vec![
331            0.02, 0.96, 0.02, //
332            0.96, 0.02, 0.02, //
333            0.02, 0.02, 0.96, //
334        ];
335        let lp = to_log(&probs);
336        let beam = ctc_prefix_beam_search(&lp, 3, 3, 0, 32).expect("beam");
337        let best = &beam[0];
338        let loss = ctc_loss(&lp, 3, 3, &best.labels, 0).expect("loss");
339        // beam log-prob ≤ −loss is not guaranteed exactly, but they should be
340        // close for a strongly peaked distribution.
341        assert!(
342            (best.log_prob - (-loss)).abs() < 0.2,
343            "beam={} loss={loss}",
344            best.log_prob
345        );
346    }
347
348    #[test]
349    fn beam_width_one_is_valid() {
350        let probs = vec![
351            0.2, 0.5, 0.3, //
352            0.4, 0.3, 0.3, //
353        ];
354        let lp = to_log(&probs);
355        let hyps = ctc_prefix_beam_search(&lp, 2, 3, 0, 1).expect("beam");
356        assert_eq!(hyps.len(), 1);
357    }
358
359    #[test]
360    fn beam_recovers_empty_when_blank_dominates() {
361        let probs = vec![
362            0.9, 0.05, 0.05, //
363            0.9, 0.05, 0.05, //
364        ];
365        let lp = to_log(&probs);
366        let hyps = ctc_prefix_beam_search(&lp, 2, 3, 0, 8).expect("beam");
367        assert!(hyps[0].labels.is_empty());
368    }
369
370    #[test]
371    fn greedy_shape_mismatch_errors() {
372        let lp = vec![0.0; 5];
373        assert!(ctc_greedy_decode(&lp, 2, 3, 0).is_err());
374    }
375
376    #[test]
377    fn greedy_blank_out_of_range_errors() {
378        let lp = to_log(&[0.5, 0.5, 0.5, 0.5]);
379        assert!(ctc_greedy_decode(&lp, 2, 2, 9).is_err());
380    }
381
382    #[test]
383    fn beam_zero_width_errors() {
384        let lp = to_log(&[0.5, 0.5, 0.5, 0.5]);
385        assert!(ctc_prefix_beam_search(&lp, 2, 2, 0, 0).is_err());
386    }
387
388    #[test]
389    fn beam_nan_errors() {
390        let lp = vec![f64::NAN, 0.0, 0.0, 0.0];
391        assert!(ctc_prefix_beam_search(&lp, 2, 2, 0, 4).is_err());
392    }
393
394    #[test]
395    fn greedy_nan_errors() {
396        let lp = vec![0.0, f64::NAN, 0.0, 0.0, 0.0, 0.0];
397        assert!(ctc_greedy_decode(&lp, 2, 3, 0).is_err());
398    }
399}