rnn/beam_search/
beam_search.rs1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum BeamError {
3 ShapeMismatch,
4 Empty,
5}
6
7pub fn select_top_beams(scores: &[f32], beam_width: usize, out_indices: &mut [usize]) -> Result<usize, BeamError> {
8 if scores.is_empty() || beam_width == 0 {
9 return Err(BeamError::Empty);
10 }
11 if out_indices.len() < beam_width {
12 return Err(BeamError::ShapeMismatch);
13 }
14
15 let filled = beam_width.min(scores.len());
16
17 for (slot, idx) in out_indices.iter_mut().take(filled).zip(0..filled) {
18 *slot = idx;
19 }
20
21 for i in filled..scores.len() {
22 let mut worst_slot = 0usize;
23 for s in 1..filled {
24 if scores[out_indices[s]] < scores[out_indices[worst_slot]] {
25 worst_slot = s;
26 }
27 }
28
29 if scores[i] > scores[out_indices[worst_slot]] {
30 out_indices[worst_slot] = i;
31 }
32 }
33
34 out_indices[..filled].sort_unstable_by(|&a, &b| {
35 scores[b].partial_cmp(&scores[a]).unwrap_or(core::cmp::Ordering::Equal)
36 });
37
38 Ok(filled)
39}