Skip to main content

oxibonsai_runtime/
nbest.rs

1//! N-best hypothesis tracking for diverse generation.
2//!
3//! Maintains a heap of the N best partial sequences seen during decoding,
4//! scored by cumulative log-probability. Used for:
5//! - Diverse beam search
6//! - Post-generation reranking
7//! - Multi-hypothesis output
8
9use std::cmp;
10use std::collections::BinaryHeap;
11
12use crate::beam_search::BeamSearchEngine;
13
14// ─── Hypothesis ────────────────────────────────────────────────────────────────
15
16/// A single hypothesis (partial or complete token sequence).
17#[derive(Debug, Clone)]
18pub struct Hypothesis {
19    /// All token IDs in this hypothesis.
20    pub tokens: Vec<u32>,
21    /// Cumulative log probability of the sequence.
22    pub log_prob: f64,
23    /// Length-normalised score: log_prob / tokens.len().max(1)
24    pub normalized_score: f64,
25    /// Whether this hypothesis ended with an EOS token.
26    pub is_complete: bool,
27}
28
29impl Hypothesis {
30    /// Create a new hypothesis with the given tokens and cumulative log probability.
31    pub fn new(tokens: Vec<u32>, log_prob: f64) -> Self {
32        let len = tokens.len().max(1) as f64;
33        let normalized_score = log_prob / len;
34        Self {
35            tokens,
36            log_prob,
37            normalized_score,
38            is_complete: false,
39        }
40    }
41
42    /// Return the length-normalised score used for ranking.
43    pub fn score(&self) -> f64 {
44        self.normalized_score
45    }
46
47    /// Extend this hypothesis with one more token, accumulating its log probability.
48    pub fn extend(&self, token: u32, token_log_prob: f32) -> Self {
49        let mut tokens = self.tokens.clone();
50        tokens.push(token);
51        let log_prob = self.log_prob + token_log_prob as f64;
52        let len = tokens.len().max(1) as f64;
53        let normalized_score = log_prob / len;
54        Self {
55            tokens,
56            log_prob,
57            normalized_score,
58            is_complete: false,
59        }
60    }
61
62    /// Mark this hypothesis as complete (ended with EOS).
63    pub fn complete(mut self, _eos_id: u32) -> Self {
64        self.is_complete = true;
65        self
66    }
67
68    /// Number of tokens in this hypothesis.
69    pub fn len(&self) -> usize {
70        self.tokens.len()
71    }
72
73    /// Whether the hypothesis has no tokens.
74    pub fn is_empty(&self) -> bool {
75        self.tokens.is_empty()
76    }
77}
78
79// Ordering by normalized_score (higher is better).
80// BinaryHeap is a max-heap, so Reverse<Hypothesis> gives a min-heap suitable for NBestList.
81
82impl PartialEq for Hypothesis {
83    fn eq(&self, other: &Self) -> bool {
84        self.normalized_score.total_cmp(&other.normalized_score) == cmp::Ordering::Equal
85    }
86}
87
88impl Eq for Hypothesis {}
89
90impl PartialOrd for Hypothesis {
91    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
92        Some(self.cmp(other))
93    }
94}
95
96impl Ord for Hypothesis {
97    fn cmp(&self, other: &Self) -> cmp::Ordering {
98        self.normalized_score.total_cmp(&other.normalized_score)
99    }
100}
101
102// ─── NBestList ─────────────────────────────────────────────────────────────────
103
104/// A fixed-capacity heap of the N best hypotheses.
105///
106/// Internally uses a min-heap (via `Reverse`) so that the worst hypothesis is
107/// always at the top and can be evicted when capacity is exceeded.
108pub struct NBestList {
109    capacity: usize,
110    /// Min-heap: the root is the *worst* kept hypothesis.
111    hypotheses: BinaryHeap<cmp::Reverse<Hypothesis>>,
112}
113
114impl NBestList {
115    /// Create an empty N-best list with the given capacity.
116    pub fn new(n: usize) -> Self {
117        Self {
118            capacity: n,
119            hypotheses: BinaryHeap::with_capacity(n + 1),
120        }
121    }
122
123    /// Push a hypothesis into the list.
124    ///
125    /// If the list is already at capacity and the new hypothesis scores better
126    /// than the current worst, the worst is evicted.
127    pub fn push(&mut self, hyp: Hypothesis) {
128        if self.capacity == 0 {
129            return;
130        }
131        if self.hypotheses.len() < self.capacity {
132            self.hypotheses.push(cmp::Reverse(hyp));
133        } else {
134            // Only replace worst if the new hypothesis is strictly better.
135            let should_insert = self
136                .hypotheses
137                .peek()
138                .map(|cmp::Reverse(worst)| hyp.score() > worst.score())
139                .unwrap_or(true);
140
141            if should_insert {
142                self.hypotheses.pop();
143                self.hypotheses.push(cmp::Reverse(hyp));
144            }
145        }
146    }
147
148    /// Return a reference to the best hypothesis (highest score) in the list.
149    pub fn top(&self) -> Option<&Hypothesis> {
150        // The min-heap's root is the worst; we need to scan for best.
151        self.hypotheses
152            .iter()
153            .map(|cmp::Reverse(h)| h)
154            .max_by(|a, b| a.score().total_cmp(&b.score()))
155    }
156
157    /// Number of hypotheses currently held.
158    pub fn len(&self) -> usize {
159        self.hypotheses.len()
160    }
161
162    /// Whether the list has no hypotheses.
163    pub fn is_empty(&self) -> bool {
164        self.hypotheses.is_empty()
165    }
166
167    /// Whether the list has reached its capacity.
168    pub fn is_full(&self) -> bool {
169        self.hypotheses.len() >= self.capacity
170    }
171
172    /// Score of the worst hypothesis currently kept, or `None` if empty.
173    pub fn worst_score(&self) -> Option<f64> {
174        self.hypotheses.peek().map(|cmp::Reverse(h)| h.score())
175    }
176
177    /// Consume the list and return hypotheses sorted best-first.
178    pub fn into_sorted(self) -> Vec<Hypothesis> {
179        let mut v: Vec<Hypothesis> = self
180            .hypotheses
181            .into_iter()
182            .map(|cmp::Reverse(h)| h)
183            .collect();
184        v.sort_by(|a, b| b.score().total_cmp(&a.score()));
185        v
186    }
187
188    /// Return references to all complete hypotheses.
189    pub fn complete_hypotheses(&self) -> Vec<&Hypothesis> {
190        self.hypotheses
191            .iter()
192            .map(|cmp::Reverse(h)| h)
193            .filter(|h| h.is_complete)
194            .collect()
195    }
196}
197
198// ─── NBestDecoder ──────────────────────────────────────────────────────────────
199
200/// Decoder that expands hypotheses by one step and maintains an N-best list.
201pub struct NBestDecoder {
202    /// Maximum number of hypotheses to track.
203    pub n: usize,
204    /// Token ID that marks end-of-sequence.
205    pub eos_id: u32,
206    /// Maximum generation length (inclusive).
207    pub max_len: usize,
208    /// Length-penalty exponent α used for normalised scoring.
209    pub length_penalty: f32,
210}
211
212impl NBestDecoder {
213    /// Create a new decoder.
214    pub fn new(n: usize, eos_id: u32, max_len: usize) -> Self {
215        Self {
216            n,
217            eos_id,
218            max_len,
219            length_penalty: 1.0,
220        }
221    }
222
223    /// Set the length-penalty exponent (builder pattern).
224    pub fn with_length_penalty(mut self, alpha: f32) -> Self {
225        self.length_penalty = alpha;
226        self
227    }
228
229    /// Expand a batch of hypotheses by one step.
230    ///
231    /// `logits_per_hyp[i]` must be the logit vector for `hypotheses[i]`.
232    /// Returns the flat list of expanded hypotheses (up to `top_k` per input).
233    pub fn step(
234        &self,
235        hypotheses: &[Hypothesis],
236        logits_per_hyp: &[Vec<f32>],
237        top_k: usize,
238    ) -> Vec<Hypothesis> {
239        let effective_k = top_k.max(1);
240        let mut expanded: Vec<Hypothesis> = Vec::new();
241
242        for (hyp, logits) in hypotheses.iter().zip(logits_per_hyp.iter()) {
243            if hyp.is_complete {
244                expanded.push(hyp.clone());
245                continue;
246            }
247
248            let top = BeamSearchEngine::top_k_log_probs(logits, effective_k);
249
250            for (token, log_prob) in top {
251                let new_hyp = hyp.extend(token, log_prob as f32);
252                let new_hyp = if token == self.eos_id {
253                    new_hyp.complete(self.eos_id)
254                } else {
255                    new_hyp
256                };
257                expanded.push(new_hyp);
258            }
259        }
260
261        expanded
262    }
263
264    /// Return an empty N-best list with this decoder's capacity.
265    pub fn init(&self) -> NBestList {
266        NBestList::new(self.n)
267    }
268
269    /// Partition hypotheses into (active, complete).
270    pub fn partition(hyps: Vec<Hypothesis>) -> (Vec<Hypothesis>, Vec<Hypothesis>) {
271        let mut active = Vec::new();
272        let mut complete = Vec::new();
273        for h in hyps {
274            if h.is_complete {
275                complete.push(h);
276            } else {
277                active.push(h);
278            }
279        }
280        (active, complete)
281    }
282}
283
284// ─── Tests ─────────────────────────────────────────────────────────────────────
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn hypothesis_new() {
292        let h = Hypothesis::new(vec![1, 2, 3], -3.0);
293        assert_eq!(h.tokens, vec![1, 2, 3]);
294        assert!((h.log_prob - -3.0).abs() < f64::EPSILON);
295        assert!(!h.is_complete);
296    }
297
298    #[test]
299    fn hypothesis_extend() {
300        let h = Hypothesis::new(vec![1, 2], -2.0);
301        let h2 = h.extend(3, -1.0);
302        assert_eq!(h2.tokens, vec![1, 2, 3]);
303        assert!((h2.log_prob - -3.0).abs() < 1e-6);
304    }
305
306    #[test]
307    fn hypothesis_complete() {
308        let h = Hypothesis::new(vec![1, 2], -2.0);
309        let h = h.complete(2);
310        assert!(h.is_complete);
311    }
312
313    #[test]
314    fn hypothesis_score_normalized() {
315        let short = Hypothesis::new(vec![1], -1.0);
316        let _long = Hypothesis::new(vec![1, 2, 3, 4], -4.0);
317        // short: -1.0/1 = -1.0; long: -4.0/4 = -1.0 — same here.
318        // But a longer sequence with same total lp has same normalised score.
319        // Let's verify that a longer bad sequence scores worse.
320        let long_bad = Hypothesis::new(vec![1, 2, 3, 4, 5], -10.0);
321        assert!(long_bad.score() < short.score());
322    }
323
324    #[test]
325    fn nbest_list_new() {
326        let list = NBestList::new(5);
327        assert_eq!(list.len(), 0);
328        assert!(list.is_empty());
329        assert!(!list.is_full());
330    }
331
332    #[test]
333    fn nbest_list_push_under_capacity() {
334        let mut list = NBestList::new(5);
335        for i in 0..3u32 {
336            list.push(Hypothesis::new(vec![i], -(i as f64)));
337        }
338        assert_eq!(list.len(), 3);
339        assert!(!list.is_full());
340    }
341
342    #[test]
343    fn nbest_list_push_over_capacity() {
344        let mut list = NBestList::new(3);
345        // Push 5 hypotheses with scores -0.0, -1.0, -2.0, -3.0, -4.0
346        for i in 0..5u32 {
347            list.push(Hypothesis::new(vec![i], -(i as f64)));
348        }
349        assert_eq!(list.len(), 3);
350        // Should keep the three best (i=0,1,2 with scores 0,-1,-2)
351        let sorted = list.into_sorted();
352        assert_eq!(sorted.len(), 3);
353        // Best should be token [0] with score 0.0
354        assert_eq!(sorted[0].tokens, vec![0]);
355    }
356
357    #[test]
358    fn nbest_list_worst_score() {
359        let mut list = NBestList::new(3);
360        list.push(Hypothesis::new(vec![1], -1.0));
361        list.push(Hypothesis::new(vec![2], -2.0));
362        list.push(Hypothesis::new(vec![3], -3.0));
363        let worst = list.worst_score().expect("should have worst score");
364        assert!((worst - -3.0).abs() < 1e-9);
365    }
366
367    #[test]
368    fn nbest_list_into_sorted_order() {
369        let mut list = NBestList::new(5);
370        list.push(Hypothesis::new(vec![1], -3.0));
371        list.push(Hypothesis::new(vec![2], -1.0));
372        list.push(Hypothesis::new(vec![3], -2.0));
373        let sorted = list.into_sorted();
374        assert_eq!(sorted.len(), 3);
375        // Best first
376        assert!((sorted[0].log_prob - -1.0).abs() < 1e-9);
377        assert!((sorted[1].log_prob - -2.0).abs() < 1e-9);
378        assert!((sorted[2].log_prob - -3.0).abs() < 1e-9);
379    }
380
381    #[test]
382    fn nbest_list_complete_hypotheses() {
383        let mut list = NBestList::new(5);
384        list.push(Hypothesis::new(vec![1], -1.0).complete(2));
385        list.push(Hypothesis::new(vec![3], -2.0));
386        let complete = list.complete_hypotheses();
387        assert_eq!(complete.len(), 1);
388        assert!(complete[0].is_complete);
389    }
390
391    #[test]
392    fn nbest_decoder_step_expands() {
393        let decoder = NBestDecoder::new(5, 99, 20);
394        let hyps = vec![Hypothesis::new(vec![1], -0.5)];
395        let logits = vec![vec![0.0f32; 10]];
396        let expanded = decoder.step(&hyps, &logits, 3);
397        assert!(expanded.len() >= 3);
398    }
399
400    #[test]
401    fn nbest_decoder_partition() {
402        let active_h = Hypothesis::new(vec![1], -1.0);
403        let complete_h = Hypothesis::new(vec![2], -2.0).complete(2);
404        let (active, complete) = NBestDecoder::partition(vec![active_h, complete_h]);
405        assert_eq!(active.len(), 1);
406        assert_eq!(complete.len(), 1);
407        assert!(!active[0].is_complete);
408        assert!(complete[0].is_complete);
409    }
410
411    #[test]
412    fn nbest_decoder_eos_completes() {
413        let eos = 2u32;
414        let decoder = NBestDecoder::new(5, eos, 20);
415        let hyps = vec![Hypothesis::new(vec![1], -0.5)];
416        // Give EOS token the highest logit
417        let mut logits = vec![f32::NEG_INFINITY; 5];
418        logits[eos as usize] = 10.0;
419        let expanded = decoder.step(&hyps, &[logits], 1);
420        assert!(!expanded.is_empty());
421        assert!(expanded[0].is_complete);
422    }
423
424    #[test]
425    fn nbest_decoder_length_penalty() {
426        // Longer sequences have lower normalised score when log_prob is proportional.
427        let h_short = Hypothesis::new(vec![1], -1.0);
428        let h_long = Hypothesis::new(vec![1, 2, 3, 4], -6.0);
429        // short: -1.0/1=-1.0; long: -6.0/4=-1.5
430        assert!(h_short.score() > h_long.score());
431    }
432}