1use crate::error::{SeqError, SeqResult};
22use std::collections::HashMap;
23
24#[inline]
26fn log_add_exp(a: f64, b: f64) -> f64 {
27 if a == f64::NEG_INFINITY {
28 return b;
29 }
30 if b == f64::NEG_INFINITY {
31 return a;
32 }
33 let (hi, lo) = if a > b { (a, b) } else { (b, a) };
34 hi + (lo - hi).exp().ln_1p()
35}
36
37fn validate(log_probs: &[f64], t_len: usize, n_symbols: usize, blank: usize) -> SeqResult<()> {
39 if t_len == 0 || n_symbols == 0 {
40 return Err(SeqError::EmptyInput);
41 }
42 if log_probs.len() != t_len * n_symbols {
43 return Err(SeqError::ShapeMismatch {
44 expected: t_len * n_symbols,
45 got: log_probs.len(),
46 });
47 }
48 if blank >= n_symbols {
49 return Err(SeqError::IndexOutOfBounds {
50 index: blank,
51 len: n_symbols,
52 });
53 }
54 Ok(())
55}
56
57pub fn ctc_greedy_decode(
61 log_probs: &[f64],
62 t_len: usize,
63 n_symbols: usize,
64 blank: usize,
65) -> SeqResult<Vec<usize>> {
66 validate(log_probs, t_len, n_symbols, blank)?;
67 let mut raw = Vec::with_capacity(t_len);
68 for ti in 0..t_len {
69 let row = &log_probs[ti * n_symbols..ti * n_symbols + n_symbols];
70 let mut best = 0usize;
71 let mut best_val = row[0];
72 for (c, &v) in row.iter().enumerate() {
73 if v.is_nan() {
74 return Err(SeqError::NumericalInstability(
75 "NaN in CTC log-probs".into(),
76 ));
77 }
78 if v > best_val {
79 best_val = v;
80 best = c;
81 }
82 }
83 raw.push(best);
84 }
85 let mut out = Vec::new();
87 let mut prev = usize::MAX;
88 for &sym in &raw {
89 if sym != prev && sym != blank {
90 out.push(sym);
91 }
92 prev = sym;
93 }
94 Ok(out)
95}
96
97#[derive(Clone, Copy)]
99struct PrefixProb {
100 p_blank: f64,
102 p_non_blank: f64,
104}
105
106impl PrefixProb {
107 #[inline]
108 fn total(&self) -> f64 {
109 log_add_exp(self.p_blank, self.p_non_blank)
110 }
111}
112
113#[derive(Debug, Clone, PartialEq)]
115pub struct CtcHypothesis {
116 pub labels: Vec<usize>,
118 pub log_prob: f64,
120}
121
122pub fn ctc_prefix_beam_search(
129 log_probs: &[f64],
130 t_len: usize,
131 n_symbols: usize,
132 blank: usize,
133 beam_width: usize,
134) -> SeqResult<Vec<CtcHypothesis>> {
135 validate(log_probs, t_len, n_symbols, blank)?;
136 if beam_width == 0 {
137 return Err(SeqError::InvalidParameter {
138 name: "beam_width".into(),
139 value: 0.0,
140 });
141 }
142 for &v in log_probs {
143 if v.is_nan() {
144 return Err(SeqError::NumericalInstability(
145 "NaN in CTC log-probs".into(),
146 ));
147 }
148 }
149
150 let mut beam: HashMap<Vec<usize>, PrefixProb> = HashMap::new();
152 beam.insert(
153 Vec::new(),
154 PrefixProb {
155 p_blank: 0.0,
156 p_non_blank: f64::NEG_INFINITY,
157 },
158 );
159
160 for ti in 0..t_len {
161 let row = &log_probs[ti * n_symbols..ti * n_symbols + n_symbols];
162 let mut next: HashMap<Vec<usize>, PrefixProb> = HashMap::new();
163
164 for (prefix, prob) in &beam {
165 let entry = next.entry(prefix.clone()).or_insert(PrefixProb {
167 p_blank: f64::NEG_INFINITY,
168 p_non_blank: f64::NEG_INFINITY,
169 });
170 entry.p_blank = log_add_exp(entry.p_blank, prob.total() + row[blank]);
171
172 for c in 0..n_symbols {
174 if c == blank {
175 continue;
176 }
177 let lp_c = row[c];
178 let last = prefix.last().copied();
179 if last == Some(c) {
180 let mut new_prefix = prefix.clone();
183 new_prefix.push(c);
184 let e = next.entry(new_prefix).or_insert(PrefixProb {
185 p_blank: f64::NEG_INFINITY,
186 p_non_blank: f64::NEG_INFINITY,
187 });
188 e.p_non_blank = log_add_exp(e.p_non_blank, prob.p_blank + lp_c);
189 let e_same = next.entry(prefix.clone()).or_insert(PrefixProb {
191 p_blank: f64::NEG_INFINITY,
192 p_non_blank: f64::NEG_INFINITY,
193 });
194 e_same.p_non_blank = log_add_exp(e_same.p_non_blank, prob.p_non_blank + lp_c);
195 } else {
196 let mut new_prefix = prefix.clone();
199 new_prefix.push(c);
200 let e = next.entry(new_prefix).or_insert(PrefixProb {
201 p_blank: f64::NEG_INFINITY,
202 p_non_blank: f64::NEG_INFINITY,
203 });
204 e.p_non_blank = log_add_exp(e.p_non_blank, prob.total() + lp_c);
205 }
206 }
207 }
208
209 let mut scored: Vec<(Vec<usize>, PrefixProb)> = next.into_iter().collect();
211 scored.sort_by(|a, b| {
212 b.1.total()
213 .partial_cmp(&a.1.total())
214 .unwrap_or(std::cmp::Ordering::Equal)
215 });
216 scored.truncate(beam_width);
217 beam = scored.into_iter().collect();
218 }
219
220 let mut hyps: Vec<CtcHypothesis> = beam
221 .into_iter()
222 .map(|(labels, prob)| CtcHypothesis {
223 labels,
224 log_prob: prob.total(),
225 })
226 .collect();
227 hyps.sort_by(|a, b| {
228 b.log_prob
229 .partial_cmp(&a.log_prob)
230 .unwrap_or(std::cmp::Ordering::Equal)
231 });
232 Ok(hyps)
233}
234
235#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn to_log(probs: &[f64]) -> Vec<f64> {
242 probs.iter().map(|&p| p.max(1e-30).ln()).collect()
243 }
244
245 #[test]
246 fn greedy_collapses_repeats_and_blanks() {
247 let probs = vec![
249 0.1, 0.8, 0.1, 0.1, 0.8, 0.1, 0.8, 0.1, 0.1, 0.1, 0.1, 0.8, ];
254 let lp = to_log(&probs);
255 let out = ctc_greedy_decode(&lp, 4, 3, 0).expect("decode");
256 assert_eq!(out, vec![1, 2]);
257 }
258
259 #[test]
260 fn greedy_all_blank_is_empty() {
261 let probs = vec![
262 0.9, 0.05, 0.05, 0.9, 0.05, 0.05, ];
265 let lp = to_log(&probs);
266 let out = ctc_greedy_decode(&lp, 2, 3, 0).expect("decode");
267 assert!(out.is_empty());
268 }
269
270 #[test]
271 fn greedy_repeat_without_blank_merges() {
272 let probs = vec![
274 0.1, 0.8, 0.1, 0.1, 0.8, 0.1, ];
277 let lp = to_log(&probs);
278 let out = ctc_greedy_decode(&lp, 2, 3, 0).expect("decode");
279 assert_eq!(out, vec![1]);
280 }
281
282 #[test]
283 fn greedy_blank_at_last_index() {
284 let probs = vec![
286 0.8, 0.1, 0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.8, ];
290 let lp = to_log(&probs);
291 let out = ctc_greedy_decode(&lp, 3, 3, 2).expect("decode");
292 assert_eq!(out, vec![0, 1]);
293 }
294
295 #[test]
296 fn beam_returns_sorted_hypotheses() {
297 let probs = vec![
298 0.2, 0.5, 0.3, 0.1, 0.6, 0.3, 0.3, 0.2, 0.5, 0.4, 0.3, 0.3, ];
303 let lp = to_log(&probs);
304 let hyps = ctc_prefix_beam_search(&lp, 4, 3, 0, 8).expect("beam");
305 assert!(!hyps.is_empty());
306 for w in hyps.windows(2) {
307 assert!(w[0].log_prob >= w[1].log_prob - 1e-12);
308 }
309 }
310
311 #[test]
312 fn beam_top1_matches_greedy_for_peaked_input() {
313 let probs = vec![
315 0.02, 0.96, 0.02, 0.96, 0.02, 0.02, 0.02, 0.02, 0.96, ];
319 let lp = to_log(&probs);
320 let greedy = ctc_greedy_decode(&lp, 3, 3, 0).expect("greedy");
321 let beam = ctc_prefix_beam_search(&lp, 3, 3, 0, 16).expect("beam");
322 assert_eq!(beam[0].labels, greedy);
323 }
324
325 #[test]
326 fn beam_total_probability_consistent_with_loss() {
327 use crate::ctc::ctc_loss::ctc_loss;
330 let probs = vec![
331 0.02, 0.96, 0.02, 0.96, 0.02, 0.02, 0.02, 0.02, 0.96, ];
335 let lp = to_log(&probs);
336 let beam = ctc_prefix_beam_search(&lp, 3, 3, 0, 32).expect("beam");
337 let best = &beam[0];
338 let loss = ctc_loss(&lp, 3, 3, &best.labels, 0).expect("loss");
339 assert!(
342 (best.log_prob - (-loss)).abs() < 0.2,
343 "beam={} loss={loss}",
344 best.log_prob
345 );
346 }
347
348 #[test]
349 fn beam_width_one_is_valid() {
350 let probs = vec![
351 0.2, 0.5, 0.3, 0.4, 0.3, 0.3, ];
354 let lp = to_log(&probs);
355 let hyps = ctc_prefix_beam_search(&lp, 2, 3, 0, 1).expect("beam");
356 assert_eq!(hyps.len(), 1);
357 }
358
359 #[test]
360 fn beam_recovers_empty_when_blank_dominates() {
361 let probs = vec![
362 0.9, 0.05, 0.05, 0.9, 0.05, 0.05, ];
365 let lp = to_log(&probs);
366 let hyps = ctc_prefix_beam_search(&lp, 2, 3, 0, 8).expect("beam");
367 assert!(hyps[0].labels.is_empty());
368 }
369
370 #[test]
371 fn greedy_shape_mismatch_errors() {
372 let lp = vec![0.0; 5];
373 assert!(ctc_greedy_decode(&lp, 2, 3, 0).is_err());
374 }
375
376 #[test]
377 fn greedy_blank_out_of_range_errors() {
378 let lp = to_log(&[0.5, 0.5, 0.5, 0.5]);
379 assert!(ctc_greedy_decode(&lp, 2, 2, 9).is_err());
380 }
381
382 #[test]
383 fn beam_zero_width_errors() {
384 let lp = to_log(&[0.5, 0.5, 0.5, 0.5]);
385 assert!(ctc_prefix_beam_search(&lp, 2, 2, 0, 0).is_err());
386 }
387
388 #[test]
389 fn beam_nan_errors() {
390 let lp = vec![f64::NAN, 0.0, 0.0, 0.0];
391 assert!(ctc_prefix_beam_search(&lp, 2, 2, 0, 4).is_err());
392 }
393
394 #[test]
395 fn greedy_nan_errors() {
396 let lp = vec![0.0, f64::NAN, 0.0, 0.0, 0.0, 0.0];
397 assert!(ctc_greedy_decode(&lp, 2, 3, 0).is_err());
398 }
399}