use crate::error::{RuntimeError, RuntimeResult};
#[derive(Debug, Clone)]
pub struct BeamSearchConfig {
pub beam_width: usize,
pub max_new_tokens: usize,
pub length_penalty: f32,
pub early_stopping: bool,
}
impl Default for BeamSearchConfig {
fn default() -> Self {
Self {
beam_width: 4,
max_new_tokens: 256,
length_penalty: 1.0,
early_stopping: true,
}
}
}
#[derive(Debug, Clone)]
pub struct BeamHypothesis {
pub tokens: Vec<u32>,
pub logprob_sum: f32,
pub finished: bool,
}
impl BeamHypothesis {
pub fn score(&self, length_penalty: f32, prompt_len: usize) -> f32 {
let n_gen = self.tokens.len().saturating_sub(prompt_len);
if n_gen == 0 {
return 0.0;
}
let denom = (n_gen as f32).powf(length_penalty);
if denom > 0.0 {
self.logprob_sum / denom
} else {
f32::NEG_INFINITY
}
}
}
pub trait BeamForwardPass {
fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>>;
fn reset(&mut self);
}
pub struct EngineBeamAdapter<'a> {
engine: &'a mut crate::engine::InferenceEngine,
}
impl<'a> EngineBeamAdapter<'a> {
pub fn new(engine: &'a mut crate::engine::InferenceEngine) -> Self {
Self { engine }
}
}
impl BeamForwardPass for EngineBeamAdapter<'_> {
fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>> {
if tokens.is_empty() {
return Err(RuntimeError::ModelLoadError {
message: "beam search: forward_tokens called with empty token slice".to_string(),
});
}
let last = *tokens.last().ok_or_else(|| RuntimeError::ModelLoadError {
message: "beam search: token slice was empty after guard".to_string(),
})?;
if tokens.len() > 1 {
self.engine.prefill(&tokens[..tokens.len() - 1])?;
}
self.engine.forward_one(last)
}
fn reset(&mut self) {
self.engine.reset();
}
}
pub fn beam_generate<F: BeamForwardPass>(
engine: &mut F,
prompt_tokens: &[u32],
config: &BeamSearchConfig,
eos_token_id: u32,
) -> RuntimeResult<Vec<BeamHypothesis>> {
if config.beam_width == 0 {
return Err(RuntimeError::ModelLoadError {
message: "beam_width must be >= 1".to_string(),
});
}
if prompt_tokens.is_empty() {
return Err(RuntimeError::ModelLoadError {
message: "beam search: prompt_tokens must not be empty".to_string(),
});
}
let prompt_len = prompt_tokens.len();
let mut active_beams: Vec<BeamHypothesis> = vec![BeamHypothesis {
tokens: prompt_tokens.to_vec(),
logprob_sum: 0.0,
finished: false,
}];
let mut finished_beams: Vec<BeamHypothesis> = Vec::new();
for _step in 0..config.max_new_tokens {
if active_beams.is_empty() {
break;
}
let mut candidates: Vec<(BeamHypothesis, u32, f32)> = Vec::new();
for beam in &active_beams {
engine.reset();
let logits = engine.forward_tokens(&beam.tokens)?;
let log_probs = log_softmax(&logits);
let mut token_logprob_pairs: Vec<(u32, f32)> = log_probs
.iter()
.enumerate()
.map(|(i, &lp)| (i as u32, lp))
.collect();
token_logprob_pairs.sort_unstable_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
token_logprob_pairs.truncate(config.beam_width);
for (token, lp) in token_logprob_pairs {
let mut new_tokens = beam.tokens.clone();
new_tokens.push(token);
let new_logprob_sum = beam.logprob_sum + lp;
let finished = token == eos_token_id;
candidates.push((
BeamHypothesis {
tokens: new_tokens,
logprob_sum: new_logprob_sum,
finished,
},
token,
lp,
));
}
}
candidates.sort_unstable_by(|(a, _, _), (b, _, _)| {
b.score(config.length_penalty, prompt_len)
.partial_cmp(&a.score(config.length_penalty, prompt_len))
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(config.beam_width);
active_beams.clear();
for (hyp, _token, _lp) in candidates {
if hyp.finished {
finished_beams.push(hyp);
} else {
active_beams.push(hyp);
}
}
if config.early_stopping && !finished_beams.is_empty() {
let best_finished_score = finished_beams
.iter()
.map(|h| h.score(config.length_penalty, prompt_len))
.fold(f32::NEG_INFINITY, f32::max);
let best_possible_active = active_beams
.iter()
.map(|h| {
h.score(config.length_penalty, prompt_len)
})
.fold(f32::NEG_INFINITY, f32::max);
if best_possible_active <= best_finished_score {
break;
}
}
}
let mut all_hyps: Vec<BeamHypothesis> = finished_beams;
all_hyps.extend(active_beams);
all_hyps.sort_unstable_by(|a, b| {
b.score(config.length_penalty, prompt_len)
.partial_cmp(&a.score(config.length_penalty, prompt_len))
.unwrap_or(std::cmp::Ordering::Equal)
});
all_hyps.truncate(config.beam_width);
Ok(all_hyps)
}
fn log_softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
let log_sum = exp_sum.ln();
logits.iter().map(|&v| (v - max_val) - log_sum).collect()
}
impl crate::engine::InferenceEngine {
pub fn beam_generate(
&mut self,
prompt_tokens: &[u32],
config: &BeamSearchConfig,
eos_token_id: u32,
) -> RuntimeResult<Vec<BeamHypothesis>> {
if !self.is_loaded() {
return Err(RuntimeError::ModelNotLoaded);
}
let mut adapter = EngineBeamAdapter::new(self);
beam_generate(&mut adapter, prompt_tokens, config, eos_token_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct StubEngine {
logit_seq: Vec<Vec<f32>>,
prompt_len: usize,
}
impl StubEngine {
fn new(prompt_len: usize, logit_seq: Vec<Vec<f32>>) -> Self {
Self {
logit_seq,
prompt_len,
}
}
}
impl BeamForwardPass for StubEngine {
fn forward_tokens(&mut self, tokens: &[u32]) -> RuntimeResult<Vec<f32>> {
let step = tokens.len().saturating_sub(self.prompt_len);
let idx = step.min(self.logit_seq.len().saturating_sub(1));
Ok(self.logit_seq[idx].clone())
}
fn reset(&mut self) {
}
}
#[test]
fn beam_hypothesis_score_applies_length_penalty() {
let hyp = BeamHypothesis {
tokens: vec![10u32, 20, 30], logprob_sum: -4.0,
finished: false,
};
let score = hyp.score(2.0, 1);
let expected = -4.0f32 / 4.0f32;
assert!(
(score - expected).abs() < 1e-5,
"score with penalty=2.0 should be {expected}, got {score}"
);
}
#[test]
fn beam_hypothesis_score_neutral_length_penalty() {
let hyp = BeamHypothesis {
tokens: vec![1u32, 2, 3, 4], logprob_sum: -6.0,
finished: false,
};
let score = hyp.score(1.0, 2);
let expected = -6.0f32 / 2.0f32;
assert!(
(score - expected).abs() < 1e-5,
"neutral score should be {expected}, got {score}"
);
}
#[test]
fn beam_hypothesis_score_zero_when_no_generated_tokens() {
let hyp = BeamHypothesis {
tokens: vec![1u32, 2],
logprob_sum: -99.0,
finished: false,
};
let score = hyp.score(1.0, 2); assert_eq!(score, 0.0, "score must be 0.0 when no tokens are generated");
}
#[test]
fn beam_search_width_one_matches_greedy() {
let logits_per_step = vec![vec![0.0f32, 5.0, 2.0, -10.0]; 5];
let prompt = vec![0u32];
let eos = 3u32;
let mut engine = StubEngine::new(prompt.len(), logits_per_step.clone());
let config = BeamSearchConfig {
beam_width: 1,
max_new_tokens: 3,
length_penalty: 1.0,
early_stopping: false,
};
let hyps =
beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
assert!(!hyps.is_empty(), "must produce at least one hypothesis");
let best = &hyps[0];
assert_eq!(
&best.tokens[prompt.len()..],
&[1u32, 1, 1],
"beam_width=1 should match greedy decode (token 1 at each step)"
);
}
#[test]
fn beam_width_four_returns_four_hypotheses() {
let logits: Vec<f32> = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, -100.0];
let logit_seq = vec![logits; 4];
let prompt = vec![100u32];
let eos = 7u32;
let mut engine = StubEngine::new(prompt.len(), logit_seq);
let config = BeamSearchConfig {
beam_width: 4,
max_new_tokens: 2,
length_penalty: 1.0,
early_stopping: false,
};
let hyps =
beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
assert_eq!(
hyps.len(),
4,
"beam_width=4 should return 4 hypotheses, got {}",
hyps.len()
);
}
#[test]
fn beam_early_stopping_terminates() {
let logits_step0 = vec![0.0f32, 100.0, 0.0]; let logit_seq = vec![logits_step0; 5];
let prompt = vec![0u32];
let eos = 1u32;
let mut engine = StubEngine::new(prompt.len(), logit_seq);
let config = BeamSearchConfig {
beam_width: 2,
max_new_tokens: 10,
length_penalty: 1.0,
early_stopping: true,
};
let hyps =
beam_generate(&mut engine, &prompt, &config, eos).expect("beam search must succeed");
assert!(!hyps.is_empty(), "must return at least one hypothesis");
let has_finished = hyps.iter().any(|h| h.finished);
assert!(
has_finished,
"at least one finished hypothesis should exist"
);
}
#[test]
fn log_softmax_sums_to_one_in_prob_space() {
let logits = vec![1.0f32, 2.0, 3.0, 4.0];
let lps = log_softmax(&logits);
let prob_sum: f32 = lps.iter().map(|&lp| lp.exp()).sum();
assert!(
(prob_sum - 1.0).abs() < 1e-5,
"exp(log-softmax) must sum to 1, got {prob_sum}"
);
}
#[test]
fn log_softmax_empty_is_empty() {
let lps = log_softmax(&[]);
assert!(lps.is_empty());
}
#[test]
fn log_softmax_single_element_is_zero() {
let lps = log_softmax(&[5.0f32]);
assert!(
(lps[0] - 0.0).abs() < 1e-6,
"log-softmax of a single element must be 0, got {}",
lps[0]
);
}
#[test]
fn beam_search_errors_on_zero_beam_width() {
let prompt = vec![1u32];
let mut engine = StubEngine::new(1, vec![vec![1.0, 2.0, 3.0]]);
let config = BeamSearchConfig {
beam_width: 0,
..BeamSearchConfig::default()
};
let result = beam_generate(&mut engine, &prompt, &config, 0);
assert!(result.is_err(), "beam_width=0 should return an error");
}
#[test]
fn beam_search_errors_on_empty_prompt() {
let mut engine = StubEngine::new(0, vec![vec![1.0, 2.0, 3.0]]);
let config = BeamSearchConfig::default();
let result = beam_generate(&mut engine, &[], &config, 0);
assert!(result.is_err(), "empty prompt should return an error");
}
#[test]
fn engine_beam_generate_errors_when_not_loaded() {
let mut engine =
crate::engine::InferenceEngine::new(crate::engine::EngineConfig::default());
let config = BeamSearchConfig::default();
let result = engine.beam_generate(&[1u32, 2], &config, 0);
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"unloaded engine should return ModelNotLoaded, got {:?}",
result
);
}
}