native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
#[derive(Debug)]
pub enum OptimizerStdError {
    InvalidHyperParams,
    ShapeMismatch,
    StepOverflow,
}

impl From<native_neural_network::optimizers::OptimizerError> for OptimizerStdError {
    fn from(e: native_neural_network::optimizers::OptimizerError) -> Self {
        match e {
            native_neural_network::optimizers::OptimizerError::InvalidHyperParams => {
                OptimizerStdError::InvalidHyperParams
            }
            native_neural_network::optimizers::OptimizerError::ShapeMismatch => {
                OptimizerStdError::ShapeMismatch
            }
            native_neural_network::optimizers::OptimizerError::StepOverflow => {
                OptimizerStdError::StepOverflow
            }
        }
    }
}

#[derive(Clone, Copy, Debug)]
pub enum OptimizerKind {
    Sgd {
        momentum: f32,
        nesterov: bool,
    },
    AdamW {
        beta1: f32,
        beta2: f32,
        eps: f32,
        weight_decay: f32,
    },
}

impl From<OptimizerKind> for native_neural_network::optimizers::OptimizerKind {
    fn from(k: OptimizerKind) -> Self {
        match k {
            OptimizerKind::Sgd { momentum, nesterov } => {
                native_neural_network::optimizers::OptimizerKind::Sgd { momentum, nesterov }
            }
            OptimizerKind::AdamW {
                beta1,
                beta2,
                eps,
                weight_decay,
            } => native_neural_network::optimizers::OptimizerKind::AdamW {
                beta1,
                beta2,
                eps,
                weight_decay,
            },
        }
    }
}

impl From<native_neural_network::optimizers::OptimizerKind> for OptimizerKind {
    fn from(k: native_neural_network::optimizers::OptimizerKind) -> Self {
        match k {
            native_neural_network::optimizers::OptimizerKind::Sgd { momentum, nesterov } => {
                OptimizerKind::Sgd { momentum, nesterov }
            }
            native_neural_network::optimizers::OptimizerKind::AdamW {
                beta1,
                beta2,
                eps,
                weight_decay,
            } => OptimizerKind::AdamW {
                beta1,
                beta2,
                eps,
                weight_decay,
            },
        }
    }
}

pub fn optimizer_state_len(kind: OptimizerKind, param_len: usize) -> Option<usize> {
    native_neural_network::optimizers::optimizer_state_len(kind.into(), param_len)
}

pub fn apply_optimizer_step(
    kind: OptimizerKind,
    params: &mut [f32],
    grads: &[f32],
    state: &mut [f32],
    learning_rate: f32,
    step: u32,
) -> Result<(), OptimizerStdError> {
    native_neural_network::optimizers::apply_optimizer_step(
        kind.into(),
        params,
        grads,
        state,
        learning_rate,
        step,
    )
    .map_err(|e| e.into())
}

pub fn step_sgd(
    params: &mut [f32],
    grads: &[f32],
    velocity: &mut [f32],
    learning_rate: f32,
    momentum: f32,
    nesterov: bool,
) -> Result<(), OptimizerStdError> {
    let kind = OptimizerKind::Sgd { momentum, nesterov };
    apply_optimizer_step(kind, params, grads, velocity, learning_rate, 1)
}

#[derive(Clone, Copy, Debug)]
pub struct AdamWParams {
    pub beta1: f32,
    pub beta2: f32,
    pub eps: f32,
    pub weight_decay: f32,
}

pub fn step_adamw(
    params: &mut [f32],
    grads: &[f32],
    state: &mut [f32],
    learning_rate: f32,
    step: u32,
    opts: AdamWParams,
) -> Result<(), OptimizerStdError> {
    let kind = OptimizerKind::AdamW {
        beta1: opts.beta1,
        beta2: opts.beta2,
        eps: opts.eps,
        weight_decay: opts.weight_decay,
    };
    apply_optimizer_step(kind, params, grads, state, learning_rate, step)
}

impl core::fmt::Display for OptimizerStdError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "OptimizerStdError::{:?}", self)
    }
}

impl std::error::Error for OptimizerStdError {}