native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::{softmax_stable, InferenceError};

pub fn normalize_logits_in_place(logits: &mut [f32]) -> Result<(), InferenceError> {
    if logits.is_empty() {
        return Err(InferenceError::ShapeMismatch);
    }

    let mut tmp = [0.0f32; 4096];
    if logits.len() > tmp.len() {
        return Err(InferenceError::ShapeMismatch);
    }

    softmax_stable(logits, &mut tmp[..logits.len()])?;
    logits.copy_from_slice(&tmp[..logits.len()]);
    Ok(())
}

pub fn argmax_index(logits: &[f32]) -> Option<usize> {
    if logits.is_empty() {
        return None;
    }
    let mut best_idx = 0usize;
    let mut best = logits[0];
    for (i, &v) in logits.iter().enumerate().skip(1) {
        if v > best {
            best = v;
            best_idx = i;
        }
    }
    Some(best_idx)
}