1use std::cmp;
10use std::collections::BinaryHeap;
11
12use crate::beam_search::BeamSearchEngine;
13
14#[derive(Debug, Clone)]
18pub struct Hypothesis {
19 pub tokens: Vec<u32>,
21 pub log_prob: f64,
23 pub normalized_score: f64,
25 pub is_complete: bool,
27}
28
29impl Hypothesis {
30 pub fn new(tokens: Vec<u32>, log_prob: f64) -> Self {
32 let len = tokens.len().max(1) as f64;
33 let normalized_score = log_prob / len;
34 Self {
35 tokens,
36 log_prob,
37 normalized_score,
38 is_complete: false,
39 }
40 }
41
42 pub fn score(&self) -> f64 {
44 self.normalized_score
45 }
46
47 pub fn extend(&self, token: u32, token_log_prob: f32) -> Self {
49 let mut tokens = self.tokens.clone();
50 tokens.push(token);
51 let log_prob = self.log_prob + token_log_prob as f64;
52 let len = tokens.len().max(1) as f64;
53 let normalized_score = log_prob / len;
54 Self {
55 tokens,
56 log_prob,
57 normalized_score,
58 is_complete: false,
59 }
60 }
61
62 pub fn complete(mut self, _eos_id: u32) -> Self {
64 self.is_complete = true;
65 self
66 }
67
68 pub fn len(&self) -> usize {
70 self.tokens.len()
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.tokens.is_empty()
76 }
77}
78
79impl PartialEq for Hypothesis {
83 fn eq(&self, other: &Self) -> bool {
84 self.normalized_score.total_cmp(&other.normalized_score) == cmp::Ordering::Equal
85 }
86}
87
88impl Eq for Hypothesis {}
89
90impl PartialOrd for Hypothesis {
91 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
92 Some(self.cmp(other))
93 }
94}
95
96impl Ord for Hypothesis {
97 fn cmp(&self, other: &Self) -> cmp::Ordering {
98 self.normalized_score.total_cmp(&other.normalized_score)
99 }
100}
101
102pub struct NBestList {
109 capacity: usize,
110 hypotheses: BinaryHeap<cmp::Reverse<Hypothesis>>,
112}
113
114impl NBestList {
115 pub fn new(n: usize) -> Self {
117 Self {
118 capacity: n,
119 hypotheses: BinaryHeap::with_capacity(n + 1),
120 }
121 }
122
123 pub fn push(&mut self, hyp: Hypothesis) {
128 if self.capacity == 0 {
129 return;
130 }
131 if self.hypotheses.len() < self.capacity {
132 self.hypotheses.push(cmp::Reverse(hyp));
133 } else {
134 let should_insert = self
136 .hypotheses
137 .peek()
138 .map(|cmp::Reverse(worst)| hyp.score() > worst.score())
139 .unwrap_or(true);
140
141 if should_insert {
142 self.hypotheses.pop();
143 self.hypotheses.push(cmp::Reverse(hyp));
144 }
145 }
146 }
147
148 pub fn top(&self) -> Option<&Hypothesis> {
150 self.hypotheses
152 .iter()
153 .map(|cmp::Reverse(h)| h)
154 .max_by(|a, b| a.score().total_cmp(&b.score()))
155 }
156
157 pub fn len(&self) -> usize {
159 self.hypotheses.len()
160 }
161
162 pub fn is_empty(&self) -> bool {
164 self.hypotheses.is_empty()
165 }
166
167 pub fn is_full(&self) -> bool {
169 self.hypotheses.len() >= self.capacity
170 }
171
172 pub fn worst_score(&self) -> Option<f64> {
174 self.hypotheses.peek().map(|cmp::Reverse(h)| h.score())
175 }
176
177 pub fn into_sorted(self) -> Vec<Hypothesis> {
179 let mut v: Vec<Hypothesis> = self
180 .hypotheses
181 .into_iter()
182 .map(|cmp::Reverse(h)| h)
183 .collect();
184 v.sort_by(|a, b| b.score().total_cmp(&a.score()));
185 v
186 }
187
188 pub fn complete_hypotheses(&self) -> Vec<&Hypothesis> {
190 self.hypotheses
191 .iter()
192 .map(|cmp::Reverse(h)| h)
193 .filter(|h| h.is_complete)
194 .collect()
195 }
196}
197
198pub struct NBestDecoder {
202 pub n: usize,
204 pub eos_id: u32,
206 pub max_len: usize,
208 pub length_penalty: f32,
210}
211
212impl NBestDecoder {
213 pub fn new(n: usize, eos_id: u32, max_len: usize) -> Self {
215 Self {
216 n,
217 eos_id,
218 max_len,
219 length_penalty: 1.0,
220 }
221 }
222
223 pub fn with_length_penalty(mut self, alpha: f32) -> Self {
225 self.length_penalty = alpha;
226 self
227 }
228
229 pub fn step(
234 &self,
235 hypotheses: &[Hypothesis],
236 logits_per_hyp: &[Vec<f32>],
237 top_k: usize,
238 ) -> Vec<Hypothesis> {
239 let effective_k = top_k.max(1);
240 let mut expanded: Vec<Hypothesis> = Vec::new();
241
242 for (hyp, logits) in hypotheses.iter().zip(logits_per_hyp.iter()) {
243 if hyp.is_complete {
244 expanded.push(hyp.clone());
245 continue;
246 }
247
248 let top = BeamSearchEngine::top_k_log_probs(logits, effective_k);
249
250 for (token, log_prob) in top {
251 let new_hyp = hyp.extend(token, log_prob as f32);
252 let new_hyp = if token == self.eos_id {
253 new_hyp.complete(self.eos_id)
254 } else {
255 new_hyp
256 };
257 expanded.push(new_hyp);
258 }
259 }
260
261 expanded
262 }
263
264 pub fn init(&self) -> NBestList {
266 NBestList::new(self.n)
267 }
268
269 pub fn partition(hyps: Vec<Hypothesis>) -> (Vec<Hypothesis>, Vec<Hypothesis>) {
271 let mut active = Vec::new();
272 let mut complete = Vec::new();
273 for h in hyps {
274 if h.is_complete {
275 complete.push(h);
276 } else {
277 active.push(h);
278 }
279 }
280 (active, complete)
281 }
282}
283
284#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn hypothesis_new() {
292 let h = Hypothesis::new(vec![1, 2, 3], -3.0);
293 assert_eq!(h.tokens, vec![1, 2, 3]);
294 assert!((h.log_prob - -3.0).abs() < f64::EPSILON);
295 assert!(!h.is_complete);
296 }
297
298 #[test]
299 fn hypothesis_extend() {
300 let h = Hypothesis::new(vec![1, 2], -2.0);
301 let h2 = h.extend(3, -1.0);
302 assert_eq!(h2.tokens, vec![1, 2, 3]);
303 assert!((h2.log_prob - -3.0).abs() < 1e-6);
304 }
305
306 #[test]
307 fn hypothesis_complete() {
308 let h = Hypothesis::new(vec![1, 2], -2.0);
309 let h = h.complete(2);
310 assert!(h.is_complete);
311 }
312
313 #[test]
314 fn hypothesis_score_normalized() {
315 let short = Hypothesis::new(vec![1], -1.0);
316 let _long = Hypothesis::new(vec![1, 2, 3, 4], -4.0);
317 let long_bad = Hypothesis::new(vec![1, 2, 3, 4, 5], -10.0);
321 assert!(long_bad.score() < short.score());
322 }
323
324 #[test]
325 fn nbest_list_new() {
326 let list = NBestList::new(5);
327 assert_eq!(list.len(), 0);
328 assert!(list.is_empty());
329 assert!(!list.is_full());
330 }
331
332 #[test]
333 fn nbest_list_push_under_capacity() {
334 let mut list = NBestList::new(5);
335 for i in 0..3u32 {
336 list.push(Hypothesis::new(vec![i], -(i as f64)));
337 }
338 assert_eq!(list.len(), 3);
339 assert!(!list.is_full());
340 }
341
342 #[test]
343 fn nbest_list_push_over_capacity() {
344 let mut list = NBestList::new(3);
345 for i in 0..5u32 {
347 list.push(Hypothesis::new(vec![i], -(i as f64)));
348 }
349 assert_eq!(list.len(), 3);
350 let sorted = list.into_sorted();
352 assert_eq!(sorted.len(), 3);
353 assert_eq!(sorted[0].tokens, vec![0]);
355 }
356
357 #[test]
358 fn nbest_list_worst_score() {
359 let mut list = NBestList::new(3);
360 list.push(Hypothesis::new(vec![1], -1.0));
361 list.push(Hypothesis::new(vec![2], -2.0));
362 list.push(Hypothesis::new(vec![3], -3.0));
363 let worst = list.worst_score().expect("should have worst score");
364 assert!((worst - -3.0).abs() < 1e-9);
365 }
366
367 #[test]
368 fn nbest_list_into_sorted_order() {
369 let mut list = NBestList::new(5);
370 list.push(Hypothesis::new(vec![1], -3.0));
371 list.push(Hypothesis::new(vec![2], -1.0));
372 list.push(Hypothesis::new(vec![3], -2.0));
373 let sorted = list.into_sorted();
374 assert_eq!(sorted.len(), 3);
375 assert!((sorted[0].log_prob - -1.0).abs() < 1e-9);
377 assert!((sorted[1].log_prob - -2.0).abs() < 1e-9);
378 assert!((sorted[2].log_prob - -3.0).abs() < 1e-9);
379 }
380
381 #[test]
382 fn nbest_list_complete_hypotheses() {
383 let mut list = NBestList::new(5);
384 list.push(Hypothesis::new(vec![1], -1.0).complete(2));
385 list.push(Hypothesis::new(vec![3], -2.0));
386 let complete = list.complete_hypotheses();
387 assert_eq!(complete.len(), 1);
388 assert!(complete[0].is_complete);
389 }
390
391 #[test]
392 fn nbest_decoder_step_expands() {
393 let decoder = NBestDecoder::new(5, 99, 20);
394 let hyps = vec![Hypothesis::new(vec![1], -0.5)];
395 let logits = vec![vec![0.0f32; 10]];
396 let expanded = decoder.step(&hyps, &logits, 3);
397 assert!(expanded.len() >= 3);
398 }
399
400 #[test]
401 fn nbest_decoder_partition() {
402 let active_h = Hypothesis::new(vec![1], -1.0);
403 let complete_h = Hypothesis::new(vec![2], -2.0).complete(2);
404 let (active, complete) = NBestDecoder::partition(vec![active_h, complete_h]);
405 assert_eq!(active.len(), 1);
406 assert_eq!(complete.len(), 1);
407 assert!(!active[0].is_complete);
408 assert!(complete[0].is_complete);
409 }
410
411 #[test]
412 fn nbest_decoder_eos_completes() {
413 let eos = 2u32;
414 let decoder = NBestDecoder::new(5, eos, 20);
415 let hyps = vec![Hypothesis::new(vec![1], -0.5)];
416 let mut logits = vec![f32::NEG_INFINITY; 5];
418 logits[eos as usize] = 10.0;
419 let expanded = decoder.step(&hyps, &[logits], 1);
420 assert!(!expanded.is_empty());
421 assert!(expanded[0].is_complete);
422 }
423
424 #[test]
425 fn nbest_decoder_length_penalty() {
426 let h_short = Hypothesis::new(vec![1], -1.0);
428 let h_long = Hypothesis::new(vec![1, 2, 3, 4], -6.0);
429 assert!(h_short.score() > h_long.score());
431 }
432}