1use crate::error::{SeqError, SeqResult};
2
3#[derive(Debug, Clone, Copy)]
5pub struct DiverseBeamConfig {
6 pub beam_width: usize,
8 pub n_groups: usize,
10 pub max_steps: usize,
12 pub vocab_size: usize,
14 pub eos_id: usize,
16 pub diversity_strength: f32,
18 pub length_norm_alpha: f32,
20}
21
22#[derive(Debug, Clone)]
24pub struct DiverseBeam {
25 pub cfg: DiverseBeamConfig,
26}
27
28impl DiverseBeam {
29 pub fn new(cfg: DiverseBeamConfig) -> SeqResult<Self> {
31 if cfg.beam_width == 0 {
32 return Err(SeqError::InvalidConfiguration(
33 "beam_width must be > 0".to_string(),
34 ));
35 }
36 if cfg.n_groups == 0 {
37 return Err(SeqError::InvalidConfiguration(
38 "n_groups must be > 0".to_string(),
39 ));
40 }
41 if cfg.beam_width % cfg.n_groups != 0 {
42 return Err(SeqError::InvalidConfiguration(format!(
43 "beam_width ({}) must be divisible by n_groups ({})",
44 cfg.beam_width, cfg.n_groups
45 )));
46 }
47 if cfg.vocab_size == 0 {
48 return Err(SeqError::InvalidConfiguration(
49 "vocab_size must be > 0".to_string(),
50 ));
51 }
52 Ok(Self { cfg })
53 }
54
55 pub fn search<F>(&self, score_fn: F) -> SeqResult<Vec<Vec<usize>>>
64 where
65 F: Fn(&[usize]) -> Vec<f32>,
66 {
67 let cfg = &self.cfg;
68 let beam_per_group = cfg.beam_width / cfg.n_groups;
69
70 let mut groups: Vec<Vec<(Vec<usize>, f32)>> =
72 vec![vec![(vec![], 0.0f32); beam_per_group]; cfg.n_groups];
73
74 let mut finished: Vec<Vec<bool>> = vec![vec![false; beam_per_group]; cfg.n_groups];
76
77 for _step in 0..cfg.max_steps {
78 let mut prev_group_tokens: Vec<usize> = Vec::new();
80
81 for g in 0..cfg.n_groups {
82 if finished[g].iter().all(|&f| f) {
84 for hyp_idx in 0..beam_per_group {
86 if let Some(&last_tok) = groups[g][hyp_idx].0.last() {
87 prev_group_tokens.push(last_tok);
88 }
89 }
90 continue;
91 }
92
93 let mut candidates: Vec<(usize, usize, f32, Vec<usize>)> = Vec::new();
95
96 for hyp_idx in 0..beam_per_group {
97 let (ref tokens, cum_score) = groups[g][hyp_idx].clone();
98 if finished[g][hyp_idx] {
99 candidates.push((hyp_idx, cfg.eos_id, cum_score, tokens.clone()));
101 continue;
102 }
103
104 let log_probs = score_fn(tokens);
105 if log_probs.len() != cfg.vocab_size {
106 return Err(SeqError::ShapeMismatch {
107 expected: cfg.vocab_size,
108 got: log_probs.len(),
109 });
110 }
111
112 for tok in 0..cfg.vocab_size {
113 let raw = log_probs[tok];
114 let diversity_pen =
115 Self::hamming_penalty(tok, &prev_group_tokens, cfg.diversity_strength);
116 let candidate_score = cum_score + raw - diversity_pen;
117 let mut new_tokens = tokens.clone();
118 new_tokens.push(tok);
119 let norm_score = Self::length_norm(
120 candidate_score,
121 new_tokens.len(),
122 cfg.length_norm_alpha,
123 );
124 candidates.push((hyp_idx, tok, norm_score, new_tokens));
126 }
127 }
128
129 candidates
131 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
132
133 let mut new_beam: Vec<(Vec<usize>, f32)> = Vec::with_capacity(beam_per_group);
135 let mut new_finished: Vec<bool> = Vec::with_capacity(beam_per_group);
136 let mut tokens_chosen_by_group: Vec<usize> = Vec::new();
137
138 for (_, tok, _, new_tokens) in candidates.iter() {
139 if new_beam.len() >= beam_per_group {
140 break;
141 }
142 let is_eos = *tok == cfg.eos_id;
147 new_beam.push((new_tokens.clone(), 0.0));
148 new_finished.push(is_eos);
149 tokens_chosen_by_group.push(*tok);
150 }
151
152 let mut candidates2: Vec<(usize, usize, f32, f32, Vec<usize>)> = Vec::new();
156 for hyp_idx in 0..beam_per_group {
157 let (ref tokens, cum_score) = groups[g][hyp_idx].clone();
158 if finished[g][hyp_idx] {
159 candidates2.push((
160 hyp_idx,
161 cfg.eos_id,
162 cum_score,
163 cum_score,
164 tokens.clone(),
165 ));
166 continue;
167 }
168 let log_probs = score_fn(tokens);
169 for tok in 0..cfg.vocab_size {
170 let raw = log_probs[tok];
171 let diversity_pen =
172 Self::hamming_penalty(tok, &prev_group_tokens, cfg.diversity_strength);
173 let new_cum = cum_score + raw - diversity_pen;
174 let mut new_tokens = tokens.clone();
175 new_tokens.push(tok);
176 let norm_score =
177 Self::length_norm(new_cum, new_tokens.len(), cfg.length_norm_alpha);
178 candidates2.push((hyp_idx, tok, new_cum, norm_score, new_tokens));
179 }
180 }
181 candidates2
182 .sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
183
184 new_beam.clear();
185 new_finished.clear();
186 tokens_chosen_by_group.clear();
187
188 for (_, tok, new_cum, _, new_tokens) in candidates2.iter() {
189 if new_beam.len() >= beam_per_group {
190 break;
191 }
192 new_beam.push((new_tokens.clone(), *new_cum));
193 new_finished.push(*tok == cfg.eos_id);
194 if !finished[g]
195 .get(new_beam.len().saturating_sub(1))
196 .copied()
197 .unwrap_or(false)
198 {
199 tokens_chosen_by_group.push(*tok);
200 }
201 }
202
203 while new_beam.len() < beam_per_group {
205 if let Some(first) = new_beam.first().cloned() {
206 new_beam.push(first);
207 new_finished.push(true);
208 } else {
209 break;
210 }
211 }
212
213 prev_group_tokens.extend_from_slice(&tokens_chosen_by_group);
214 groups[g] = new_beam;
215 finished[g] = new_finished;
216 }
217
218 if finished.iter().all(|gf| gf.iter().all(|&f| f)) {
220 break;
221 }
222 }
223
224 let mut result: Vec<Vec<usize>> = Vec::with_capacity(cfg.beam_width);
226 for g in 0..cfg.n_groups {
227 for hyp_idx in 0..beam_per_group {
228 result.push(groups[g][hyp_idx].0.clone());
229 }
230 }
231 Ok(result)
232 }
233
234 #[inline]
240 pub fn hamming_penalty(token: usize, prev_group_tokens: &[usize], strength: f32) -> f32 {
241 if strength == 0.0 {
242 return 0.0;
243 }
244 let count = prev_group_tokens.iter().filter(|&&t| t == token).count();
245 strength * count as f32
246 }
247
248 #[inline]
252 pub fn length_norm(score: f32, len: usize, alpha: f32) -> f32 {
253 if alpha == 0.0 || len == 0 {
254 return score;
255 }
256 let denom = ((5.0 + len as f32) / 6.0).powf(alpha);
257 score / denom
258 }
259
260 pub fn top_k(log_probs: &[f32], k: usize) -> Vec<(usize, f32)> {
262 let k = k.min(log_probs.len());
263 let mut indexed: Vec<(usize, f32)> = log_probs.iter().copied().enumerate().collect();
264 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
265 indexed.truncate(k);
266 indexed
267 }
268}
269
270#[cfg(test)]
273mod tests {
274 use super::*;
275
276 fn default_cfg() -> DiverseBeamConfig {
277 DiverseBeamConfig {
278 beam_width: 4,
279 n_groups: 2,
280 max_steps: 8,
281 vocab_size: 5,
282 eos_id: 4,
283 diversity_strength: 0.5,
284 length_norm_alpha: 0.0,
285 }
286 }
287
288 fn prefer_low(prefix: &[usize]) -> Vec<f32> {
290 let _ = prefix;
291 vec![-0.1, -0.5, -1.0, -2.0, -10.0]
292 }
293
294 fn always_eos(_prefix: &[usize]) -> Vec<f32> {
296 vec![-100.0, -100.0, -100.0, -100.0, 0.0]
297 }
298
299 #[test]
300 fn diverse_beam_returns_b_sequences() {
301 let db = DiverseBeam::new(default_cfg()).expect("ok");
302 let seqs = db.search(prefer_low).expect("ok");
303 assert_eq!(seqs.len(), 4, "expected beam_width sequences");
304 }
305
306 #[test]
307 fn diverse_beam_sequences_differ() {
308 let cfg_div = DiverseBeamConfig {
309 diversity_strength: 1.0,
310 ..default_cfg()
311 };
312 let cfg_nodiv = DiverseBeamConfig {
313 diversity_strength: 0.0,
314 ..default_cfg()
315 };
316 let db_div = DiverseBeam::new(cfg_div).expect("ok");
317 let db_nodiv = DiverseBeam::new(cfg_nodiv).expect("ok");
318
319 let seqs_div = db_div.search(prefer_low).expect("ok");
320 let seqs_nodiv = db_nodiv.search(prefer_low).expect("ok");
321
322 let distinct = |seqs: &[Vec<usize>]| {
324 let mut s: Vec<Vec<usize>> = seqs.to_vec();
325 s.sort();
326 s.dedup();
327 s.len()
328 };
329 let div_count = distinct(&seqs_div);
331 let nodiv_count = distinct(&seqs_nodiv);
332 assert!(
333 div_count >= nodiv_count,
334 "diverse search should produce >= distinct seqs: div={div_count} nodiv={nodiv_count}"
335 );
336 }
337
338 #[test]
339 fn hamming_penalty_zero_when_no_overlap() {
340 let penalty = DiverseBeam::hamming_penalty(3, &[0, 1, 2], 1.0);
341 assert_eq!(penalty, 0.0);
342 }
343
344 #[test]
345 fn hamming_penalty_proportional_to_count() {
346 let prev = vec![1usize, 1, 1, 2];
347 let penalty = DiverseBeam::hamming_penalty(1, &prev, 2.0);
348 assert!((penalty - 6.0).abs() < 1e-6, "penalty={penalty}");
349 }
350
351 #[test]
352 fn length_norm_alpha_zero_is_identity() {
353 let s = -3.0f32;
354 assert!((DiverseBeam::length_norm(s, 5, 0.0) - s).abs() < 1e-6);
355 }
356
357 #[test]
358 fn length_norm_longer_sequence_penalized() {
359 let alpha = 0.6f32;
360 let score = 10.0f32;
363 let short = DiverseBeam::length_norm(score, 3, alpha);
364 let long = DiverseBeam::length_norm(score, 10, alpha);
365 assert!(
366 long < short,
367 "longer seq should have lower normalised score: short={short} long={long}"
368 );
369 }
370
371 #[test]
372 fn top_k_returns_k_items() {
373 let probs = vec![-1.0f32, -0.5, -2.0, -0.1, -3.0];
374 let top = DiverseBeam::top_k(&probs, 3);
375 assert_eq!(top.len(), 3);
376 }
377
378 #[test]
379 fn top_k_sorted_desc() {
380 let probs = vec![-1.0f32, -0.5, -2.0, -0.1, -3.0];
381 let top = DiverseBeam::top_k(&probs, 4);
382 for w in top.windows(2) {
383 assert!(w[0].1 >= w[1].1, "not sorted desc: {:?}", top);
384 }
385 }
386
387 #[test]
388 fn top_k_selects_highest_scores() {
389 let probs = vec![-3.0f32, -0.1, -2.0, -5.0];
390 let top = DiverseBeam::top_k(&probs, 1);
391 assert_eq!(top[0].0, 1, "token 1 has max log-prob");
392 }
393
394 #[test]
395 fn diverse_beam_empty_sequences_on_immediate_eos() {
396 let cfg = DiverseBeamConfig {
398 beam_width: 2,
399 n_groups: 1,
400 max_steps: 5,
401 vocab_size: 5,
402 eos_id: 4,
403 diversity_strength: 0.0,
404 length_norm_alpha: 0.0,
405 };
406 let db = DiverseBeam::new(cfg).expect("ok");
407 let seqs = db.search(always_eos).expect("ok");
408 assert_eq!(seqs.len(), 2);
409 for s in &seqs {
410 assert!(!s.is_empty());
411 assert_eq!(*s.last().expect("non-empty"), 4);
412 }
413 }
414
415 #[test]
416 fn diverse_beam_respects_max_steps() {
417 let cfg = DiverseBeamConfig {
418 beam_width: 2,
419 n_groups: 1,
420 max_steps: 3,
421 vocab_size: 3,
422 eos_id: 99,
423 diversity_strength: 0.0,
424 length_norm_alpha: 0.0,
425 };
426 let db = DiverseBeam::new(cfg).expect("ok");
427 let score_no_eos = |_: &[usize]| vec![-1.0f32, -2.0, -3.0];
428 let seqs = db.search(score_no_eos).expect("ok");
429 for s in &seqs {
430 assert!(
431 s.len() <= 3,
432 "sequence longer than max_steps: len={}",
433 s.len()
434 );
435 }
436 }
437
438 #[test]
439 fn new_err_beam_not_divisible() {
440 let mut cfg = default_cfg();
441 cfg.beam_width = 5;
442 cfg.n_groups = 2;
443 assert!(matches!(
444 DiverseBeam::new(cfg),
445 Err(SeqError::InvalidConfiguration(_))
446 ));
447 }
448
449 #[test]
450 fn new_err_zero_groups() {
451 let mut cfg = default_cfg();
452 cfg.n_groups = 0;
453 assert!(matches!(
454 DiverseBeam::new(cfg),
455 Err(SeqError::InvalidConfiguration(_))
456 ));
457 }
458
459 #[test]
460 fn new_err_zero_beam() {
461 let mut cfg = default_cfg();
462 cfg.beam_width = 0;
463 assert!(matches!(
464 DiverseBeam::new(cfg),
465 Err(SeqError::InvalidConfiguration(_))
466 ));
467 }
468
469 #[test]
470 fn new_err_zero_vocab() {
471 let mut cfg = default_cfg();
472 cfg.vocab_size = 0;
473 assert!(matches!(
474 DiverseBeam::new(cfg),
475 Err(SeqError::InvalidConfiguration(_))
476 ));
477 }
478
479 #[test]
480 fn n_groups_1_matches_standard_beam_top_choice() {
481 let cfg = DiverseBeamConfig {
484 beam_width: 2,
485 n_groups: 1,
486 max_steps: 4,
487 vocab_size: 3,
488 eos_id: 99,
489 diversity_strength: 0.0,
490 length_norm_alpha: 0.0,
491 };
492 let db = DiverseBeam::new(cfg).expect("ok");
493 let score_fn = |_: &[usize]| vec![-0.1f32, -1.0, -5.0];
494 let seqs = db.search(score_fn).expect("ok");
495 assert_eq!(seqs[0], vec![0, 0, 0, 0]);
497 }
498
499 #[test]
500 fn diverse_beam_single_token_vocab() {
501 let cfg = DiverseBeamConfig {
502 beam_width: 2,
503 n_groups: 1,
504 max_steps: 3,
505 vocab_size: 1,
506 eos_id: 0,
507 diversity_strength: 0.0,
508 length_norm_alpha: 0.0,
509 };
510 let db = DiverseBeam::new(cfg).expect("ok");
511 let score_fn = |_: &[usize]| vec![0.0f32];
512 let seqs = db.search(score_fn).expect("ok");
513 assert_eq!(seqs.len(), 2);
514 for s in &seqs {
515 assert!(!s.is_empty(), "sequence must not be empty");
516 }
517 }
518
519 #[test]
520 fn diverse_beam_eos_in_group0_still_returns_full() {
521 let cfg = DiverseBeamConfig {
523 beam_width: 4,
524 n_groups: 2,
525 max_steps: 5,
526 vocab_size: 5,
527 eos_id: 4,
528 diversity_strength: 0.3,
529 length_norm_alpha: 0.0,
530 };
531 let db = DiverseBeam::new(cfg).expect("ok");
532
533 let call_count = std::cell::Cell::new(0u32);
535 let seqs = db
536 .search(|prefix| {
537 call_count.set(call_count.get() + 1);
538 if prefix.is_empty() || prefix.last().copied() == Some(4) {
539 vec![
540 f32::NEG_INFINITY,
541 f32::NEG_INFINITY,
542 f32::NEG_INFINITY,
543 f32::NEG_INFINITY,
544 0.0,
545 ]
546 } else {
547 vec![-0.1, -0.5, -1.0, -2.0, -10.0]
548 }
549 })
550 .expect("ok");
551
552 assert_eq!(seqs.len(), 4, "must return beam_width sequences");
553 }
554}