native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::adamw::step_adamw;
use super::kind::{OptimizerError, OptimizerKind};
use super::sgd::step_sgd;

pub fn optimizer_state_len(kind: OptimizerKind, param_count: usize) -> Option<usize> {
    match kind {
        OptimizerKind::Sgd { .. } => Some(param_count),
        OptimizerKind::AdamW { .. } => param_count.checked_mul(2),
    }
}

pub fn apply_optimizer_step(
    kind: OptimizerKind,
    params: &mut [f32],
    grads: &[f32],
    state: &mut [f32],
    learning_rate: f32,
    step: u32,
) -> Result<(), OptimizerError> {
    if !learning_rate.is_finite() || learning_rate <= 0.0 {
        return Err(OptimizerError::InvalidHyperParams);
    }
    if params.len() != grads.len() {
        return Err(OptimizerError::ShapeMismatch);
    }

    match kind {
        OptimizerKind::Sgd { momentum, nesterov } => {
            step_sgd(params, grads, state, learning_rate, momentum, nesterov)
        }
        OptimizerKind::AdamW { beta1, beta2, eps, weight_decay } => {
            if state.len() != params.len().saturating_mul(2) {
                return Err(OptimizerError::ShapeMismatch);
            }
            let (m, v) = state.split_at_mut(params.len());
            let cfg = super::adamw::AdamwConfig {
                learning_rate,
                step,
                beta1,
                beta2,
                eps,
                weight_decay,
            };
            step_adamw(params, grads, m, v, cfg)
        }
    }
}