native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::engine::{forward_dense_plan_big_kernel, required_batch_scratch_len, ForwardError};
use crate::layers::LayerPlan;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InferenceError {
    InvalidPlan,
    ShapeMismatch,
    BatchMismatch,
    ScratchTooSmall,
}

pub fn softmax_stable(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceError> {
    if logits.is_empty() || out.len() != logits.len() {
        return Err(InferenceError::ShapeMismatch);
    }

    let mut max_v = logits[0];
    for value in logits.iter().skip(1) {
        if *value > max_v {
            max_v = *value;
        }
    }

    let mut sum = 0.0f32;
    for i in 0..logits.len() {
        let e = crate::math::expf(logits[i] - max_v);
        out[i] = e;
        sum += e;
    }

    if !sum.is_finite() || sum <= 0.0 {
        return Err(InferenceError::InvalidPlan);
    }

    let inv_sum = 1.0 / sum;
    for value in out {
        *value *= inv_sum;
    }

    Ok(())
}

pub fn forward_dense_batch(
    plan: &LayerPlan<'_>,
    input_batch: &[f32],
    output_batch: &mut [f32],
    batch_size: usize,
    scratch_batch: &mut [f32],
) -> Result<(), InferenceError> {
    let input_size = plan.input_size().ok_or(InferenceError::InvalidPlan)?;
    let output_size = plan.output_size().ok_or(InferenceError::InvalidPlan)?;

    let expected_in = batch_size.checked_mul(input_size).ok_or(InferenceError::BatchMismatch)?;
    let expected_out = batch_size.checked_mul(output_size).ok_or(InferenceError::BatchMismatch)?;

    if input_batch.len() != expected_in || output_batch.len() != expected_out {
        return Err(InferenceError::BatchMismatch);
    }

    let needed = required_batch_scratch_len(plan, batch_size).ok_or(InferenceError::ScratchTooSmall)?;
    if scratch_batch.len() < needed {
        return Err(InferenceError::ScratchTooSmall);
    }

    forward_dense_plan_big_kernel(plan, input_batch, output_batch, batch_size, scratch_batch)
        .map_err(map_forward_error)?;

    Ok(())
}

fn map_forward_error(err: ForwardError) -> InferenceError {
    match err {
        ForwardError::InvalidPlan => InferenceError::InvalidPlan,
        ForwardError::ShapeMismatch => InferenceError::ShapeMismatch,
        ForwardError::ScratchTooSmall => InferenceError::ScratchTooSmall,
    }
}