use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum BeamError {
EmptyLogProbs,
VocabSizeMismatch,
InvalidConfig(String),
ScoreFunctionError(String),
}
impl fmt::Display for BeamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BeamError::EmptyLogProbs => write!(f, "log_probs slice is empty"),
BeamError::VocabSizeMismatch => {
write!(f, "vocab size of log_probs does not match config")
}
BeamError::InvalidConfig(msg) => write!(f, "invalid beam search config: {msg}"),
BeamError::ScoreFunctionError(msg) => {
write!(f, "score function returned an error: {msg}")
}
}
}
}
impl std::error::Error for BeamError {}
#[derive(Debug, Clone)]
pub struct BeamSearchConfig {
pub num_beams: usize,
pub max_new_tokens: usize,
pub min_length: usize,
pub length_penalty: f32,
pub early_stopping: bool,
pub no_repeat_ngram_size: usize,
pub repetition_penalty: f32,
pub diversity_penalty: f32,
pub num_beam_groups: usize,
pub eos_token_id: Option<u32>,
pub pad_token_id: Option<u32>,
pub vocab_size: usize,
}
impl Default for BeamSearchConfig {
fn default() -> Self {
Self {
num_beams: 4,
max_new_tokens: 50,
min_length: 0,
length_penalty: 1.0,
early_stopping: true,
no_repeat_ngram_size: 0,
repetition_penalty: 1.0,
diversity_penalty: 0.0,
num_beam_groups: 1,
eos_token_id: None,
pad_token_id: None,
vocab_size: 32000,
}
}
}
impl BeamSearchConfig {
pub fn validate(&self) -> Result<(), BeamError> {
if self.num_beams == 0 {
return Err(BeamError::InvalidConfig(
"num_beams must be at least 1".to_string(),
));
}
if self.num_beam_groups == 0 {
return Err(BeamError::InvalidConfig(
"num_beam_groups must be at least 1".to_string(),
));
}
if self.num_beam_groups > self.num_beams {
return Err(BeamError::InvalidConfig(
"num_beam_groups must not exceed num_beams".to_string(),
));
}
if self.num_beams % self.num_beam_groups != 0 {
return Err(BeamError::InvalidConfig(
"num_beams must be divisible by num_beam_groups".to_string(),
));
}
if self.vocab_size == 0 {
return Err(BeamError::InvalidConfig(
"vocab_size must be at least 1".to_string(),
));
}
if self.repetition_penalty <= 0.0 {
return Err(BeamError::InvalidConfig(
"repetition_penalty must be positive".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BeamHypothesis {
pub tokens: Vec<u32>,
pub score: f32,
}
impl BeamHypothesis {
pub fn new() -> Self {
Self {
tokens: Vec::new(),
score: 0.0,
}
}
pub fn length_normalized_score(&self, length_penalty: f32) -> f32 {
let len = self.tokens.len().max(1) as f32;
self.score / len.powf(length_penalty)
}
}
impl Default for BeamHypothesis {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BeamState {
pub hypotheses: Vec<BeamHypothesis>,
pub completed: Vec<BeamHypothesis>,
pub prompt_tokens: Vec<u32>,
}
impl BeamState {
pub fn new(prompt_tokens: Vec<u32>, hypotheses: Vec<BeamHypothesis>) -> Self {
Self {
hypotheses,
completed: Vec::new(),
prompt_tokens,
}
}
pub fn best_hypothesis(&self) -> Option<&BeamHypothesis> {
self.best_hypothesis_with_penalty(1.0)
}
pub fn best_hypothesis_with_penalty(&self, length_penalty: f32) -> Option<&BeamHypothesis> {
let best_completed = self
.completed
.iter()
.max_by(|a, b| {
a.length_normalized_score(length_penalty)
.partial_cmp(&b.length_normalized_score(length_penalty))
.unwrap_or(std::cmp::Ordering::Equal)
});
if best_completed.is_some() {
return best_completed;
}
self.hypotheses.iter().max_by(|a, b| {
a.length_normalized_score(length_penalty)
.partial_cmp(&b.length_normalized_score(length_penalty))
.unwrap_or(std::cmp::Ordering::Equal)
})
}
pub fn is_done(&self, num_beams: usize, early_stopping: bool) -> bool {
if self.hypotheses.is_empty() {
return true;
}
if early_stopping && self.completed.len() >= num_beams {
return true;
}
false
}
}
pub fn get_forbidden_tokens_for_ngram(tokens: &[u32], ngram_size: usize) -> Vec<u32> {
if ngram_size == 0 || tokens.len() < ngram_size - 1 {
return Vec::new();
}
let suffix_start = tokens.len() + 1 - ngram_size;
let suffix = &tokens[suffix_start..];
let mut forbidden = Vec::new();
let window_size = ngram_size - 1;
if tokens.len() < window_size {
return forbidden;
}
for start in 0..=(tokens.len() - window_size) {
let window = &tokens[start..start + window_size];
if window == suffix {
if start + window_size < tokens.len() {
forbidden.push(tokens[start + window_size]);
}
}
}
forbidden.sort_unstable();
forbidden.dedup();
forbidden
}
pub fn beam_search_step(
beam_state: &mut BeamState,
log_probs: &[Vec<f32>],
config: &BeamSearchConfig,
) -> Result<(), BeamError> {
if log_probs.is_empty() {
return Err(BeamError::EmptyLogProbs);
}
let num_active = beam_state.hypotheses.len();
if log_probs.len() != num_active {
return Err(BeamError::VocabSizeMismatch);
}
for beam_log_probs in log_probs.iter() {
if beam_log_probs.len() != config.vocab_size {
return Err(BeamError::VocabSizeMismatch);
}
}
let mut candidates: Vec<(f32, usize, u32)> = Vec::new();
for (beam_idx, hyp) in beam_state.hypotheses.iter().enumerate() {
let mut lp = log_probs[beam_idx].clone();
let all_tokens: Vec<u32> = beam_state
.prompt_tokens
.iter()
.chain(hyp.tokens.iter())
.copied()
.collect();
if (config.repetition_penalty - 1.0).abs() > f32::EPSILON {
for &tok in &all_tokens {
if (tok as usize) < lp.len() {
lp[tok as usize] /= config.repetition_penalty;
}
}
}
if config.no_repeat_ngram_size > 0 {
let forbidden =
get_forbidden_tokens_for_ngram(&all_tokens, config.no_repeat_ngram_size);
for tok in forbidden {
if (tok as usize) < lp.len() {
lp[tok as usize] = f32::NEG_INFINITY;
}
}
}
if let Some(eos) = config.eos_token_id {
if hyp.tokens.len() < config.min_length {
if (eos as usize) < lp.len() {
lp[eos as usize] = f32::NEG_INFINITY;
}
}
}
for (token_id, &lp_val) in lp.iter().enumerate() {
let new_score = hyp.score + lp_val;
candidates.push((new_score, beam_idx, token_id as u32));
}
}
candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(config.num_beams);
let old_hypotheses = beam_state.hypotheses.clone();
let mut new_hypotheses: Vec<BeamHypothesis> = Vec::new();
for (new_score, beam_idx, token_id) in candidates {
let parent = &old_hypotheses[beam_idx];
let mut new_tokens = parent.tokens.clone();
new_tokens.push(token_id);
let new_hyp = BeamHypothesis {
tokens: new_tokens,
score: new_score,
};
let is_eos = config
.eos_token_id
.map(|eos| token_id == eos)
.unwrap_or(false);
if is_eos {
beam_state.completed.push(new_hyp);
} else {
new_hypotheses.push(new_hyp);
}
}
beam_state.hypotheses = new_hypotheses;
Ok(())
}
pub struct BeamSearchDecoder {
pub config: BeamSearchConfig,
}
impl BeamSearchDecoder {
pub fn new(config: BeamSearchConfig) -> Result<Self, BeamError> {
config.validate()?;
Ok(Self { config })
}
pub fn initialize_beams(&self, prompt_tokens: &[u32]) -> BeamState {
let hypotheses: Vec<BeamHypothesis> = (0..self.config.num_beams)
.map(|_| BeamHypothesis::new())
.collect();
BeamState::new(prompt_tokens.to_vec(), hypotheses)
}
pub fn decode(
&self,
prompt_tokens: &[u32],
score_fn: impl Fn(&[Vec<u32>]) -> Result<Vec<Vec<f32>>, BeamError>,
) -> Result<Vec<u32>, BeamError> {
let mut beam_state = self.initialize_beams(prompt_tokens);
for _step in 0..self.config.max_new_tokens {
if beam_state.is_done(self.config.num_beams, self.config.early_stopping) {
break;
}
if beam_state.hypotheses.is_empty() {
break;
}
let sequences: Vec<Vec<u32>> = beam_state
.hypotheses
.iter()
.map(|hyp| {
let mut seq = beam_state.prompt_tokens.clone();
seq.extend_from_slice(&hyp.tokens);
seq
})
.collect();
let log_probs = score_fn(&sequences)
.map_err(|e| BeamError::ScoreFunctionError(e.to_string()))?;
beam_search_step(&mut beam_state, &log_probs, &self.config)?;
}
let active: Vec<BeamHypothesis> = beam_state.hypotheses.drain(..).collect();
beam_state.completed.extend(active);
let best = beam_state
.best_hypothesis_with_penalty(self.config.length_penalty)
.ok_or(BeamError::EmptyLogProbs)?;
Ok(best.tokens.clone())
}
pub fn diverse_beam_search_step(
beam_state: &mut BeamState,
log_probs: &[Vec<f32>],
group_idx: usize,
previous_group_tokens: &[u32],
config: &BeamSearchConfig,
) -> Result<(), BeamError> {
if log_probs.is_empty() {
return Err(BeamError::EmptyLogProbs);
}
let num_active = beam_state.hypotheses.len();
if log_probs.len() != num_active {
return Err(BeamError::VocabSizeMismatch);
}
for beam_log_probs in log_probs.iter() {
if beam_log_probs.len() != config.vocab_size {
return Err(BeamError::VocabSizeMismatch);
}
}
let diversity_discount = config.diversity_penalty * (group_idx as f32);
let mut penalized: Vec<Vec<f32>> = log_probs.to_vec();
if diversity_discount > 0.0 {
for beam_lp in penalized.iter_mut() {
for &tok in previous_group_tokens {
if (tok as usize) < beam_lp.len() {
beam_lp[tok as usize] -= diversity_discount;
}
}
}
}
beam_search_step(beam_state, &penalized, config)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn single_token_log_probs(vocab_size: usize, best_token: usize, num_beams: usize) -> Vec<Vec<f32>> {
(0..num_beams)
.map(|_| {
let mut lp = vec![f32::NEG_INFINITY; vocab_size];
lp[best_token] = 0.0_f32; lp
})
.collect()
}
#[test]
fn test_config_defaults() {
let cfg = BeamSearchConfig::default();
assert_eq!(cfg.num_beams, 4);
assert_eq!(cfg.max_new_tokens, 50);
assert_eq!(cfg.min_length, 0);
assert!((cfg.length_penalty - 1.0).abs() < f32::EPSILON);
assert!(cfg.early_stopping);
assert_eq!(cfg.no_repeat_ngram_size, 0);
assert!((cfg.repetition_penalty - 1.0).abs() < f32::EPSILON);
assert!((cfg.diversity_penalty - 0.0).abs() < f32::EPSILON);
assert_eq!(cfg.num_beam_groups, 1);
assert_eq!(cfg.eos_token_id, None);
assert_eq!(cfg.pad_token_id, None);
assert_eq!(cfg.vocab_size, 32000);
}
#[test]
fn test_length_normalized_score_length_penalty_lt_1() {
let short_hyp = BeamHypothesis {
tokens: vec![1, 2],
score: -2.0,
};
let long_hyp = BeamHypothesis {
tokens: vec![1, 2, 3, 4],
score: -4.0,
};
let length_penalty = 0.5_f32;
let short_norm = short_hyp.length_normalized_score(length_penalty);
let long_norm = long_hyp.length_normalized_score(length_penalty);
assert!(
short_norm > long_norm,
"short={short_norm}, long={long_norm}"
);
}
#[test]
fn test_length_normalized_score_length_penalty_gt_1() {
let short_hyp = BeamHypothesis {
tokens: vec![1, 2],
score: -2.0,
};
let long_hyp = BeamHypothesis {
tokens: vec![1, 2, 3, 4],
score: -4.0,
};
let length_penalty = 2.0_f32;
let short_norm = short_hyp.length_normalized_score(length_penalty);
let long_norm = long_hyp.length_normalized_score(length_penalty);
assert!(
long_norm > short_norm,
"short={short_norm}, long={long_norm}"
);
}
#[test]
fn test_single_beam_is_greedy() {
let cfg = BeamSearchConfig {
num_beams: 1,
max_new_tokens: 3,
vocab_size: 5,
eos_token_id: Some(4),
..Default::default()
};
let decoder = BeamSearchDecoder::new(cfg).expect("valid config");
let result = decoder.decode(&[0], |seqs| {
let lp: Vec<Vec<f32>> = seqs
.iter()
.map(|_| {
let mut v = vec![-10.0_f32; 5];
v[2] = 0.0;
v
})
.collect();
Ok(lp)
});
let tokens = result.expect("decoding succeeded");
assert!(tokens.iter().all(|&t| t == 2), "tokens: {tokens:?}");
}
#[test]
fn test_forbidden_tokens_ngram_zero() {
let forbidden = get_forbidden_tokens_for_ngram(&[1, 2, 1], 0);
assert!(forbidden.is_empty());
}
#[test]
fn test_forbidden_tokens_ngram_bigram() {
let forbidden = get_forbidden_tokens_for_ngram(&[1, 2, 1], 2);
assert!(forbidden.contains(&2), "expected 2 in {forbidden:?}");
}
#[test]
fn test_forbidden_tokens_ngram_trigram() {
let forbidden = get_forbidden_tokens_for_ngram(&[1, 2, 3, 1, 2], 3);
assert!(forbidden.contains(&3), "expected 3 in {forbidden:?}");
}
#[test]
fn test_forbidden_tokens_no_repeat_yet() {
let forbidden = get_forbidden_tokens_for_ngram(&[1, 2, 3], 2);
assert!(!forbidden.contains(&2), "unexpected 2 in {forbidden:?}");
}
#[test]
fn test_repetition_penalty_applied() {
let cfg = BeamSearchConfig {
num_beams: 1,
vocab_size: 4,
repetition_penalty: 2.0,
max_new_tokens: 1,
..Default::default()
};
let mut state = BeamState::new(vec![0], vec![BeamHypothesis::new()]);
let log_probs = vec![vec![-0.1_f32, -0.5, -0.5, -0.5]];
beam_search_step(&mut state, &log_probs, &cfg).expect("step ok");
assert_eq!(state.hypotheses[0].tokens[0], 0, "token 0 wins after penalty");
}
#[test]
fn test_beam_initialization() {
let cfg = BeamSearchConfig {
num_beams: 4,
..Default::default()
};
let decoder = BeamSearchDecoder::new(cfg).expect("valid config");
let state = decoder.initialize_beams(&[10, 20, 30]);
assert_eq!(state.hypotheses.len(), 4);
assert_eq!(state.prompt_tokens, vec![10, 20, 30]);
for hyp in &state.hypotheses {
assert!(hyp.tokens.is_empty());
assert!((hyp.score - 0.0).abs() < f32::EPSILON);
}
assert!(state.completed.is_empty());
}
#[test]
fn test_beam_search_step_top_k_selection() {
let cfg = BeamSearchConfig {
num_beams: 2,
vocab_size: 4,
..Default::default()
};
let hypotheses = vec![BeamHypothesis::new(), BeamHypothesis::new()];
let mut state = BeamState::new(vec![], hypotheses);
let log_probs = vec![
vec![-1.0_f32, -1.0, -1.0, 0.0], vec![-0.1_f32, 0.0, -1.0, -1.0], ];
beam_search_step(&mut state, &log_probs, &cfg).expect("step ok");
assert_eq!(state.hypotheses.len(), 2);
}
#[test]
fn test_completed_hypothesis_on_eos() {
let cfg = BeamSearchConfig {
num_beams: 2,
vocab_size: 3,
eos_token_id: Some(2),
..Default::default()
};
let hypotheses = vec![BeamHypothesis::new(), BeamHypothesis::new()];
let mut state = BeamState::new(vec![], hypotheses);
let log_probs = vec![
vec![-1.0_f32, -1.0, 0.0],
vec![-1.0_f32, -1.0, 0.0],
];
beam_search_step(&mut state, &log_probs, &cfg).expect("step ok");
assert_eq!(state.completed.len(), 2);
}
#[test]
fn test_best_hypothesis_prefers_completed() {
let mut state = BeamState::new(vec![], vec![
BeamHypothesis { tokens: vec![1, 2, 3], score: -3.0 },
]);
state.completed.push(BeamHypothesis { tokens: vec![5, 6], score: -1.0 });
let best = state.best_hypothesis().expect("has a best");
assert_eq!(best.tokens, vec![5, 6]);
}
#[test]
fn test_diverse_beam_penalty_applied() {
let cfg = BeamSearchConfig {
num_beams: 1,
vocab_size: 4,
diversity_penalty: 1.0,
num_beam_groups: 1,
..Default::default()
};
let mut state = BeamState::new(vec![], vec![BeamHypothesis::new()]);
let log_probs = vec![vec![0.0_f32, -0.5, -1.0, -1.0]];
BeamSearchDecoder::diverse_beam_search_step(
&mut state,
&log_probs,
1, &[0], &cfg,
).expect("step ok");
assert_eq!(state.hypotheses[0].tokens[0], 1, "token 1 should win after diversity penalty");
}
#[test]
fn test_is_done_early_stopping() {
let mut state = BeamState::new(vec![], vec![BeamHypothesis::new(), BeamHypothesis::new()]);
assert!(!state.is_done(2, true));
state.completed.push(BeamHypothesis::new());
state.completed.push(BeamHypothesis::new());
assert!(state.is_done(2, true));
}
#[test]
fn test_is_done_no_early_stopping() {
let mut state = BeamState::new(vec![], vec![BeamHypothesis::new()]);
state.completed.push(BeamHypothesis::new());
assert!(!state.is_done(1, false));
}
#[test]
fn test_error_empty_log_probs() {
let cfg = BeamSearchConfig { num_beams: 1, vocab_size: 4, ..Default::default() };
let mut state = BeamState::new(vec![], vec![BeamHypothesis::new()]);
let result = beam_search_step(&mut state, &[], &cfg);
assert_eq!(result, Err(BeamError::EmptyLogProbs));
}
#[test]
fn test_error_vocab_size_mismatch() {
let cfg = BeamSearchConfig { num_beams: 1, vocab_size: 4, ..Default::default() };
let mut state = BeamState::new(vec![], vec![BeamHypothesis::new()]);
let log_probs = vec![vec![0.0_f32, 0.0, 0.0]];
let result = beam_search_step(&mut state, &log_probs, &cfg);
assert_eq!(result, Err(BeamError::VocabSizeMismatch));
}
#[test]
fn test_error_invalid_config_zero_beams() {
let cfg = BeamSearchConfig { num_beams: 0, ..Default::default() };
let result = BeamSearchDecoder::new(cfg);
assert!(matches!(result, Err(BeamError::InvalidConfig(_))));
}
#[test]
fn test_error_display() {
assert!(!BeamError::EmptyLogProbs.to_string().is_empty());
assert!(!BeamError::VocabSizeMismatch.to_string().is_empty());
assert!(!BeamError::InvalidConfig("x".to_string()).to_string().is_empty());
assert!(!BeamError::ScoreFunctionError("y".to_string()).to_string().is_empty());
}
}