Skip to main content

oxicuda_seq/beam/
beam.rs

1//! Generic beam search with length-normalisation and diversity penalty.
2
3use crate::error::{SeqError, SeqResult};
4
5/// Beam-search configuration.
6#[derive(Debug, Clone, Copy)]
7pub struct BeamConfig {
8    pub beam_width: usize,
9    pub max_steps: usize,
10    /// Exponent applied to `length^alpha` denominator for length normalisation.
11    pub length_alpha: f64,
12    /// Penalty subtracted from each candidate proportional to its position in
13    /// the sorted beam (diversity / Diverse Beam Search).
14    pub diversity: f64,
15}
16
17impl Default for BeamConfig {
18    fn default() -> Self {
19        Self {
20            beam_width: 4,
21            max_steps: 32,
22            length_alpha: 0.0,
23            diversity: 0.0,
24        }
25    }
26}
27
28/// Generic beam search over a score-function callback.
29///
30/// * `init_token` is the starting token used to seed all beam items.
31/// * `is_terminal(token) -> bool` halts a beam if the most recent token is terminal.
32/// * `successors(path) -> Vec<(token, log_prob)>` yields candidate extensions.
33#[derive(Debug, Clone)]
34pub struct BeamSearch {
35    pub cfg: BeamConfig,
36}
37
38impl BeamSearch {
39    pub fn new(cfg: BeamConfig) -> SeqResult<Self> {
40        if cfg.beam_width == 0 {
41            return Err(SeqError::InvalidConfiguration(
42                "beam_width must be > 0".to_string(),
43            ));
44        }
45        Ok(Self { cfg })
46    }
47
48    /// Run beam search.  `Token = usize` for simplicity.  Returns the top-1 path.
49    pub fn search<F, G>(
50        &self,
51        init_token: usize,
52        mut successors: F,
53        mut is_terminal: G,
54    ) -> SeqResult<(Vec<usize>, f64)>
55    where
56        F: FnMut(&[usize]) -> Vec<(usize, f64)>,
57        G: FnMut(usize) -> bool,
58    {
59        let mut beam: Vec<(Vec<usize>, f64, bool)> = vec![(vec![init_token], 0.0, false)];
60        for _step in 0..self.cfg.max_steps {
61            if beam.iter().all(|(_, _, done)| *done) {
62                break;
63            }
64            let mut new_beam: Vec<(Vec<usize>, f64, bool)> = Vec::with_capacity(beam.len() * 4);
65            for (path, score, done) in &beam {
66                if *done {
67                    new_beam.push((path.clone(), *score, true));
68                    continue;
69                }
70                let next = successors(path);
71                if next.is_empty() {
72                    new_beam.push((path.clone(), *score, true));
73                    continue;
74                }
75                for (tok, logp) in next {
76                    let mut p = path.clone();
77                    p.push(tok);
78                    let term = is_terminal(tok);
79                    new_beam.push((p, score + logp, term));
80                }
81            }
82            // Apply length normalisation when ranking.
83            new_beam.sort_by(|a, b| {
84                let sa = norm_score(a.1, a.0.len(), self.cfg.length_alpha);
85                let sb = norm_score(b.1, b.0.len(), self.cfg.length_alpha);
86                sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
87            });
88            // Diversity penalty: subtract `i * diversity` from rank-i candidate.
89            if self.cfg.diversity > 0.0 {
90                for (i, item) in new_beam.iter_mut().enumerate() {
91                    item.1 -= self.cfg.diversity * i as f64;
92                }
93                new_beam.sort_by(|a, b| {
94                    let sa = norm_score(a.1, a.0.len(), self.cfg.length_alpha);
95                    let sb = norm_score(b.1, b.0.len(), self.cfg.length_alpha);
96                    sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
97                });
98            }
99            new_beam.truncate(self.cfg.beam_width);
100            beam = new_beam;
101        }
102        let (path, score, _) = beam
103            .into_iter()
104            .next()
105            .ok_or_else(|| SeqError::NumericalInstability("empty beam".to_string()))?;
106        Ok((path, score))
107    }
108}
109
110#[inline]
111fn norm_score(score: f64, length: usize, alpha: f64) -> f64 {
112    if alpha == 0.0 || length == 0 {
113        score
114    } else {
115        score / (length as f64).powf(alpha)
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn beam_trivial_chain() {
125        // Generative model that always appends token 1 with log-prob -0.1
126        // and token 2 with log-prob -0.5.  Beam width 1 should pick token 1
127        // every step.
128        let bs = BeamSearch::new(BeamConfig {
129            beam_width: 1,
130            max_steps: 3,
131            length_alpha: 0.0,
132            diversity: 0.0,
133        })
134        .expect("ok");
135        let (path, score) = bs
136            .search(0, |_path| vec![(1, -0.1), (2, -0.5)], |t| t == 9)
137            .expect("ok");
138        assert_eq!(path, vec![0, 1, 1, 1]);
139        assert!((score - (-0.3)).abs() < 1e-9);
140    }
141
142    #[test]
143    fn beam_terminates_on_end_token() {
144        let bs = BeamSearch::new(BeamConfig::default()).expect("ok");
145        let mut step = 0;
146        let (path, _score) = bs
147            .search(
148                0,
149                |_path| {
150                    step += 1;
151                    if step >= 3 {
152                        vec![(9, 0.0)]
153                    } else {
154                        vec![(1, 0.0), (2, -0.1)]
155                    }
156                },
157                |t| t == 9,
158            )
159            .expect("ok");
160        assert!(path.contains(&9));
161    }
162
163    #[test]
164    fn beam_zero_width_errors() {
165        assert!(
166            BeamSearch::new(BeamConfig {
167                beam_width: 0,
168                max_steps: 1,
169                length_alpha: 0.0,
170                diversity: 0.0,
171            })
172            .is_err()
173        );
174    }
175}