use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone, Copy)]
pub struct BeamConfig {
pub beam_width: usize,
pub max_steps: usize,
pub length_alpha: f64,
pub diversity: f64,
}
impl Default for BeamConfig {
fn default() -> Self {
Self {
beam_width: 4,
max_steps: 32,
length_alpha: 0.0,
diversity: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct BeamSearch {
pub cfg: BeamConfig,
}
impl BeamSearch {
pub fn new(cfg: BeamConfig) -> SeqResult<Self> {
if cfg.beam_width == 0 {
return Err(SeqError::InvalidConfiguration(
"beam_width must be > 0".to_string(),
));
}
Ok(Self { cfg })
}
pub fn search<F, G>(
&self,
init_token: usize,
mut successors: F,
mut is_terminal: G,
) -> SeqResult<(Vec<usize>, f64)>
where
F: FnMut(&[usize]) -> Vec<(usize, f64)>,
G: FnMut(usize) -> bool,
{
let mut beam: Vec<(Vec<usize>, f64, bool)> = vec![(vec![init_token], 0.0, false)];
for _step in 0..self.cfg.max_steps {
if beam.iter().all(|(_, _, done)| *done) {
break;
}
let mut new_beam: Vec<(Vec<usize>, f64, bool)> = Vec::with_capacity(beam.len() * 4);
for (path, score, done) in &beam {
if *done {
new_beam.push((path.clone(), *score, true));
continue;
}
let next = successors(path);
if next.is_empty() {
new_beam.push((path.clone(), *score, true));
continue;
}
for (tok, logp) in next {
let mut p = path.clone();
p.push(tok);
let term = is_terminal(tok);
new_beam.push((p, score + logp, term));
}
}
new_beam.sort_by(|a, b| {
let sa = norm_score(a.1, a.0.len(), self.cfg.length_alpha);
let sb = norm_score(b.1, b.0.len(), self.cfg.length_alpha);
sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
});
if self.cfg.diversity > 0.0 {
for (i, item) in new_beam.iter_mut().enumerate() {
item.1 -= self.cfg.diversity * i as f64;
}
new_beam.sort_by(|a, b| {
let sa = norm_score(a.1, a.0.len(), self.cfg.length_alpha);
let sb = norm_score(b.1, b.0.len(), self.cfg.length_alpha);
sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
});
}
new_beam.truncate(self.cfg.beam_width);
beam = new_beam;
}
let (path, score, _) = beam
.into_iter()
.next()
.ok_or_else(|| SeqError::NumericalInstability("empty beam".to_string()))?;
Ok((path, score))
}
}
#[inline]
fn norm_score(score: f64, length: usize, alpha: f64) -> f64 {
if alpha == 0.0 || length == 0 {
score
} else {
score / (length as f64).powf(alpha)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn beam_trivial_chain() {
let bs = BeamSearch::new(BeamConfig {
beam_width: 1,
max_steps: 3,
length_alpha: 0.0,
diversity: 0.0,
})
.expect("ok");
let (path, score) = bs
.search(0, |_path| vec![(1, -0.1), (2, -0.5)], |t| t == 9)
.expect("ok");
assert_eq!(path, vec![0, 1, 1, 1]);
assert!((score - (-0.3)).abs() < 1e-9);
}
#[test]
fn beam_terminates_on_end_token() {
let bs = BeamSearch::new(BeamConfig::default()).expect("ok");
let mut step = 0;
let (path, _score) = bs
.search(
0,
|_path| {
step += 1;
if step >= 3 {
vec![(9, 0.0)]
} else {
vec![(1, 0.0), (2, -0.1)]
}
},
|t| t == 9,
)
.expect("ok");
assert!(path.contains(&9));
}
#[test]
fn beam_zero_width_errors() {
assert!(
BeamSearch::new(BeamConfig {
beam_width: 0,
max_steps: 1,
length_alpha: 0.0,
diversity: 0.0,
})
.is_err()
);
}
}