native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::SamplingError;

pub fn argmax_sample(probabilities: &[f32]) -> Result<usize, SamplingError> {
    if probabilities.is_empty() {
        return Err(SamplingError::Empty);
    }
    let mut idx = 0usize;
    let mut best = probabilities[0];
    for (i, &p) in probabilities.iter().enumerate().skip(1) {
        if p > best {
            best = p;
            idx = i;
        }
    }
    Ok(idx)
}

pub fn sample_from_cumulative(probabilities: &[f32], threshold: f32) -> Result<usize, SamplingError> {
    if probabilities.is_empty() {
        return Err(SamplingError::Empty);
    }
    if !threshold.is_finite() || !(0.0..=1.0).contains(&threshold) {
        return Err(SamplingError::InvalidParameter);
    }

    let mut cumulative = 0.0f32;
    for (i, &p) in probabilities.iter().enumerate() {
        if !p.is_finite() || p < 0.0 {
            return Err(SamplingError::InvalidParameter);
        }
        cumulative += p;
        if cumulative >= threshold {
            return Ok(i);
        }
    }
    Ok(probabilities.len() - 1)
}