native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
pub use native_neural_network::layers::LayerPlan;

#[derive(Debug)]
pub enum InferenceStdError {
    InvalidPlan,
    ShapeMismatch,
    BatchMismatch,
    ScratchTooSmall,
}

impl From<native_neural_network::inference::InferenceError> for InferenceStdError {
    fn from(e: native_neural_network::inference::InferenceError) -> Self {
        match e {
            native_neural_network::inference::InferenceError::InvalidPlan => {
                InferenceStdError::InvalidPlan
            }
            native_neural_network::inference::InferenceError::ShapeMismatch => {
                InferenceStdError::ShapeMismatch
            }
            native_neural_network::inference::InferenceError::BatchMismatch => {
                InferenceStdError::BatchMismatch
            }
            native_neural_network::inference::InferenceError::ScratchTooSmall => {
                InferenceStdError::ScratchTooSmall
            }
        }
    }
}

pub fn softmax(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceStdError> {
    native_neural_network::inference::softmax_stable_f32(logits, out).map_err(|e| e.into())
}

pub fn softmax_stable(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceStdError> {
    native_neural_network::inference::softmax_stable_f32(logits, out).map_err(|e| e.into())
}

pub fn forward_batch(
    plan: &LayerPlan<'_>,
    input_batch: &[f32],
    output_batch: &mut [f32],
    batch_size: usize,
    scratch_batch: &mut [f32],
) -> Result<(), InferenceStdError> {
    native_neural_network::inference::forward_batch_f32(
        plan,
        input_batch,
        output_batch,
        batch_size,
        scratch_batch,
    )
    .map_err(|e| e.into())
}

pub fn normalize_logits_in_place(logits: &mut [f32]) -> Result<(), InferenceStdError> {
    native_neural_network::inference::normalize_logits_in_place_f32(logits).map_err(|e| e.into())
}

pub fn argmax_index(logits: &[f32]) -> Option<usize> {
    native_neural_network::inference::argmax_index_f32(logits)
}

impl core::fmt::Display for InferenceStdError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "InferenceStdError::{:?}", self)
    }
}

impl std::error::Error for InferenceStdError {}