oxillama_runtime/beam_search.rs
1//! Beam search decoding for sequence generation.
2//!
3//! Implements a full beam search decoder over an abstract forward-pass
4//! interface. The engine's `forward()` call is abstracted behind
5//! [`BeamForwardPass`] so that both the real [`crate::engine::InferenceEngine`] and
6//! test-only stubs can drive the algorithm.
7//!
8//! # Algorithm
9//!
10//! 1. Start with a single beam containing the prompt tokens.
11//! 2. For each step up to `max_new_tokens`:
12//! a. For each active (unfinished) beam, call `forward(tokens)` to get logits.
13//! b. Compute log-softmax of the logits.
14//! c. Expand each beam to `beam_width` candidates (top-k log-probs).
15//! d. Keep the global top `beam_width` unique candidates across all expanded beams.
16//! e. If a candidate produces the EOS token, mark its beam as finished.
17//! f. If `early_stopping` is true and the best finished beam already scores
18//! higher than all active beams can possibly score, stop.
19//! 3. Return all hypotheses (finished + active), sorted by normalised score
20//! descending.
21//!
22//! # Normalised score
23//!
24//! `score = logprob_sum / (n_tokens ^ length_penalty)`
25//!
26//! A `length_penalty` of 1.0 divides by token count (balances short vs long).
27//! Values > 1.0 favour longer sequences.
28
29use crate::error::{RuntimeError, RuntimeResult};
30
31// ─── Public types ─────────────────────────────────────────────────────────────
32
33/// Configuration for the beam search decoder.
34#[derive(Debug, Clone)]
35pub struct BeamSearchConfig {
36 /// Number of beams to keep alive at each step (e.g. 4).
37 pub beam_width: usize,
38 /// Maximum number of new tokens to generate beyond the prompt.
39 pub max_new_tokens: usize,
40 /// Length-penalty exponent applied as `score = logprob_sum / len^length_penalty`.
41 ///
42 /// - `1.0` divides by length (neutral).
43 /// - Values above `1.0` favour longer sequences.
44 /// - Values below `1.0` favour shorter sequences.
45 pub length_penalty: f32,
46 /// Stop as soon as the best finished beam scores better than all active ones.
47 pub early_stopping: bool,
48}
49
50impl Default for BeamSearchConfig {
51 fn default() -> Self {
52 Self {
53 beam_width: 4,
54 max_new_tokens: 256,
55 length_penalty: 1.0,
56 early_stopping: true,
57 }
58 }
59}
60
61/// A single beam hypothesis produced by the decoder.
62#[derive(Debug, Clone)]
63pub struct BeamHypothesis {
64 /// Token IDs generated so far (includes the prompt tokens).
65 pub tokens: Vec<u32>,
66 /// Sum of log-probabilities of all generated (non-prompt) tokens.
67 pub logprob_sum: f32,
68 /// True when this beam ended with the EOS token.
69 pub finished: bool,
70}
71
72impl BeamHypothesis {
73 /// Compute the length-normalised score for ranking.
74 ///
75 /// `score = logprob_sum / n_generated_tokens ^ length_penalty`
76 ///
77 /// When `n_generated_tokens == 0` (no tokens beyond prompt), the score is 0.
78 pub fn score(&self, length_penalty: f32, prompt_len: usize) -> f32 {
79 let n_gen = self.tokens.len().saturating_sub(prompt_len);
80 if n_gen == 0 {
81 return 0.0;
82 }
83 let denom = (n_gen as f32).powf(length_penalty);
84 if denom > 0.0 {
85 self.logprob_sum / denom
86 } else {
87 f32::NEG_INFINITY
88 }
89 }
90}
91
92// ─── Forward-pass abstraction ─────────────────────────────────────────────────
93
94/// Abstraction over a forward pass that produces logits for a token sequence.
95///
96/// The real implementation is backed by [`crate::engine::InferenceEngine`]; test stubs
97/// can implement this trait with pre-computed logit sequences.
98pub trait BeamForwardPass {
99 /// Run the forward pass on `tokens` and return raw logits.
100 ///
101 /// The implementation is free to maintain internal state (KV cache, etc.)
102 /// but must be resettable via [`Self::reset`].
103 fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>>;
104
105 /// Reset the internal state (e.g. clear the KV cache) so a fresh
106 /// forward pass can be run for a different beam.
107 fn reset(&mut self);
108}
109
110// ─── Engine adapter ───────────────────────────────────────────────────────────
111
112/// Adapter that wraps [`crate::engine::InferenceEngine`] to implement [`BeamForwardPass`].
113///
114/// Each call to `forward_tokens` resets the KV cache, prefills the prompt
115/// tokens, and returns the logits for the last token.
116pub struct EngineBeamAdapter<'a> {
117 engine: &'a mut crate::engine::InferenceEngine,
118}
119
120impl<'a> EngineBeamAdapter<'a> {
121 /// Create an adapter over a loaded engine.
122 pub fn new(engine: &'a mut crate::engine::InferenceEngine) -> Self {
123 Self { engine }
124 }
125}
126
127impl BeamForwardPass for EngineBeamAdapter<'_> {
128 fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>> {
129 if tokens.is_empty() {
130 return Err(RuntimeError::ModelLoadError {
131 message: "beam search: forward_tokens called with empty token slice".to_string(),
132 });
133 }
134 // Use forward_one for the last token; the KV cache must already be
135 // primed for all preceding tokens. For beam search we re-run the
136 // whole sequence from scratch (reset happens between beams).
137 let last = *tokens.last().ok_or_else(|| RuntimeError::ModelLoadError {
138 message: "beam search: token slice was empty after guard".to_string(),
139 })?;
140 // Process all tokens except the last to prime the KV cache.
141 if tokens.len() > 1 {
142 self.engine.prefill(&tokens[..tokens.len() - 1])?;
143 }
144 self.engine.forward_one(last)
145 }
146
147 fn reset(&mut self) {
148 self.engine.reset();
149 }
150}
151
152// ─── Beam search algorithm ────────────────────────────────────────────────────
153
154/// Run beam search decoding.
155///
156/// `engine` — any type implementing [`BeamForwardPass`]
157/// `prompt_tokens` — initial token sequence (prompt)
158/// `config` — beam search hyper-parameters
159/// `eos_token_id` — token that signals end-of-sequence
160///
161/// Returns a list of [`BeamHypothesis`] sorted by normalised score descending.
162/// The list contains at most `config.beam_width` hypotheses.
163pub fn beam_generate<F: BeamForwardPass>(
164 engine: &mut F,
165 prompt_tokens: &[u32],
166 config: &BeamSearchConfig,
167 eos_token_id: u32,
168) -> RuntimeResult<Vec<BeamHypothesis>> {
169 if config.beam_width == 0 {
170 return Err(RuntimeError::ModelLoadError {
171 message: "beam_width must be >= 1".to_string(),
172 });
173 }
174 if prompt_tokens.is_empty() {
175 return Err(RuntimeError::ModelLoadError {
176 message: "beam search: prompt_tokens must not be empty".to_string(),
177 });
178 }
179
180 let prompt_len = prompt_tokens.len();
181
182 // ── Initialisation ────────────────────────────────────────────────────────
183 // Start with a single "beam" containing only the prompt.
184 let mut active_beams: Vec<BeamHypothesis> = vec![BeamHypothesis {
185 tokens: prompt_tokens.to_vec(),
186 logprob_sum: 0.0,
187 finished: false,
188 }];
189 let mut finished_beams: Vec<BeamHypothesis> = Vec::new();
190
191 // ── Decode loop ───────────────────────────────────────────────────────────
192 for _step in 0..config.max_new_tokens {
193 if active_beams.is_empty() {
194 break;
195 }
196
197 // For each active beam, expand to `beam_width` candidates.
198 // A candidate is a (hypothesis, new_token, added_logprob) triple.
199 let mut candidates: Vec<(BeamHypothesis, u32, f32)> = Vec::new();
200
201 for beam in &active_beams {
202 // Reset engine state, then run forward pass for this beam's tokens.
203 engine.reset();
204 let logits = engine.forward_tokens(&beam.tokens)?;
205
206 // Log-softmax to obtain per-token log-probabilities.
207 let log_probs = log_softmax(&logits);
208
209 // Pick the top `beam_width` tokens from this beam.
210 let mut token_logprob_pairs: Vec<(u32, f32)> = log_probs
211 .iter()
212 .enumerate()
213 .map(|(i, &lp)| (i as u32, lp))
214 .collect();
215 // Sort by log-probability descending (highest first).
216 token_logprob_pairs.sort_unstable_by(|a, b| {
217 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
218 });
219 token_logprob_pairs.truncate(config.beam_width);
220
221 for (token, lp) in token_logprob_pairs {
222 let mut new_tokens = beam.tokens.clone();
223 new_tokens.push(token);
224 let new_logprob_sum = beam.logprob_sum + lp;
225 let finished = token == eos_token_id;
226 candidates.push((
227 BeamHypothesis {
228 tokens: new_tokens,
229 logprob_sum: new_logprob_sum,
230 finished,
231 },
232 token,
233 lp,
234 ));
235 }
236 }
237
238 // ── Prune to beam_width global best ───────────────────────────────────
239 // Sort all candidates by their normalised score (descending).
240 candidates.sort_unstable_by(|(a, _, _), (b, _, _)| {
241 b.score(config.length_penalty, prompt_len)
242 .partial_cmp(&a.score(config.length_penalty, prompt_len))
243 .unwrap_or(std::cmp::Ordering::Equal)
244 });
245 candidates.truncate(config.beam_width);
246
247 // ── Separate finished from active ─────────────────────────────────────
248 active_beams.clear();
249 for (hyp, _token, _lp) in candidates {
250 if hyp.finished {
251 finished_beams.push(hyp);
252 } else {
253 active_beams.push(hyp);
254 }
255 }
256
257 // ── Early stopping ────────────────────────────────────────────────────
258 if config.early_stopping && !finished_beams.is_empty() {
259 // Compute the best finished beam score.
260 let best_finished_score = finished_beams
261 .iter()
262 .map(|h| h.score(config.length_penalty, prompt_len))
263 .fold(f32::NEG_INFINITY, f32::max);
264
265 // The best any active beam could ever score is its current logprob_sum
266 // divided by its current length (lower bound on future length → best
267 // possible score). If even that can't beat the best finished beam, stop.
268 let best_possible_active = active_beams
269 .iter()
270 .map(|h| {
271 // Optimistic: assume the beam stops right now.
272 h.score(config.length_penalty, prompt_len)
273 })
274 .fold(f32::NEG_INFINITY, f32::max);
275
276 if best_possible_active <= best_finished_score {
277 break;
278 }
279 }
280 }
281
282 // Collect all hypotheses.
283 let mut all_hyps: Vec<BeamHypothesis> = finished_beams;
284 all_hyps.extend(active_beams);
285
286 // Sort by normalised score descending.
287 all_hyps.sort_unstable_by(|a, b| {
288 b.score(config.length_penalty, prompt_len)
289 .partial_cmp(&a.score(config.length_penalty, prompt_len))
290 .unwrap_or(std::cmp::Ordering::Equal)
291 });
292
293 // Trim to at most beam_width results.
294 all_hyps.truncate(config.beam_width);
295
296 Ok(all_hyps)
297}
298
299// ─── Math helpers ─────────────────────────────────────────────────────────────
300
301/// Compute log-softmax of a logit vector, returning log-probabilities.
302///
303/// `log_softmax(x_i) = x_i - log(sum_j(exp(x_j - x_max)))`
304///
305/// The `x_max` subtraction prevents overflow.
306fn log_softmax(logits: &[f32]) -> Vec<f32> {
307 if logits.is_empty() {
308 return Vec::new();
309 }
310 let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
311 let exp_sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
312 let log_sum = exp_sum.ln();
313 logits.iter().map(|&v| (v - max_val) - log_sum).collect()
314}
315
316// ─── InferenceEngine integration ──────────────────────────────────────────────
317
318impl crate::engine::InferenceEngine {
319 /// Generate using beam search decoding.
320 ///
321 /// Wraps the engine in an [`EngineBeamAdapter`] and calls [`beam_generate`].
322 ///
323 /// Returns a list of [`BeamHypothesis`] sorted by normalised score
324 /// descending. The hypotheses include the original prompt tokens in
325 /// `tokens`.
326 ///
327 /// # Errors
328 ///
329 /// Returns [`RuntimeError::ModelNotLoaded`] if no model has been loaded.
330 pub fn beam_generate(
331 &mut self,
332 prompt_tokens: &[u32],
333 config: &BeamSearchConfig,
334 eos_token_id: u32,
335 ) -> RuntimeResult<Vec<BeamHypothesis>> {
336 if !self.is_loaded() {
337 return Err(RuntimeError::ModelNotLoaded);
338 }
339 let mut adapter = EngineBeamAdapter::new(self);
340 beam_generate(&mut adapter, prompt_tokens, config, eos_token_id)
341 }
342}
343
344// ─── Tests ────────────────────────────────────────────────────────────────────
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 // ── Test-only stub engine ─────────────────────────────────────────────────
351
352 /// A stub `BeamForwardPass` backed by a fixed sequence of logit vectors.
353 ///
354 /// On each call to `forward_tokens`, the stub returns the next logit
355 /// vector in its pre-programmed sequence (indexed by generation step,
356 /// i.e. `tokens.len() - prompt_len`). If the sequence is exhausted, the
357 /// last vector is repeated.
358 ///
359 /// `reset()` rewinds the step counter so multiple beams can reuse the stub.
360 struct StubEngine {
361 /// Logit vectors for step 0, 1, 2, … (indexed by `tokens.len() - prompt_len`).
362 logit_seq: Vec<Vec<f32>>,
363 /// Length of the prompt (so we can compute the step index).
364 prompt_len: usize,
365 }
366
367 impl StubEngine {
368 fn new(prompt_len: usize, logit_seq: Vec<Vec<f32>>) -> Self {
369 Self {
370 logit_seq,
371 prompt_len,
372 }
373 }
374 }
375
376 impl BeamForwardPass for StubEngine {
377 fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>> {
378 // Step index = how many tokens beyond the prompt have been generated.
379 let step = tokens.len().saturating_sub(self.prompt_len);
380 let idx = step.min(self.logit_seq.len().saturating_sub(1));
381 Ok(self.logit_seq[idx].clone())
382 }
383
384 fn reset(&mut self) {
385 // Stateless stub — nothing to reset.
386 }
387 }
388
389 // ── Score formula tests ───────────────────────────────────────────────────
390
391 #[test]
392 fn beam_hypothesis_score_applies_length_penalty() {
393 // A hypothesis with 2 generated tokens (beyond the prompt of length 1).
394 // logprob_sum = -4.0, n_gen = 2.
395 // With length_penalty = 2.0: score = -4.0 / 2^2 = -4.0 / 4 = -1.0
396 let hyp = BeamHypothesis {
397 tokens: vec![10u32, 20, 30], // prompt_len = 1, so 2 generated
398 logprob_sum: -4.0,
399 finished: false,
400 };
401 let score = hyp.score(2.0, 1);
402 let expected = -4.0f32 / 4.0f32;
403 assert!(
404 (score - expected).abs() < 1e-5,
405 "score with penalty=2.0 should be {expected}, got {score}"
406 );
407 }
408
409 #[test]
410 fn beam_hypothesis_score_neutral_length_penalty() {
411 // length_penalty = 1.0: score = logprob_sum / n_generated_tokens.
412 let hyp = BeamHypothesis {
413 tokens: vec![1u32, 2, 3, 4], // prompt_len = 2 → 2 generated tokens
414 logprob_sum: -6.0,
415 finished: false,
416 };
417 let score = hyp.score(1.0, 2);
418 let expected = -6.0f32 / 2.0f32;
419 assert!(
420 (score - expected).abs() < 1e-5,
421 "neutral score should be {expected}, got {score}"
422 );
423 }
424
425 #[test]
426 fn beam_hypothesis_score_zero_when_no_generated_tokens() {
427 // No generated tokens beyond the prompt → score = 0.
428 let hyp = BeamHypothesis {
429 tokens: vec![1u32, 2],
430 logprob_sum: -99.0,
431 finished: false,
432 };
433 let score = hyp.score(1.0, 2); // prompt_len == tokens.len()
434 assert_eq!(score, 0.0, "score must be 0.0 when no tokens are generated");
435 }
436
437 // ── Beam width one matches greedy ─────────────────────────────────────────
438
439 #[test]
440 fn beam_search_width_one_matches_greedy() {
441 // With beam_width=1 and a deterministic stub that always returns the
442 // same logits, beam search should produce the same sequence as greedy
443 // (argmax at each step).
444 //
445 // Vocab size = 4; EOS = 3.
446 // Logits at every step: [0.0, 5.0, 2.0, -10.0]
447 // → argmax = token 1 every time.
448 let logits_per_step = vec![vec![0.0f32, 5.0, 2.0, -10.0]; 5];
449 let prompt = vec![0u32];
450 let eos = 3u32;
451
452 let mut engine = StubEngine::new(prompt.len(), logits_per_step.clone());
453 let config = BeamSearchConfig {
454 beam_width: 1,
455 max_new_tokens: 3,
456 length_penalty: 1.0,
457 early_stopping: false,
458 };
459 let hyps =
460 beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
461 assert!(!hyps.is_empty(), "must produce at least one hypothesis");
462
463 // The only hypothesis should contain [prompt, 1, 1, 1] (greedy picks token 1).
464 let best = &hyps[0];
465 assert_eq!(
466 &best.tokens[prompt.len()..],
467 &[1u32, 1, 1],
468 "beam_width=1 should match greedy decode (token 1 at each step)"
469 );
470 }
471
472 // ── Beam width four returns four hypotheses ───────────────────────────────
473
474 #[test]
475 fn beam_width_four_returns_four_hypotheses() {
476 // Vocab size = 8, EOS = 7.
477 // Logits spread so all 4 beams stay active (no EOS in top-4).
478 // Logits: [10, 9, 8, 7, 6, 5, 4, -100] → top-4 = tokens 0,1,2,3
479 let logits: Vec<f32> = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, -100.0];
480 let logit_seq = vec![logits; 4];
481
482 let prompt = vec![100u32];
483 let eos = 7u32;
484
485 let mut engine = StubEngine::new(prompt.len(), logit_seq);
486 let config = BeamSearchConfig {
487 beam_width: 4,
488 max_new_tokens: 2,
489 length_penalty: 1.0,
490 early_stopping: false,
491 };
492 let hyps =
493 beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
494 assert_eq!(
495 hyps.len(),
496 4,
497 "beam_width=4 should return 4 hypotheses, got {}",
498 hyps.len()
499 );
500 }
501
502 // ── Early stopping terminates ─────────────────────────────────────────────
503
504 #[test]
505 fn beam_early_stopping_terminates() {
506 // Logits that always give a high probability to the EOS token.
507 // EOS = 1, vocab = 3.
508 // Logits: [0.0, 100.0, 0.0] → EOS (token 1) is overwhelmingly likely.
509 //
510 // With beam_width=2 and early_stopping=true, the first step should
511 // produce at least one finished beam (EOS), which then scores better
512 // than the remaining active beam, causing early termination.
513 let logits_step0 = vec![0.0f32, 100.0, 0.0]; // EOS dominates
514 let logit_seq = vec![logits_step0; 5];
515
516 let prompt = vec![0u32];
517 let eos = 1u32;
518
519 let mut engine = StubEngine::new(prompt.len(), logit_seq);
520 let config = BeamSearchConfig {
521 beam_width: 2,
522 max_new_tokens: 10,
523 length_penalty: 1.0,
524 early_stopping: true,
525 };
526 let hyps =
527 beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
528
529 // At least the finished EOS hypothesis must be present.
530 assert!(!hyps.is_empty(), "must return at least one hypothesis");
531 // The best hypothesis should be finished (ended with EOS).
532 // It's possible early_stopping didn't fire on step 1 if the active beam
533 // still beats it; at minimum, a finished beam should appear.
534 let has_finished = hyps.iter().any(|h| h.finished);
535 assert!(
536 has_finished,
537 "at least one finished hypothesis should exist"
538 );
539 }
540
541 // ── log_softmax correctness ────────────────────────────────────────────────
542
543 #[test]
544 fn log_softmax_sums_to_one_in_prob_space() {
545 let logits = vec![1.0f32, 2.0, 3.0, 4.0];
546 let lps = log_softmax(&logits);
547 let prob_sum: f32 = lps.iter().map(|&lp| lp.exp()).sum();
548 assert!(
549 (prob_sum - 1.0).abs() < 1e-5,
550 "exp(log-softmax) must sum to 1, got {prob_sum}"
551 );
552 }
553
554 #[test]
555 fn log_softmax_empty_is_empty() {
556 let lps = log_softmax(&[]);
557 assert!(lps.is_empty());
558 }
559
560 #[test]
561 fn log_softmax_single_element_is_zero() {
562 let lps = log_softmax(&[5.0f32]);
563 assert!(
564 (lps[0] - 0.0).abs() < 1e-6,
565 "log-softmax of a single element must be 0, got {}",
566 lps[0]
567 );
568 }
569
570 // ── Error-path tests ──────────────────────────────────────────────────────
571
572 #[test]
573 fn beam_search_errors_on_zero_beam_width() {
574 let prompt = vec![1u32];
575 let mut engine = StubEngine::new(1, vec![vec![1.0, 2.0, 3.0]]);
576 let config = BeamSearchConfig {
577 beam_width: 0,
578 ..BeamSearchConfig::default()
579 };
580 let result = beam_generate(&mut engine, &prompt, &config, 0);
581 assert!(result.is_err(), "beam_width=0 should return an error");
582 }
583
584 #[test]
585 fn beam_search_errors_on_empty_prompt() {
586 let mut engine = StubEngine::new(0, vec![vec![1.0, 2.0, 3.0]]);
587 let config = BeamSearchConfig::default();
588 let result = beam_generate(&mut engine, &[], &config, 0);
589 assert!(result.is_err(), "empty prompt should return an error");
590 }
591
592 // ── Engine integration (no model loaded) ─────────────────────────────────
593
594 #[test]
595 fn engine_beam_generate_errors_when_not_loaded() {
596 let mut engine =
597 crate::engine::InferenceEngine::new(crate::engine::EngineConfig::default());
598 let config = BeamSearchConfig::default();
599 let result = engine.beam_generate(&[1u32, 2], &config, 0);
600 assert!(
601 matches!(result, Err(RuntimeError::ModelNotLoaded)),
602 "unloaded engine should return ModelNotLoaded, got {:?}",
603 result
604 );
605 }
606}