1use crate::config::DecodeMethod;
19use rten_tensor::NdTensorView;
20use rten_tensor::prelude::*;
21use std::collections::HashMap;
22use std::num::NonZeroU32;
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub struct CtcStep {
27 pub label: u32,
28 pub pos: u32,
29}
30
31#[derive(Clone, Debug)]
33pub struct CtcHypothesis {
34 steps: Vec<CtcStep>,
35 score: f32,
36}
37
38impl CtcHypothesis {
39 pub fn steps(&self) -> &[CtcStep] {
40 &self.steps
41 }
42
43 pub fn score(&self) -> f32 {
44 self.score
45 }
46}
47
48pub fn decode(input_seq: NdTensorView<f32, 2>, method: DecodeMethod) -> CtcHypothesis {
50 match method {
51 DecodeMethod::Greedy => decode_greedy(input_seq),
52 DecodeMethod::BeamSearch { width } => {
53 decode_beam(input_seq, width.max(1)).unwrap_or_else(|| decode_greedy(input_seq))
54 }
55 }
56}
57
58fn decode_greedy(prob_seq: NdTensorView<f32, 2>) -> CtcHypothesis {
59 let mut last_label = 0;
60 let mut steps = Vec::new();
61 let mut score = 0.;
62
63 for pos in 0..prob_seq.size(0) {
64 let mut best_label = 0usize;
65 let mut best_lp = prob_seq[[pos, 0]];
66 for label in 1..prob_seq.size(1) {
67 let lp = prob_seq[[pos, label]];
68 if lp > best_lp {
69 best_lp = lp;
70 best_label = label;
71 }
72 }
73 let label = best_label;
74 score += best_lp;
75 if label == last_label {
76 continue;
77 }
78 last_label = label;
79 if label > 0 {
80 steps.push(CtcStep {
81 label: label as u32,
82 pos: pos as u32,
83 });
84 }
85 }
86
87 CtcHypothesis { steps, score }
88}
89
90#[derive(Debug)]
91struct BeamProbs {
92 prob_blank: f32,
93 prob_no_blank: f32,
94}
95
96fn log_sum_exp<const N: usize>(log_probs: [f32; N]) -> f32 {
97 if log_probs.iter().all(|&x| x == f32::NEG_INFINITY) {
98 f32::NEG_INFINITY
99 } else {
100 let lp_max = log_probs
101 .into_iter()
102 .reduce(f32::max)
103 .unwrap_or(f32::NEG_INFINITY);
104 lp_max
105 + log_probs
106 .iter()
107 .map(|x| (x - lp_max).exp())
108 .sum::<f32>()
109 .ln()
110 }
111}
112
113fn decode_beam(prob_seq: NdTensorView<f32, 2>, beam_size: u32) -> Option<CtcHypothesis> {
114 let beam_size = NonZeroU32::new(beam_size)?;
115 let mut states: HashMap<Vec<CtcStep>, BeamProbs> = HashMap::new();
116 states.insert(
117 Vec::new(),
118 BeamProbs {
119 prob_blank: 0.,
120 prob_no_blank: f32::NEG_INFINITY,
121 },
122 );
123
124 for t in 0..prob_seq.size(0) {
125 let mut next: HashMap<Vec<CtcStep>, BeamProbs> = HashMap::new();
126 let blank_lp = prob_seq[[t, 0]];
127 for (prefix, state) in &states {
128 let p_b = state.prob_blank;
129 let p_nb = state.prob_no_blank;
130 merge_beam(
131 &mut next,
132 prefix.clone(),
133 log_sum_exp([p_b + blank_lp, p_nb + blank_lp]),
134 f32::NEG_INFINITY,
135 );
136 for label in 1..prob_seq.size(1) {
137 let lp = prob_seq[[t, label]];
138 let mut new_prefix = prefix.clone();
139 let step = CtcStep {
140 label: label as u32,
141 pos: t as u32,
142 };
143 let last = new_prefix.last().map(|s| s.label);
144 if last != Some(step.label) {
145 new_prefix.push(step);
146 }
147 let (nb_lp, b_lp) = if last == Some(step.label) {
148 (p_nb + lp, p_b + lp)
149 } else {
150 (log_sum_exp([p_b + lp, p_nb + lp]), f32::NEG_INFINITY)
151 };
152 merge_beam(&mut next, new_prefix, b_lp, nb_lp);
153 }
154 }
155 let mut ranked: Vec<_> = next.into_iter().collect();
156 ranked.sort_by(|(_, a), (_, b)| {
157 let sa = log_sum_exp([a.prob_blank, a.prob_no_blank]);
158 let sb = log_sum_exp([b.prob_blank, b.prob_no_blank]);
159 sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
160 });
161 ranked.truncate(beam_size.get() as usize);
162 states = ranked.into_iter().collect();
163 }
164
165 let (prefix, probs) = states.into_iter().max_by(|(_, a), (_, b)| {
166 let sa = log_sum_exp([a.prob_blank, a.prob_no_blank]);
167 let sb = log_sum_exp([b.prob_blank, b.prob_no_blank]);
168 sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
169 })?;
170 Some(CtcHypothesis {
171 steps: prefix,
172 score: log_sum_exp([probs.prob_blank, probs.prob_no_blank]),
173 })
174}
175
176fn merge_beam(map: &mut HashMap<Vec<CtcStep>, BeamProbs>, prefix: Vec<CtcStep>, pb: f32, pnb: f32) {
177 use std::collections::hash_map::Entry;
178 match map.entry(prefix) {
179 Entry::Vacant(e) => {
180 e.insert(BeamProbs {
181 prob_blank: pb,
182 prob_no_blank: pnb,
183 });
184 }
185 Entry::Occupied(mut e) => {
186 let s = e.get_mut();
187 s.prob_blank = log_sum_exp([s.prob_blank, pb]);
188 s.prob_no_blank = log_sum_exp([s.prob_no_blank, pnb]);
189 }
190 }
191}