native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BeamError {
    ShapeMismatch,
    Empty,
}

pub fn select_top_beams(
    scores: &[f32],
    beam_width: usize,
    out_indices: &mut [usize],
) -> Result<usize, BeamError> {
    if scores.is_empty() || beam_width == 0 {
        return Err(BeamError::Empty);
    }
    if out_indices.len() < beam_width {
        return Err(BeamError::ShapeMismatch);
    }

    let filled = beam_width.min(scores.len());

    for (slot, idx) in out_indices.iter_mut().take(filled).zip(0..filled) {
        *slot = idx;
    }

    for i in filled..scores.len() {
        let mut worst_slot = 0usize;
        for s in 1..filled {
            if scores[out_indices[s]] < scores[out_indices[worst_slot]] {
                worst_slot = s;
            }
        }

        if scores[i] > scores[out_indices[worst_slot]] {
            out_indices[worst_slot] = i;
        }
    }

    out_indices[..filled].sort_unstable_by(|&a, &b| {
        scores[b]
            .partial_cmp(&scores[a])
            .unwrap_or(core::cmp::Ordering::Equal)
    });

    Ok(filled)
}