native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::engine::{
    forward_plan_kernel, forward_plan_kernel_f64, required_batch_scratch_len,
    required_batch_scratch_len_f64, ForwardError,
};
use crate::layers::{LayerPlan, LayerPlanF64};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InferenceError {
    InvalidPlan,
    ShapeMismatch,
    BatchMismatch,
    ScratchTooSmall,
}

pub fn softmax_stable_f32(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceError> {
    if logits.is_empty() || out.len() != logits.len() {
        return Err(InferenceError::ShapeMismatch);
    }

    if crate::engine::try_invoke_gpu_softmax_f32(logits, out) {
        return Ok(());
    }

    let mut max_v = logits[0];
    for value in logits.iter().skip(1) {
        if *value > max_v {
            max_v = *value;
        }
    }

    let mut sum = 0.0f32;
    for i in 0..logits.len() {
        let e = crate::math::expf(logits[i] - max_v);
        out[i] = e;
        sum += e;
    }
    if !sum.is_finite() || sum <= 0.0 {
        return Err(InferenceError::InvalidPlan);
    }
    let inv_sum = 1.0f32 / sum;
    for value in out {
        *value *= inv_sum;
    }

    Ok(())
}

pub fn softmax_stable_f64(logits: &[f64], out: &mut [f64]) -> Result<(), InferenceError> {
    if logits.is_empty() || out.len() != logits.len() {
        return Err(InferenceError::ShapeMismatch);
    }

    if crate::engine::try_invoke_gpu_softmax_f64(logits, out) {
        return Ok(());
    }

    let mut max_v = logits[0];
    for value in logits.iter().skip(1) {
        if *value > max_v {
            max_v = *value;
        }
    }

    let mut sum = 0.0f64;
    for i in 0..logits.len() {
        let e = crate::math::expd(logits[i] - max_v);
        out[i] = e;
        sum += e;
    }
    if !sum.is_finite() || sum <= 0.0 {
        return Err(InferenceError::InvalidPlan);
    }
    let inv_sum = 1.0f64 / sum;
    for value in out {
        *value *= inv_sum;
    }

    Ok(())
}

pub fn forward_batch_f32(
    plan: &LayerPlan<'_>,
    input_batch: &[f32],
    output_batch: &mut [f32],
    batch_size: usize,
    scratch_batch: &mut [f32],
) -> Result<(), InferenceError> {
    let input_size = plan.input_size().ok_or(InferenceError::InvalidPlan)?;
    let output_size = plan.output_size().ok_or(InferenceError::InvalidPlan)?;

    let expected_in = batch_size
        .checked_mul(input_size)
        .ok_or(InferenceError::BatchMismatch)?;
    let expected_out = batch_size
        .checked_mul(output_size)
        .ok_or(InferenceError::BatchMismatch)?;

    if input_batch.len() != expected_in || output_batch.len() != expected_out {
        return Err(InferenceError::BatchMismatch);
    }

    let needed =
        required_batch_scratch_len(plan, batch_size).ok_or(InferenceError::ScratchTooSmall)?;
    if scratch_batch.len() < needed {
        return Err(InferenceError::ScratchTooSmall);
    }

    forward_plan_kernel(plan, input_batch, output_batch, batch_size, scratch_batch)
        .map_err(map_forward_error)?;

    Ok(())
}

pub fn forward_batch_f64(
    plan: &LayerPlanF64<'_>,
    input_batch: &[f64],
    output_batch: &mut [f64],
    batch_size: usize,
    scratch_batch: &mut [f64],
) -> Result<(), InferenceError> {
    let input_size = plan.input_size().ok_or(InferenceError::InvalidPlan)?;
    let output_size = plan.output_size().ok_or(InferenceError::InvalidPlan)?;

    let expected_in = batch_size
        .checked_mul(input_size)
        .ok_or(InferenceError::BatchMismatch)?;
    let expected_out = batch_size
        .checked_mul(output_size)
        .ok_or(InferenceError::BatchMismatch)?;

    if input_batch.len() != expected_in || output_batch.len() != expected_out {
        return Err(InferenceError::BatchMismatch);
    }

    let needed =
        required_batch_scratch_len_f64(plan, batch_size).ok_or(InferenceError::ScratchTooSmall)?;
    if scratch_batch.len() < needed {
        return Err(InferenceError::ScratchTooSmall);
    }

    forward_plan_kernel_f64(plan, input_batch, output_batch, batch_size, scratch_batch)
        .map_err(map_forward_error)?;

    Ok(())
}

fn map_forward_error(err: ForwardError) -> InferenceError {
    match err {
        ForwardError::InvalidPlan => InferenceError::InvalidPlan,
        ForwardError::ShapeMismatch => InferenceError::ShapeMismatch,
        ForwardError::ScratchTooSmall => InferenceError::ScratchTooSmall,
    }
}