use crate::error::{SeqError, SeqResult};
use std::collections::HashMap;
#[inline]
fn log_add_exp(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
return b;
}
if b == f64::NEG_INFINITY {
return a;
}
let (hi, lo) = if a > b { (a, b) } else { (b, a) };
hi + (lo - hi).exp().ln_1p()
}
fn validate(log_probs: &[f64], t_len: usize, n_symbols: usize, blank: usize) -> SeqResult<()> {
if t_len == 0 || n_symbols == 0 {
return Err(SeqError::EmptyInput);
}
if log_probs.len() != t_len * n_symbols {
return Err(SeqError::ShapeMismatch {
expected: t_len * n_symbols,
got: log_probs.len(),
});
}
if blank >= n_symbols {
return Err(SeqError::IndexOutOfBounds {
index: blank,
len: n_symbols,
});
}
Ok(())
}
pub fn ctc_greedy_decode(
log_probs: &[f64],
t_len: usize,
n_symbols: usize,
blank: usize,
) -> SeqResult<Vec<usize>> {
validate(log_probs, t_len, n_symbols, blank)?;
let mut raw = Vec::with_capacity(t_len);
for ti in 0..t_len {
let row = &log_probs[ti * n_symbols..ti * n_symbols + n_symbols];
let mut best = 0usize;
let mut best_val = row[0];
for (c, &v) in row.iter().enumerate() {
if v.is_nan() {
return Err(SeqError::NumericalInstability(
"NaN in CTC log-probs".into(),
));
}
if v > best_val {
best_val = v;
best = c;
}
}
raw.push(best);
}
let mut out = Vec::new();
let mut prev = usize::MAX;
for &sym in &raw {
if sym != prev && sym != blank {
out.push(sym);
}
prev = sym;
}
Ok(out)
}
#[derive(Clone, Copy)]
struct PrefixProb {
p_blank: f64,
p_non_blank: f64,
}
impl PrefixProb {
#[inline]
fn total(&self) -> f64 {
log_add_exp(self.p_blank, self.p_non_blank)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CtcHypothesis {
pub labels: Vec<usize>,
pub log_prob: f64,
}
pub fn ctc_prefix_beam_search(
log_probs: &[f64],
t_len: usize,
n_symbols: usize,
blank: usize,
beam_width: usize,
) -> SeqResult<Vec<CtcHypothesis>> {
validate(log_probs, t_len, n_symbols, blank)?;
if beam_width == 0 {
return Err(SeqError::InvalidParameter {
name: "beam_width".into(),
value: 0.0,
});
}
for &v in log_probs {
if v.is_nan() {
return Err(SeqError::NumericalInstability(
"NaN in CTC log-probs".into(),
));
}
}
let mut beam: HashMap<Vec<usize>, PrefixProb> = HashMap::new();
beam.insert(
Vec::new(),
PrefixProb {
p_blank: 0.0,
p_non_blank: f64::NEG_INFINITY,
},
);
for ti in 0..t_len {
let row = &log_probs[ti * n_symbols..ti * n_symbols + n_symbols];
let mut next: HashMap<Vec<usize>, PrefixProb> = HashMap::new();
for (prefix, prob) in &beam {
let entry = next.entry(prefix.clone()).or_insert(PrefixProb {
p_blank: f64::NEG_INFINITY,
p_non_blank: f64::NEG_INFINITY,
});
entry.p_blank = log_add_exp(entry.p_blank, prob.total() + row[blank]);
for c in 0..n_symbols {
if c == blank {
continue;
}
let lp_c = row[c];
let last = prefix.last().copied();
if last == Some(c) {
let mut new_prefix = prefix.clone();
new_prefix.push(c);
let e = next.entry(new_prefix).or_insert(PrefixProb {
p_blank: f64::NEG_INFINITY,
p_non_blank: f64::NEG_INFINITY,
});
e.p_non_blank = log_add_exp(e.p_non_blank, prob.p_blank + lp_c);
let e_same = next.entry(prefix.clone()).or_insert(PrefixProb {
p_blank: f64::NEG_INFINITY,
p_non_blank: f64::NEG_INFINITY,
});
e_same.p_non_blank = log_add_exp(e_same.p_non_blank, prob.p_non_blank + lp_c);
} else {
let mut new_prefix = prefix.clone();
new_prefix.push(c);
let e = next.entry(new_prefix).or_insert(PrefixProb {
p_blank: f64::NEG_INFINITY,
p_non_blank: f64::NEG_INFINITY,
});
e.p_non_blank = log_add_exp(e.p_non_blank, prob.total() + lp_c);
}
}
}
let mut scored: Vec<(Vec<usize>, PrefixProb)> = next.into_iter().collect();
scored.sort_by(|a, b| {
b.1.total()
.partial_cmp(&a.1.total())
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(beam_width);
beam = scored.into_iter().collect();
}
let mut hyps: Vec<CtcHypothesis> = beam
.into_iter()
.map(|(labels, prob)| CtcHypothesis {
labels,
log_prob: prob.total(),
})
.collect();
hyps.sort_by(|a, b| {
b.log_prob
.partial_cmp(&a.log_prob)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(hyps)
}
#[cfg(test)]
mod tests {
use super::*;
fn to_log(probs: &[f64]) -> Vec<f64> {
probs.iter().map(|&p| p.max(1e-30).ln()).collect()
}
#[test]
fn greedy_collapses_repeats_and_blanks() {
let probs = vec![
0.1, 0.8, 0.1, 0.1, 0.8, 0.1, 0.8, 0.1, 0.1, 0.1, 0.1, 0.8, ];
let lp = to_log(&probs);
let out = ctc_greedy_decode(&lp, 4, 3, 0).expect("decode");
assert_eq!(out, vec![1, 2]);
}
#[test]
fn greedy_all_blank_is_empty() {
let probs = vec![
0.9, 0.05, 0.05, 0.9, 0.05, 0.05, ];
let lp = to_log(&probs);
let out = ctc_greedy_decode(&lp, 2, 3, 0).expect("decode");
assert!(out.is_empty());
}
#[test]
fn greedy_repeat_without_blank_merges() {
let probs = vec![
0.1, 0.8, 0.1, 0.1, 0.8, 0.1, ];
let lp = to_log(&probs);
let out = ctc_greedy_decode(&lp, 2, 3, 0).expect("decode");
assert_eq!(out, vec![1]);
}
#[test]
fn greedy_blank_at_last_index() {
let probs = vec![
0.8, 0.1, 0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.8, ];
let lp = to_log(&probs);
let out = ctc_greedy_decode(&lp, 3, 3, 2).expect("decode");
assert_eq!(out, vec![0, 1]);
}
#[test]
fn beam_returns_sorted_hypotheses() {
let probs = vec![
0.2, 0.5, 0.3, 0.1, 0.6, 0.3, 0.3, 0.2, 0.5, 0.4, 0.3, 0.3, ];
let lp = to_log(&probs);
let hyps = ctc_prefix_beam_search(&lp, 4, 3, 0, 8).expect("beam");
assert!(!hyps.is_empty());
for w in hyps.windows(2) {
assert!(w[0].log_prob >= w[1].log_prob - 1e-12);
}
}
#[test]
fn beam_top1_matches_greedy_for_peaked_input() {
let probs = vec![
0.02, 0.96, 0.02, 0.96, 0.02, 0.02, 0.02, 0.02, 0.96, ];
let lp = to_log(&probs);
let greedy = ctc_greedy_decode(&lp, 3, 3, 0).expect("greedy");
let beam = ctc_prefix_beam_search(&lp, 3, 3, 0, 16).expect("beam");
assert_eq!(beam[0].labels, greedy);
}
#[test]
fn beam_total_probability_consistent_with_loss() {
use crate::ctc::ctc_loss::ctc_loss;
let probs = vec![
0.02, 0.96, 0.02, 0.96, 0.02, 0.02, 0.02, 0.02, 0.96, ];
let lp = to_log(&probs);
let beam = ctc_prefix_beam_search(&lp, 3, 3, 0, 32).expect("beam");
let best = &beam[0];
let loss = ctc_loss(&lp, 3, 3, &best.labels, 0).expect("loss");
assert!(
(best.log_prob - (-loss)).abs() < 0.2,
"beam={} loss={loss}",
best.log_prob
);
}
#[test]
fn beam_width_one_is_valid() {
let probs = vec![
0.2, 0.5, 0.3, 0.4, 0.3, 0.3, ];
let lp = to_log(&probs);
let hyps = ctc_prefix_beam_search(&lp, 2, 3, 0, 1).expect("beam");
assert_eq!(hyps.len(), 1);
}
#[test]
fn beam_recovers_empty_when_blank_dominates() {
let probs = vec![
0.9, 0.05, 0.05, 0.9, 0.05, 0.05, ];
let lp = to_log(&probs);
let hyps = ctc_prefix_beam_search(&lp, 2, 3, 0, 8).expect("beam");
assert!(hyps[0].labels.is_empty());
}
#[test]
fn greedy_shape_mismatch_errors() {
let lp = vec![0.0; 5];
assert!(ctc_greedy_decode(&lp, 2, 3, 0).is_err());
}
#[test]
fn greedy_blank_out_of_range_errors() {
let lp = to_log(&[0.5, 0.5, 0.5, 0.5]);
assert!(ctc_greedy_decode(&lp, 2, 2, 9).is_err());
}
#[test]
fn beam_zero_width_errors() {
let lp = to_log(&[0.5, 0.5, 0.5, 0.5]);
assert!(ctc_prefix_beam_search(&lp, 2, 2, 0, 0).is_err());
}
#[test]
fn beam_nan_errors() {
let lp = vec![f64::NAN, 0.0, 0.0, 0.0];
assert!(ctc_prefix_beam_search(&lp, 2, 2, 0, 4).is_err());
}
#[test]
fn greedy_nan_errors() {
let lp = vec![0.0, f64::NAN, 0.0, 0.0, 0.0, 0.0];
assert!(ctc_greedy_decode(&lp, 2, 3, 0).is_err());
}
}