1use crate::error::{SeqError, SeqResult};
4
5#[derive(Debug, Clone, Copy)]
7pub struct BeamConfig {
8 pub beam_width: usize,
9 pub max_steps: usize,
10 pub length_alpha: f64,
12 pub diversity: f64,
15}
16
17impl Default for BeamConfig {
18 fn default() -> Self {
19 Self {
20 beam_width: 4,
21 max_steps: 32,
22 length_alpha: 0.0,
23 diversity: 0.0,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
34pub struct BeamSearch {
35 pub cfg: BeamConfig,
36}
37
38impl BeamSearch {
39 pub fn new(cfg: BeamConfig) -> SeqResult<Self> {
40 if cfg.beam_width == 0 {
41 return Err(SeqError::InvalidConfiguration(
42 "beam_width must be > 0".to_string(),
43 ));
44 }
45 Ok(Self { cfg })
46 }
47
48 pub fn search<F, G>(
50 &self,
51 init_token: usize,
52 mut successors: F,
53 mut is_terminal: G,
54 ) -> SeqResult<(Vec<usize>, f64)>
55 where
56 F: FnMut(&[usize]) -> Vec<(usize, f64)>,
57 G: FnMut(usize) -> bool,
58 {
59 let mut beam: Vec<(Vec<usize>, f64, bool)> = vec![(vec![init_token], 0.0, false)];
60 for _step in 0..self.cfg.max_steps {
61 if beam.iter().all(|(_, _, done)| *done) {
62 break;
63 }
64 let mut new_beam: Vec<(Vec<usize>, f64, bool)> = Vec::with_capacity(beam.len() * 4);
65 for (path, score, done) in &beam {
66 if *done {
67 new_beam.push((path.clone(), *score, true));
68 continue;
69 }
70 let next = successors(path);
71 if next.is_empty() {
72 new_beam.push((path.clone(), *score, true));
73 continue;
74 }
75 for (tok, logp) in next {
76 let mut p = path.clone();
77 p.push(tok);
78 let term = is_terminal(tok);
79 new_beam.push((p, score + logp, term));
80 }
81 }
82 new_beam.sort_by(|a, b| {
84 let sa = norm_score(a.1, a.0.len(), self.cfg.length_alpha);
85 let sb = norm_score(b.1, b.0.len(), self.cfg.length_alpha);
86 sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
87 });
88 if self.cfg.diversity > 0.0 {
90 for (i, item) in new_beam.iter_mut().enumerate() {
91 item.1 -= self.cfg.diversity * i as f64;
92 }
93 new_beam.sort_by(|a, b| {
94 let sa = norm_score(a.1, a.0.len(), self.cfg.length_alpha);
95 let sb = norm_score(b.1, b.0.len(), self.cfg.length_alpha);
96 sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
97 });
98 }
99 new_beam.truncate(self.cfg.beam_width);
100 beam = new_beam;
101 }
102 let (path, score, _) = beam
103 .into_iter()
104 .next()
105 .ok_or_else(|| SeqError::NumericalInstability("empty beam".to_string()))?;
106 Ok((path, score))
107 }
108}
109
110#[inline]
111fn norm_score(score: f64, length: usize, alpha: f64) -> f64 {
112 if alpha == 0.0 || length == 0 {
113 score
114 } else {
115 score / (length as f64).powf(alpha)
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn beam_trivial_chain() {
125 let bs = BeamSearch::new(BeamConfig {
129 beam_width: 1,
130 max_steps: 3,
131 length_alpha: 0.0,
132 diversity: 0.0,
133 })
134 .expect("ok");
135 let (path, score) = bs
136 .search(0, |_path| vec![(1, -0.1), (2, -0.5)], |t| t == 9)
137 .expect("ok");
138 assert_eq!(path, vec![0, 1, 1, 1]);
139 assert!((score - (-0.3)).abs() < 1e-9);
140 }
141
142 #[test]
143 fn beam_terminates_on_end_token() {
144 let bs = BeamSearch::new(BeamConfig::default()).expect("ok");
145 let mut step = 0;
146 let (path, _score) = bs
147 .search(
148 0,
149 |_path| {
150 step += 1;
151 if step >= 3 {
152 vec![(9, 0.0)]
153 } else {
154 vec![(1, 0.0), (2, -0.1)]
155 }
156 },
157 |t| t == 9,
158 )
159 .expect("ok");
160 assert!(path.contains(&9));
161 }
162
163 #[test]
164 fn beam_zero_width_errors() {
165 assert!(
166 BeamSearch::new(BeamConfig {
167 beam_width: 0,
168 max_steps: 1,
169 length_alpha: 0.0,
170 diversity: 0.0,
171 })
172 .is_err()
173 );
174 }
175}