Skip to main content

rnn/beam_search/
beam_search.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum BeamError {
3    ShapeMismatch,
4    Empty,
5}
6
7pub fn select_top_beams(scores: &[f32], beam_width: usize, out_indices: &mut [usize]) -> Result<usize, BeamError> {
8    if scores.is_empty() || beam_width == 0 {
9        return Err(BeamError::Empty);
10    }
11    if out_indices.len() < beam_width {
12        return Err(BeamError::ShapeMismatch);
13    }
14
15    let filled = beam_width.min(scores.len());
16
17    for (slot, idx) in out_indices.iter_mut().take(filled).zip(0..filled) {
18        *slot = idx;
19    }
20
21    for i in filled..scores.len() {
22        let mut worst_slot = 0usize;
23        for s in 1..filled {
24            if scores[out_indices[s]] < scores[out_indices[worst_slot]] {
25                worst_slot = s;
26            }
27        }
28
29        if scores[i] > scores[out_indices[worst_slot]] {
30            out_indices[worst_slot] = i;
31        }
32    }
33
34    out_indices[..filled].sort_unstable_by(|&a, &b| {
35        scores[b].partial_cmp(&scores[a]).unwrap_or(core::cmp::Ordering::Equal)
36    });
37
38    Ok(filled)
39}