Skip to main content

oxillama_runtime/
beam_search.rs

1//! Beam search decoding for sequence generation.
2//!
3//! Implements a full beam search decoder over an abstract forward-pass
4//! interface.  The engine's `forward()` call is abstracted behind
5//! [`BeamForwardPass`] so that both the real [`crate::engine::InferenceEngine`] and
6//! test-only stubs can drive the algorithm.
7//!
8//! # Algorithm
9//!
10//! 1. Start with a single beam containing the prompt tokens.
11//! 2. For each step up to `max_new_tokens`:
12//!    a. For each active (unfinished) beam, call `forward(tokens)` to get logits.
13//!    b. Compute log-softmax of the logits.
14//!    c. Expand each beam to `beam_width` candidates (top-k log-probs).
15//!    d. Keep the global top `beam_width` unique candidates across all expanded beams.
16//!    e. If a candidate produces the EOS token, mark its beam as finished.
17//!    f. If `early_stopping` is true and the best finished beam already scores
18//!    higher than all active beams can possibly score, stop.
19//! 3. Return all hypotheses (finished + active), sorted by normalised score
20//!    descending.
21//!
22//! # Normalised score
23//!
24//! `score = logprob_sum / (n_tokens ^ length_penalty)`
25//!
26//! A `length_penalty` of 1.0 divides by token count (balances short vs long).
27//! Values > 1.0 favour longer sequences.
28
29use crate::error::{RuntimeError, RuntimeResult};
30
31// ─── Public types ─────────────────────────────────────────────────────────────
32
33/// Configuration for the beam search decoder.
34#[derive(Debug, Clone)]
35pub struct BeamSearchConfig {
36    /// Number of beams to keep alive at each step (e.g. 4).
37    pub beam_width: usize,
38    /// Maximum number of new tokens to generate beyond the prompt.
39    pub max_new_tokens: usize,
40    /// Length-penalty exponent applied as `score = logprob_sum / len^length_penalty`.
41    ///
42    /// - `1.0` divides by length (neutral).
43    /// - Values above `1.0` favour longer sequences.
44    /// - Values below `1.0` favour shorter sequences.
45    pub length_penalty: f32,
46    /// Stop as soon as the best finished beam scores better than all active ones.
47    pub early_stopping: bool,
48}
49
50impl Default for BeamSearchConfig {
51    fn default() -> Self {
52        Self {
53            beam_width: 4,
54            max_new_tokens: 256,
55            length_penalty: 1.0,
56            early_stopping: true,
57        }
58    }
59}
60
61/// A single beam hypothesis produced by the decoder.
62#[derive(Debug, Clone)]
63pub struct BeamHypothesis {
64    /// Token IDs generated so far (includes the prompt tokens).
65    pub tokens: Vec<u32>,
66    /// Sum of log-probabilities of all generated (non-prompt) tokens.
67    pub logprob_sum: f32,
68    /// True when this beam ended with the EOS token.
69    pub finished: bool,
70}
71
72impl BeamHypothesis {
73    /// Compute the length-normalised score for ranking.
74    ///
75    /// `score = logprob_sum / n_generated_tokens ^ length_penalty`
76    ///
77    /// When `n_generated_tokens == 0` (no tokens beyond prompt), the score is 0.
78    pub fn score(&self, length_penalty: f32, prompt_len: usize) -> f32 {
79        let n_gen = self.tokens.len().saturating_sub(prompt_len);
80        if n_gen == 0 {
81            return 0.0;
82        }
83        let denom = (n_gen as f32).powf(length_penalty);
84        if denom > 0.0 {
85            self.logprob_sum / denom
86        } else {
87            f32::NEG_INFINITY
88        }
89    }
90}
91
92// ─── Forward-pass abstraction ─────────────────────────────────────────────────
93
94/// Abstraction over a forward pass that produces logits for a token sequence.
95///
96/// The real implementation is backed by [`crate::engine::InferenceEngine`]; test stubs
97/// can implement this trait with pre-computed logit sequences.
98pub trait BeamForwardPass {
99    /// Run the forward pass on `tokens` and return raw logits.
100    ///
101    /// The implementation is free to maintain internal state (KV cache, etc.)
102    /// but must be resettable via [`Self::reset`].
103    fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>>;
104
105    /// Reset the internal state (e.g. clear the KV cache) so a fresh
106    /// forward pass can be run for a different beam.
107    fn reset(&mut self);
108}
109
110// ─── Engine adapter ───────────────────────────────────────────────────────────
111
112/// Adapter that wraps [`crate::engine::InferenceEngine`] to implement [`BeamForwardPass`].
113///
114/// Each call to `forward_tokens` resets the KV cache, prefills the prompt
115/// tokens, and returns the logits for the last token.
116pub struct EngineBeamAdapter<'a> {
117    engine: &'a mut crate::engine::InferenceEngine,
118}
119
120impl<'a> EngineBeamAdapter<'a> {
121    /// Create an adapter over a loaded engine.
122    pub fn new(engine: &'a mut crate::engine::InferenceEngine) -> Self {
123        Self { engine }
124    }
125}
126
127impl BeamForwardPass for EngineBeamAdapter<'_> {
128    fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>> {
129        if tokens.is_empty() {
130            return Err(RuntimeError::ModelLoadError {
131                message: "beam search: forward_tokens called with empty token slice".to_string(),
132            });
133        }
134        // Use forward_one for the last token; the KV cache must already be
135        // primed for all preceding tokens.  For beam search we re-run the
136        // whole sequence from scratch (reset happens between beams).
137        let last = *tokens.last().ok_or_else(|| RuntimeError::ModelLoadError {
138            message: "beam search: token slice was empty after guard".to_string(),
139        })?;
140        // Process all tokens except the last to prime the KV cache.
141        if tokens.len() > 1 {
142            self.engine.prefill(&tokens[..tokens.len() - 1])?;
143        }
144        self.engine.forward_one(last)
145    }
146
147    fn reset(&mut self) {
148        self.engine.reset();
149    }
150}
151
152// ─── Beam search algorithm ────────────────────────────────────────────────────
153
154/// Run beam search decoding.
155///
156/// `engine`        — any type implementing [`BeamForwardPass`]
157/// `prompt_tokens` — initial token sequence (prompt)
158/// `config`        — beam search hyper-parameters
159/// `eos_token_id`  — token that signals end-of-sequence
160///
161/// Returns a list of [`BeamHypothesis`] sorted by normalised score descending.
162/// The list contains at most `config.beam_width` hypotheses.
163pub fn beam_generate<F: BeamForwardPass>(
164    engine: &mut F,
165    prompt_tokens: &[u32],
166    config: &BeamSearchConfig,
167    eos_token_id: u32,
168) -> RuntimeResult<Vec<BeamHypothesis>> {
169    if config.beam_width == 0 {
170        return Err(RuntimeError::ModelLoadError {
171            message: "beam_width must be >= 1".to_string(),
172        });
173    }
174    if prompt_tokens.is_empty() {
175        return Err(RuntimeError::ModelLoadError {
176            message: "beam search: prompt_tokens must not be empty".to_string(),
177        });
178    }
179
180    let prompt_len = prompt_tokens.len();
181
182    // ── Initialisation ────────────────────────────────────────────────────────
183    // Start with a single "beam" containing only the prompt.
184    let mut active_beams: Vec<BeamHypothesis> = vec![BeamHypothesis {
185        tokens: prompt_tokens.to_vec(),
186        logprob_sum: 0.0,
187        finished: false,
188    }];
189    let mut finished_beams: Vec<BeamHypothesis> = Vec::new();
190
191    // ── Decode loop ───────────────────────────────────────────────────────────
192    for _step in 0..config.max_new_tokens {
193        if active_beams.is_empty() {
194            break;
195        }
196
197        // For each active beam, expand to `beam_width` candidates.
198        // A candidate is a (hypothesis, new_token, added_logprob) triple.
199        let mut candidates: Vec<(BeamHypothesis, u32, f32)> = Vec::new();
200
201        for beam in &active_beams {
202            // Reset engine state, then run forward pass for this beam's tokens.
203            engine.reset();
204            let logits = engine.forward_tokens(&beam.tokens)?;
205
206            // Log-softmax to obtain per-token log-probabilities.
207            let log_probs = log_softmax(&logits);
208
209            // Pick the top `beam_width` tokens from this beam.
210            let mut token_logprob_pairs: Vec<(u32, f32)> = log_probs
211                .iter()
212                .enumerate()
213                .map(|(i, &lp)| (i as u32, lp))
214                .collect();
215            // Sort by log-probability descending (highest first).
216            token_logprob_pairs.sort_unstable_by(|a, b| {
217                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
218            });
219            token_logprob_pairs.truncate(config.beam_width);
220
221            for (token, lp) in token_logprob_pairs {
222                let mut new_tokens = beam.tokens.clone();
223                new_tokens.push(token);
224                let new_logprob_sum = beam.logprob_sum + lp;
225                let finished = token == eos_token_id;
226                candidates.push((
227                    BeamHypothesis {
228                        tokens: new_tokens,
229                        logprob_sum: new_logprob_sum,
230                        finished,
231                    },
232                    token,
233                    lp,
234                ));
235            }
236        }
237
238        // ── Prune to beam_width global best ───────────────────────────────────
239        // Sort all candidates by their normalised score (descending).
240        candidates.sort_unstable_by(|(a, _, _), (b, _, _)| {
241            b.score(config.length_penalty, prompt_len)
242                .partial_cmp(&a.score(config.length_penalty, prompt_len))
243                .unwrap_or(std::cmp::Ordering::Equal)
244        });
245        candidates.truncate(config.beam_width);
246
247        // ── Separate finished from active ─────────────────────────────────────
248        active_beams.clear();
249        for (hyp, _token, _lp) in candidates {
250            if hyp.finished {
251                finished_beams.push(hyp);
252            } else {
253                active_beams.push(hyp);
254            }
255        }
256
257        // ── Early stopping ────────────────────────────────────────────────────
258        if config.early_stopping && !finished_beams.is_empty() {
259            // Compute the best finished beam score.
260            let best_finished_score = finished_beams
261                .iter()
262                .map(|h| h.score(config.length_penalty, prompt_len))
263                .fold(f32::NEG_INFINITY, f32::max);
264
265            // The best any active beam could ever score is its current logprob_sum
266            // divided by its current length (lower bound on future length → best
267            // possible score). If even that can't beat the best finished beam, stop.
268            let best_possible_active = active_beams
269                .iter()
270                .map(|h| {
271                    // Optimistic: assume the beam stops right now.
272                    h.score(config.length_penalty, prompt_len)
273                })
274                .fold(f32::NEG_INFINITY, f32::max);
275
276            if best_possible_active <= best_finished_score {
277                break;
278            }
279        }
280    }
281
282    // Collect all hypotheses.
283    let mut all_hyps: Vec<BeamHypothesis> = finished_beams;
284    all_hyps.extend(active_beams);
285
286    // Sort by normalised score descending.
287    all_hyps.sort_unstable_by(|a, b| {
288        b.score(config.length_penalty, prompt_len)
289            .partial_cmp(&a.score(config.length_penalty, prompt_len))
290            .unwrap_or(std::cmp::Ordering::Equal)
291    });
292
293    // Trim to at most beam_width results.
294    all_hyps.truncate(config.beam_width);
295
296    Ok(all_hyps)
297}
298
299// ─── Math helpers ─────────────────────────────────────────────────────────────
300
301/// Compute log-softmax of a logit vector, returning log-probabilities.
302///
303/// `log_softmax(x_i) = x_i - log(sum_j(exp(x_j - x_max)))`
304///
305/// The `x_max` subtraction prevents overflow.
306fn log_softmax(logits: &[f32]) -> Vec<f32> {
307    if logits.is_empty() {
308        return Vec::new();
309    }
310    let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
311    let exp_sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
312    let log_sum = exp_sum.ln();
313    logits.iter().map(|&v| (v - max_val) - log_sum).collect()
314}
315
316// ─── InferenceEngine integration ──────────────────────────────────────────────
317
318impl crate::engine::InferenceEngine {
319    /// Generate using beam search decoding.
320    ///
321    /// Wraps the engine in an [`EngineBeamAdapter`] and calls [`beam_generate`].
322    ///
323    /// Returns a list of [`BeamHypothesis`] sorted by normalised score
324    /// descending.  The hypotheses include the original prompt tokens in
325    /// `tokens`.
326    ///
327    /// # Errors
328    ///
329    /// Returns [`RuntimeError::ModelNotLoaded`] if no model has been loaded.
330    pub fn beam_generate(
331        &mut self,
332        prompt_tokens: &[u32],
333        config: &BeamSearchConfig,
334        eos_token_id: u32,
335    ) -> RuntimeResult<Vec<BeamHypothesis>> {
336        if !self.is_loaded() {
337            return Err(RuntimeError::ModelNotLoaded);
338        }
339        let mut adapter = EngineBeamAdapter::new(self);
340        beam_generate(&mut adapter, prompt_tokens, config, eos_token_id)
341    }
342}
343
344// ─── Tests ────────────────────────────────────────────────────────────────────
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    // ── Test-only stub engine ─────────────────────────────────────────────────
351
352    /// A stub `BeamForwardPass` backed by a fixed sequence of logit vectors.
353    ///
354    /// On each call to `forward_tokens`, the stub returns the next logit
355    /// vector in its pre-programmed sequence (indexed by generation step,
356    /// i.e. `tokens.len() - prompt_len`).  If the sequence is exhausted, the
357    /// last vector is repeated.
358    ///
359    /// `reset()` rewinds the step counter so multiple beams can reuse the stub.
360    struct StubEngine {
361        /// Logit vectors for step 0, 1, 2, … (indexed by `tokens.len() - prompt_len`).
362        logit_seq: Vec<Vec<f32>>,
363        /// Length of the prompt (so we can compute the step index).
364        prompt_len: usize,
365    }
366
367    impl StubEngine {
368        fn new(prompt_len: usize, logit_seq: Vec<Vec<f32>>) -> Self {
369            Self {
370                logit_seq,
371                prompt_len,
372            }
373        }
374    }
375
376    impl BeamForwardPass for StubEngine {
377        fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>> {
378            // Step index = how many tokens beyond the prompt have been generated.
379            let step = tokens.len().saturating_sub(self.prompt_len);
380            let idx = step.min(self.logit_seq.len().saturating_sub(1));
381            Ok(self.logit_seq[idx].clone())
382        }
383
384        fn reset(&mut self) {
385            // Stateless stub — nothing to reset.
386        }
387    }
388
389    // ── Score formula tests ───────────────────────────────────────────────────
390
391    #[test]
392    fn beam_hypothesis_score_applies_length_penalty() {
393        // A hypothesis with 2 generated tokens (beyond the prompt of length 1).
394        // logprob_sum = -4.0, n_gen = 2.
395        // With length_penalty = 2.0: score = -4.0 / 2^2 = -4.0 / 4 = -1.0
396        let hyp = BeamHypothesis {
397            tokens: vec![10u32, 20, 30], // prompt_len = 1, so 2 generated
398            logprob_sum: -4.0,
399            finished: false,
400        };
401        let score = hyp.score(2.0, 1);
402        let expected = -4.0f32 / 4.0f32;
403        assert!(
404            (score - expected).abs() < 1e-5,
405            "score with penalty=2.0 should be {expected}, got {score}"
406        );
407    }
408
409    #[test]
410    fn beam_hypothesis_score_neutral_length_penalty() {
411        // length_penalty = 1.0: score = logprob_sum / n_generated_tokens.
412        let hyp = BeamHypothesis {
413            tokens: vec![1u32, 2, 3, 4], // prompt_len = 2 → 2 generated tokens
414            logprob_sum: -6.0,
415            finished: false,
416        };
417        let score = hyp.score(1.0, 2);
418        let expected = -6.0f32 / 2.0f32;
419        assert!(
420            (score - expected).abs() < 1e-5,
421            "neutral score should be {expected}, got {score}"
422        );
423    }
424
425    #[test]
426    fn beam_hypothesis_score_zero_when_no_generated_tokens() {
427        // No generated tokens beyond the prompt → score = 0.
428        let hyp = BeamHypothesis {
429            tokens: vec![1u32, 2],
430            logprob_sum: -99.0,
431            finished: false,
432        };
433        let score = hyp.score(1.0, 2); // prompt_len == tokens.len()
434        assert_eq!(score, 0.0, "score must be 0.0 when no tokens are generated");
435    }
436
437    // ── Beam width one matches greedy ─────────────────────────────────────────
438
439    #[test]
440    fn beam_search_width_one_matches_greedy() {
441        // With beam_width=1 and a deterministic stub that always returns the
442        // same logits, beam search should produce the same sequence as greedy
443        // (argmax at each step).
444        //
445        // Vocab size = 4; EOS = 3.
446        // Logits at every step: [0.0, 5.0, 2.0, -10.0]
447        // → argmax = token 1 every time.
448        let logits_per_step = vec![vec![0.0f32, 5.0, 2.0, -10.0]; 5];
449        let prompt = vec![0u32];
450        let eos = 3u32;
451
452        let mut engine = StubEngine::new(prompt.len(), logits_per_step.clone());
453        let config = BeamSearchConfig {
454            beam_width: 1,
455            max_new_tokens: 3,
456            length_penalty: 1.0,
457            early_stopping: false,
458        };
459        let hyps =
460            beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
461        assert!(!hyps.is_empty(), "must produce at least one hypothesis");
462
463        // The only hypothesis should contain [prompt, 1, 1, 1] (greedy picks token 1).
464        let best = &hyps[0];
465        assert_eq!(
466            &best.tokens[prompt.len()..],
467            &[1u32, 1, 1],
468            "beam_width=1 should match greedy decode (token 1 at each step)"
469        );
470    }
471
472    // ── Beam width four returns four hypotheses ───────────────────────────────
473
474    #[test]
475    fn beam_width_four_returns_four_hypotheses() {
476        // Vocab size = 8, EOS = 7.
477        // Logits spread so all 4 beams stay active (no EOS in top-4).
478        // Logits: [10, 9, 8, 7, 6, 5, 4, -100]  → top-4 = tokens 0,1,2,3
479        let logits: Vec<f32> = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, -100.0];
480        let logit_seq = vec![logits; 4];
481
482        let prompt = vec![100u32];
483        let eos = 7u32;
484
485        let mut engine = StubEngine::new(prompt.len(), logit_seq);
486        let config = BeamSearchConfig {
487            beam_width: 4,
488            max_new_tokens: 2,
489            length_penalty: 1.0,
490            early_stopping: false,
491        };
492        let hyps =
493            beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
494        assert_eq!(
495            hyps.len(),
496            4,
497            "beam_width=4 should return 4 hypotheses, got {}",
498            hyps.len()
499        );
500    }
501
502    // ── Early stopping terminates ─────────────────────────────────────────────
503
504    #[test]
505    fn beam_early_stopping_terminates() {
506        // Logits that always give a high probability to the EOS token.
507        // EOS = 1, vocab = 3.
508        // Logits: [0.0, 100.0, 0.0]  → EOS (token 1) is overwhelmingly likely.
509        //
510        // With beam_width=2 and early_stopping=true, the first step should
511        // produce at least one finished beam (EOS), which then scores better
512        // than the remaining active beam, causing early termination.
513        let logits_step0 = vec![0.0f32, 100.0, 0.0]; // EOS dominates
514        let logit_seq = vec![logits_step0; 5];
515
516        let prompt = vec![0u32];
517        let eos = 1u32;
518
519        let mut engine = StubEngine::new(prompt.len(), logit_seq);
520        let config = BeamSearchConfig {
521            beam_width: 2,
522            max_new_tokens: 10,
523            length_penalty: 1.0,
524            early_stopping: true,
525        };
526        let hyps =
527            beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
528
529        // At least the finished EOS hypothesis must be present.
530        assert!(!hyps.is_empty(), "must return at least one hypothesis");
531        // The best hypothesis should be finished (ended with EOS).
532        // It's possible early_stopping didn't fire on step 1 if the active beam
533        // still beats it; at minimum, a finished beam should appear.
534        let has_finished = hyps.iter().any(|h| h.finished);
535        assert!(
536            has_finished,
537            "at least one finished hypothesis should exist"
538        );
539    }
540
541    // ── log_softmax correctness ────────────────────────────────────────────────
542
543    #[test]
544    fn log_softmax_sums_to_one_in_prob_space() {
545        let logits = vec![1.0f32, 2.0, 3.0, 4.0];
546        let lps = log_softmax(&logits);
547        let prob_sum: f32 = lps.iter().map(|&lp| lp.exp()).sum();
548        assert!(
549            (prob_sum - 1.0).abs() < 1e-5,
550            "exp(log-softmax) must sum to 1, got {prob_sum}"
551        );
552    }
553
554    #[test]
555    fn log_softmax_empty_is_empty() {
556        let lps = log_softmax(&[]);
557        assert!(lps.is_empty());
558    }
559
560    #[test]
561    fn log_softmax_single_element_is_zero() {
562        let lps = log_softmax(&[5.0f32]);
563        assert!(
564            (lps[0] - 0.0).abs() < 1e-6,
565            "log-softmax of a single element must be 0, got {}",
566            lps[0]
567        );
568    }
569
570    // ── Error-path tests ──────────────────────────────────────────────────────
571
572    #[test]
573    fn beam_search_errors_on_zero_beam_width() {
574        let prompt = vec![1u32];
575        let mut engine = StubEngine::new(1, vec![vec![1.0, 2.0, 3.0]]);
576        let config = BeamSearchConfig {
577            beam_width: 0,
578            ..BeamSearchConfig::default()
579        };
580        let result = beam_generate(&mut engine, &prompt, &config, 0);
581        assert!(result.is_err(), "beam_width=0 should return an error");
582    }
583
584    #[test]
585    fn beam_search_errors_on_empty_prompt() {
586        let mut engine = StubEngine::new(0, vec![vec![1.0, 2.0, 3.0]]);
587        let config = BeamSearchConfig::default();
588        let result = beam_generate(&mut engine, &[], &config, 0);
589        assert!(result.is_err(), "empty prompt should return an error");
590    }
591
592    // ── Engine integration (no model loaded) ─────────────────────────────────
593
594    #[test]
595    fn engine_beam_generate_errors_when_not_loaded() {
596        let mut engine =
597            crate::engine::InferenceEngine::new(crate::engine::EngineConfig::default());
598        let config = BeamSearchConfig::default();
599        let result = engine.beam_generate(&[1u32, 2], &config, 0);
600        assert!(
601            matches!(result, Err(RuntimeError::ModelNotLoaded)),
602            "unloaded engine should return ModelNotLoaded, got {:?}",
603            result
604        );
605    }
606}