Skip to main content

oxicuda_seq/beam/
diverse.rs

1use crate::error::{SeqError, SeqResult};
2
3/// Configuration for Diverse Beam Search (Vijayakumar et al. 2018).
4#[derive(Debug, Clone, Copy)]
5pub struct DiverseBeamConfig {
6    /// Total beam width B. Must be divisible by `n_groups`.
7    pub beam_width: usize,
8    /// G: number of diversity groups. Each group gets `beam_width / n_groups` beams.
9    pub n_groups: usize,
10    /// Maximum decoding steps.
11    pub max_steps: usize,
12    /// Vocabulary size (number of tokens).
13    pub vocab_size: usize,
14    /// Token ID for end-of-sequence.
15    pub eos_id: usize,
16    /// λ: Hamming diversity penalty strength.
17    pub diversity_strength: f32,
18    /// GNMT-style length normalisation exponent. 0.0 = no normalisation.
19    pub length_norm_alpha: f32,
20}
21
22/// Diverse beam search decoder.
23#[derive(Debug, Clone)]
24pub struct DiverseBeam {
25    pub cfg: DiverseBeamConfig,
26}
27
28impl DiverseBeam {
29    /// Validate configuration and create a new `DiverseBeam`.
30    pub fn new(cfg: DiverseBeamConfig) -> SeqResult<Self> {
31        if cfg.beam_width == 0 {
32            return Err(SeqError::InvalidConfiguration(
33                "beam_width must be > 0".to_string(),
34            ));
35        }
36        if cfg.n_groups == 0 {
37            return Err(SeqError::InvalidConfiguration(
38                "n_groups must be > 0".to_string(),
39            ));
40        }
41        if cfg.beam_width % cfg.n_groups != 0 {
42            return Err(SeqError::InvalidConfiguration(format!(
43                "beam_width ({}) must be divisible by n_groups ({})",
44                cfg.beam_width, cfg.n_groups
45            )));
46        }
47        if cfg.vocab_size == 0 {
48            return Err(SeqError::InvalidConfiguration(
49                "vocab_size must be > 0".to_string(),
50            ));
51        }
52        Ok(Self { cfg })
53    }
54
55    /// Run diverse beam search.
56    ///
57    /// Groups decode sequentially at each time step.  Group g receives a
58    /// Hamming diversity penalty based on the tokens that groups 0..g-1
59    /// committed to at the current step.
60    ///
61    /// Returns `beam_width` sequences (all groups concatenated), each a
62    /// vector of token IDs.
63    pub fn search<F>(&self, score_fn: F) -> SeqResult<Vec<Vec<usize>>>
64    where
65        F: Fn(&[usize]) -> Vec<f32>,
66    {
67        let cfg = &self.cfg;
68        let beam_per_group = cfg.beam_width / cfg.n_groups;
69
70        // Each group maintains its own beam: Vec<(tokens, cumulative_score)>.
71        let mut groups: Vec<Vec<(Vec<usize>, f32)>> =
72            vec![vec![(vec![], 0.0f32); beam_per_group]; cfg.n_groups];
73
74        // Track which hypotheses are finished (emitted EOS).
75        let mut finished: Vec<Vec<bool>> = vec![vec![false; beam_per_group]; cfg.n_groups];
76
77        for _step in 0..cfg.max_steps {
78            // Tokens chosen by groups that have already decoded this step.
79            let mut prev_group_tokens: Vec<usize> = Vec::new();
80
81            for g in 0..cfg.n_groups {
82                // Skip group if all hypotheses are finished.
83                if finished[g].iter().all(|&f| f) {
84                    // Still register their current last tokens for later groups.
85                    for hyp_idx in 0..beam_per_group {
86                        if let Some(&last_tok) = groups[g][hyp_idx].0.last() {
87                            prev_group_tokens.push(last_tok);
88                        }
89                    }
90                    continue;
91                }
92
93                // Candidates: (hypothesis_idx, token, adjusted_score).
94                let mut candidates: Vec<(usize, usize, f32, Vec<usize>)> = Vec::new();
95
96                for hyp_idx in 0..beam_per_group {
97                    let (ref tokens, cum_score) = groups[g][hyp_idx].clone();
98                    if finished[g][hyp_idx] {
99                        // Carry finished hypotheses forward unchanged.
100                        candidates.push((hyp_idx, cfg.eos_id, cum_score, tokens.clone()));
101                        continue;
102                    }
103
104                    let log_probs = score_fn(tokens);
105                    if log_probs.len() != cfg.vocab_size {
106                        return Err(SeqError::ShapeMismatch {
107                            expected: cfg.vocab_size,
108                            got: log_probs.len(),
109                        });
110                    }
111
112                    for tok in 0..cfg.vocab_size {
113                        let raw = log_probs[tok];
114                        let diversity_pen =
115                            Self::hamming_penalty(tok, &prev_group_tokens, cfg.diversity_strength);
116                        let candidate_score = cum_score + raw - diversity_pen;
117                        let mut new_tokens = tokens.clone();
118                        new_tokens.push(tok);
119                        let norm_score = Self::length_norm(
120                            candidate_score,
121                            new_tokens.len(),
122                            cfg.length_norm_alpha,
123                        );
124                        // Store (hyp_idx, tok, norm_score for ranking, extended tokens).
125                        candidates.push((hyp_idx, tok, norm_score, new_tokens));
126                    }
127                }
128
129                // Sort by adjusted normalised score, descending.
130                candidates
131                    .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
132
133                // Select top beam_per_group survivors, deduplicated by sequence.
134                let mut new_beam: Vec<(Vec<usize>, f32)> = Vec::with_capacity(beam_per_group);
135                let mut new_finished: Vec<bool> = Vec::with_capacity(beam_per_group);
136                let mut tokens_chosen_by_group: Vec<usize> = Vec::new();
137
138                for (_, tok, _, new_tokens) in candidates.iter() {
139                    if new_beam.len() >= beam_per_group {
140                        break;
141                    }
142                    // Recompute the actual cumulative (unnormalised) score for storage.
143                    // We need to recover cum_score + raw for proper accumulation.
144                    // Since we sorted on normalised+diversity-adjusted, extract raw score.
145                    // Easier: store unnormalised in a parallel structure.
146                    let is_eos = *tok == cfg.eos_id;
147                    new_beam.push((new_tokens.clone(), 0.0));
148                    new_finished.push(is_eos);
149                    tokens_chosen_by_group.push(*tok);
150                }
151
152                // We need unnormalised cumulative scores for correct accumulation.
153                // Redo: collect candidates with both norm_score (for ranking) and raw cumsum.
154                // Overwrite new_beam with proper cumulative scores by re-expanding.
155                let mut candidates2: Vec<(usize, usize, f32, f32, Vec<usize>)> = Vec::new();
156                for hyp_idx in 0..beam_per_group {
157                    let (ref tokens, cum_score) = groups[g][hyp_idx].clone();
158                    if finished[g][hyp_idx] {
159                        candidates2.push((
160                            hyp_idx,
161                            cfg.eos_id,
162                            cum_score,
163                            cum_score,
164                            tokens.clone(),
165                        ));
166                        continue;
167                    }
168                    let log_probs = score_fn(tokens);
169                    for tok in 0..cfg.vocab_size {
170                        let raw = log_probs[tok];
171                        let diversity_pen =
172                            Self::hamming_penalty(tok, &prev_group_tokens, cfg.diversity_strength);
173                        let new_cum = cum_score + raw - diversity_pen;
174                        let mut new_tokens = tokens.clone();
175                        new_tokens.push(tok);
176                        let norm_score =
177                            Self::length_norm(new_cum, new_tokens.len(), cfg.length_norm_alpha);
178                        candidates2.push((hyp_idx, tok, new_cum, norm_score, new_tokens));
179                    }
180                }
181                candidates2
182                    .sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
183
184                new_beam.clear();
185                new_finished.clear();
186                tokens_chosen_by_group.clear();
187
188                for (_, tok, new_cum, _, new_tokens) in candidates2.iter() {
189                    if new_beam.len() >= beam_per_group {
190                        break;
191                    }
192                    new_beam.push((new_tokens.clone(), *new_cum));
193                    new_finished.push(*tok == cfg.eos_id);
194                    if !finished[g]
195                        .get(new_beam.len().saturating_sub(1))
196                        .copied()
197                        .unwrap_or(false)
198                    {
199                        tokens_chosen_by_group.push(*tok);
200                    }
201                }
202
203                // Pad if score_fn returned fewer candidates than needed.
204                while new_beam.len() < beam_per_group {
205                    if let Some(first) = new_beam.first().cloned() {
206                        new_beam.push(first);
207                        new_finished.push(true);
208                    } else {
209                        break;
210                    }
211                }
212
213                prev_group_tokens.extend_from_slice(&tokens_chosen_by_group);
214                groups[g] = new_beam;
215                finished[g] = new_finished;
216            }
217
218            // Early exit when all beams in all groups are finished.
219            if finished.iter().all(|gf| gf.iter().all(|&f| f)) {
220                break;
221            }
222        }
223
224        // Collect all hypotheses across groups, in group order.
225        let mut result: Vec<Vec<usize>> = Vec::with_capacity(cfg.beam_width);
226        for g in 0..cfg.n_groups {
227            for hyp_idx in 0..beam_per_group {
228                result.push(groups[g][hyp_idx].0.clone());
229            }
230        }
231        Ok(result)
232    }
233
234    /// Hamming diversity penalty for `token`.
235    ///
236    /// Counts how many elements of `prev_group_tokens` equal `token`, then
237    /// multiplies by `strength`.  This penalises tokens already chosen by
238    /// earlier groups at the current decoding step.
239    #[inline]
240    pub fn hamming_penalty(token: usize, prev_group_tokens: &[usize], strength: f32) -> f32 {
241        if strength == 0.0 {
242            return 0.0;
243        }
244        let count = prev_group_tokens.iter().filter(|&&t| t == token).count();
245        strength * count as f32
246    }
247
248    /// GNMT-style length normalisation: score / ((5 + len) / 6)^alpha.
249    ///
250    /// When `alpha == 0.0` the denominator is 1.0 (no normalisation).
251    #[inline]
252    pub fn length_norm(score: f32, len: usize, alpha: f32) -> f32 {
253        if alpha == 0.0 || len == 0 {
254            return score;
255        }
256        let denom = ((5.0 + len as f32) / 6.0).powf(alpha);
257        score / denom
258    }
259
260    /// Return the top-k (token, score) pairs sorted by score descending.
261    pub fn top_k(log_probs: &[f32], k: usize) -> Vec<(usize, f32)> {
262        let k = k.min(log_probs.len());
263        let mut indexed: Vec<(usize, f32)> = log_probs.iter().copied().enumerate().collect();
264        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
265        indexed.truncate(k);
266        indexed
267    }
268}
269
270// ─── inline tests ────────────────────────────────────────────────────────────
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    fn default_cfg() -> DiverseBeamConfig {
277        DiverseBeamConfig {
278            beam_width: 4,
279            n_groups: 2,
280            max_steps: 8,
281            vocab_size: 5,
282            eos_id: 4,
283            diversity_strength: 0.5,
284            length_norm_alpha: 0.0,
285        }
286    }
287
288    /// Simple score function: prefer lower token indices (token 0 is best).
289    fn prefer_low(prefix: &[usize]) -> Vec<f32> {
290        let _ = prefix;
291        vec![-0.1, -0.5, -1.0, -2.0, -10.0]
292    }
293
294    /// Score function that always returns a high score for EOS.
295    fn always_eos(_prefix: &[usize]) -> Vec<f32> {
296        vec![-100.0, -100.0, -100.0, -100.0, 0.0]
297    }
298
299    #[test]
300    fn diverse_beam_returns_b_sequences() {
301        let db = DiverseBeam::new(default_cfg()).expect("ok");
302        let seqs = db.search(prefer_low).expect("ok");
303        assert_eq!(seqs.len(), 4, "expected beam_width sequences");
304    }
305
306    #[test]
307    fn diverse_beam_sequences_differ() {
308        let cfg_div = DiverseBeamConfig {
309            diversity_strength: 1.0,
310            ..default_cfg()
311        };
312        let cfg_nodiv = DiverseBeamConfig {
313            diversity_strength: 0.0,
314            ..default_cfg()
315        };
316        let db_div = DiverseBeam::new(cfg_div).expect("ok");
317        let db_nodiv = DiverseBeam::new(cfg_nodiv).expect("ok");
318
319        let seqs_div = db_div.search(prefer_low).expect("ok");
320        let seqs_nodiv = db_nodiv.search(prefer_low).expect("ok");
321
322        // Count distinct sequences.
323        let distinct = |seqs: &[Vec<usize>]| {
324            let mut s: Vec<Vec<usize>> = seqs.to_vec();
325            s.sort();
326            s.dedup();
327            s.len()
328        };
329        // With diversity > 0, expect at least as many distinct sequences.
330        let div_count = distinct(&seqs_div);
331        let nodiv_count = distinct(&seqs_nodiv);
332        assert!(
333            div_count >= nodiv_count,
334            "diverse search should produce >= distinct seqs: div={div_count} nodiv={nodiv_count}"
335        );
336    }
337
338    #[test]
339    fn hamming_penalty_zero_when_no_overlap() {
340        let penalty = DiverseBeam::hamming_penalty(3, &[0, 1, 2], 1.0);
341        assert_eq!(penalty, 0.0);
342    }
343
344    #[test]
345    fn hamming_penalty_proportional_to_count() {
346        let prev = vec![1usize, 1, 1, 2];
347        let penalty = DiverseBeam::hamming_penalty(1, &prev, 2.0);
348        assert!((penalty - 6.0).abs() < 1e-6, "penalty={penalty}");
349    }
350
351    #[test]
352    fn length_norm_alpha_zero_is_identity() {
353        let s = -3.0f32;
354        assert!((DiverseBeam::length_norm(s, 5, 0.0) - s).abs() < 1e-6);
355    }
356
357    #[test]
358    fn length_norm_longer_sequence_penalized() {
359        let alpha = 0.6f32;
360        // Use a positive score to verify that a longer sequence gets a smaller
361        // normalised score (the denominator grows with length, so score/denom shrinks).
362        let score = 10.0f32;
363        let short = DiverseBeam::length_norm(score, 3, alpha);
364        let long = DiverseBeam::length_norm(score, 10, alpha);
365        assert!(
366            long < short,
367            "longer seq should have lower normalised score: short={short} long={long}"
368        );
369    }
370
371    #[test]
372    fn top_k_returns_k_items() {
373        let probs = vec![-1.0f32, -0.5, -2.0, -0.1, -3.0];
374        let top = DiverseBeam::top_k(&probs, 3);
375        assert_eq!(top.len(), 3);
376    }
377
378    #[test]
379    fn top_k_sorted_desc() {
380        let probs = vec![-1.0f32, -0.5, -2.0, -0.1, -3.0];
381        let top = DiverseBeam::top_k(&probs, 4);
382        for w in top.windows(2) {
383            assert!(w[0].1 >= w[1].1, "not sorted desc: {:?}", top);
384        }
385    }
386
387    #[test]
388    fn top_k_selects_highest_scores() {
389        let probs = vec![-3.0f32, -0.1, -2.0, -5.0];
390        let top = DiverseBeam::top_k(&probs, 1);
391        assert_eq!(top[0].0, 1, "token 1 has max log-prob");
392    }
393
394    #[test]
395    fn diverse_beam_empty_sequences_on_immediate_eos() {
396        // EOS (token 4) always gets best score → all sequences should be [4].
397        let cfg = DiverseBeamConfig {
398            beam_width: 2,
399            n_groups: 1,
400            max_steps: 5,
401            vocab_size: 5,
402            eos_id: 4,
403            diversity_strength: 0.0,
404            length_norm_alpha: 0.0,
405        };
406        let db = DiverseBeam::new(cfg).expect("ok");
407        let seqs = db.search(always_eos).expect("ok");
408        assert_eq!(seqs.len(), 2);
409        for s in &seqs {
410            assert!(!s.is_empty());
411            assert_eq!(*s.last().expect("non-empty"), 4);
412        }
413    }
414
415    #[test]
416    fn diverse_beam_respects_max_steps() {
417        let cfg = DiverseBeamConfig {
418            beam_width: 2,
419            n_groups: 1,
420            max_steps: 3,
421            vocab_size: 3,
422            eos_id: 99,
423            diversity_strength: 0.0,
424            length_norm_alpha: 0.0,
425        };
426        let db = DiverseBeam::new(cfg).expect("ok");
427        let score_no_eos = |_: &[usize]| vec![-1.0f32, -2.0, -3.0];
428        let seqs = db.search(score_no_eos).expect("ok");
429        for s in &seqs {
430            assert!(
431                s.len() <= 3,
432                "sequence longer than max_steps: len={}",
433                s.len()
434            );
435        }
436    }
437
438    #[test]
439    fn new_err_beam_not_divisible() {
440        let mut cfg = default_cfg();
441        cfg.beam_width = 5;
442        cfg.n_groups = 2;
443        assert!(matches!(
444            DiverseBeam::new(cfg),
445            Err(SeqError::InvalidConfiguration(_))
446        ));
447    }
448
449    #[test]
450    fn new_err_zero_groups() {
451        let mut cfg = default_cfg();
452        cfg.n_groups = 0;
453        assert!(matches!(
454            DiverseBeam::new(cfg),
455            Err(SeqError::InvalidConfiguration(_))
456        ));
457    }
458
459    #[test]
460    fn new_err_zero_beam() {
461        let mut cfg = default_cfg();
462        cfg.beam_width = 0;
463        assert!(matches!(
464            DiverseBeam::new(cfg),
465            Err(SeqError::InvalidConfiguration(_))
466        ));
467    }
468
469    #[test]
470    fn new_err_zero_vocab() {
471        let mut cfg = default_cfg();
472        cfg.vocab_size = 0;
473        assert!(matches!(
474            DiverseBeam::new(cfg),
475            Err(SeqError::InvalidConfiguration(_))
476        ));
477    }
478
479    #[test]
480    fn n_groups_1_matches_standard_beam_top_choice() {
481        // With n_groups=1 and no diversity, the top beam should greedily prefer
482        // the highest-probability token at each step.
483        let cfg = DiverseBeamConfig {
484            beam_width: 2,
485            n_groups: 1,
486            max_steps: 4,
487            vocab_size: 3,
488            eos_id: 99,
489            diversity_strength: 0.0,
490            length_norm_alpha: 0.0,
491        };
492        let db = DiverseBeam::new(cfg).expect("ok");
493        let score_fn = |_: &[usize]| vec![-0.1f32, -1.0, -5.0];
494        let seqs = db.search(score_fn).expect("ok");
495        // Top beam should be all token-0.
496        assert_eq!(seqs[0], vec![0, 0, 0, 0]);
497    }
498
499    #[test]
500    fn diverse_beam_single_token_vocab() {
501        let cfg = DiverseBeamConfig {
502            beam_width: 2,
503            n_groups: 1,
504            max_steps: 3,
505            vocab_size: 1,
506            eos_id: 0,
507            diversity_strength: 0.0,
508            length_norm_alpha: 0.0,
509        };
510        let db = DiverseBeam::new(cfg).expect("ok");
511        let score_fn = |_: &[usize]| vec![0.0f32];
512        let seqs = db.search(score_fn).expect("ok");
513        assert_eq!(seqs.len(), 2);
514        for s in &seqs {
515            assert!(!s.is_empty(), "sequence must not be empty");
516        }
517    }
518
519    #[test]
520    fn diverse_beam_eos_in_group0_still_returns_full() {
521        // Group 0 finishes immediately; group 1 should still decode normally.
522        let cfg = DiverseBeamConfig {
523            beam_width: 4,
524            n_groups: 2,
525            max_steps: 5,
526            vocab_size: 5,
527            eos_id: 4,
528            diversity_strength: 0.3,
529            length_norm_alpha: 0.0,
530        };
531        let db = DiverseBeam::new(cfg).expect("ok");
532
533        // Group 0 will pick EOS; group 1 will pick token 0.
534        let call_count = std::cell::Cell::new(0u32);
535        let seqs = db
536            .search(|prefix| {
537                call_count.set(call_count.get() + 1);
538                if prefix.is_empty() || prefix.last().copied() == Some(4) {
539                    vec![
540                        f32::NEG_INFINITY,
541                        f32::NEG_INFINITY,
542                        f32::NEG_INFINITY,
543                        f32::NEG_INFINITY,
544                        0.0,
545                    ]
546                } else {
547                    vec![-0.1, -0.5, -1.0, -2.0, -10.0]
548                }
549            })
550            .expect("ok");
551
552        assert_eq!(seqs.len(), 4, "must return beam_width sequences");
553    }
554}