native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::SamplingError;

pub fn argmax_sample_f32(probabilities: &[f32]) -> Result<usize, SamplingError> {
    if probabilities.is_empty() {
        return Err(SamplingError::Empty);
    }
    let mut idx = 0usize;
    let mut best = probabilities[0];
    for (i, &p) in probabilities.iter().enumerate().skip(1) {
        if p > best {
            best = p;
            idx = i;
        }
    }
    Ok(idx)
}

pub fn sample_from_cumulative_f32(
    probabilities: &[f32],
    threshold: f32,
) -> Result<usize, SamplingError> {
    if probabilities.is_empty() {
        return Err(SamplingError::Empty);
    }
    if !threshold.is_finite() || !(0.0..=1.0).contains(&threshold) {
        return Err(SamplingError::InvalidParameter);
    }

    let mut cumulative = 0.0f32;
    for (i, &p) in probabilities.iter().enumerate() {
        if !p.is_finite() || p < 0.0 {
            return Err(SamplingError::InvalidParameter);
        }
        cumulative += p;
        if cumulative >= threshold {
            return Ok(i);
        }
    }
    Ok(probabilities.len() - 1)
}

pub fn argmax_sample_f64(probabilities: &[f64]) -> Result<usize, SamplingError> {
    if probabilities.is_empty() {
        return Err(SamplingError::Empty);
    }
    let mut idx = 0usize;
    let mut best = probabilities[0];
    for (i, &p) in probabilities.iter().enumerate().skip(1) {
        if p > best {
            best = p;
            idx = i;
        }
    }
    Ok(idx)
}

pub fn sample_from_cumulative_f64(
    probabilities: &[f64],
    threshold: f64,
) -> Result<usize, SamplingError> {
    if probabilities.is_empty() {
        return Err(SamplingError::Empty);
    }
    if !threshold.is_finite() || !(0.0..=1.0).contains(&threshold) {
        return Err(SamplingError::InvalidParameter);
    }

    let mut cumulative = 0.0f64;
    for (i, &p) in probabilities.iter().enumerate() {
        if !p.is_finite() || p < 0.0 {
            return Err(SamplingError::InvalidParameter);
        }
        cumulative += p;
        if cumulative >= threshold {
            return Ok(i);
        }
    }
    Ok(probabilities.len() - 1)
}