native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::adamw::{step_adamw_f32, step_adamw_f64, AdamwConfigF64};
use super::kind::{OptimizerError, OptimizerKind};
use super::sgd::{step_sgd, step_sgd_f64};

pub fn optimizer_state_len(kind: OptimizerKind, param_count: usize) -> Option<usize> {
    match kind {
        OptimizerKind::Sgd { .. } => Some(param_count),
        OptimizerKind::AdamW { .. } => param_count.checked_mul(2),
    }
}

pub fn optimizer_state_len_f64(
    kind: super::kind::OptimizerKindF64,
    param_count: usize,
) -> Option<usize> {
    match kind {
        super::kind::OptimizerKindF64::Sgd { .. } => Some(param_count),
        super::kind::OptimizerKindF64::AdamW { .. } => param_count.checked_mul(2),
    }
}

pub fn apply_optimizer_step(
    kind: OptimizerKind,
    params: &mut [f32],
    grads: &[f32],
    state: &mut [f32],
    learning_rate: f32,
    step: u32,
) -> Result<(), OptimizerError> {
    if !learning_rate.is_finite() || learning_rate <= 0.0 {
        return Err(OptimizerError::InvalidHyperParams);
    }
    if params.len() != grads.len() {
        return Err(OptimizerError::ShapeMismatch);
    }

    match kind {
        OptimizerKind::Sgd { momentum, nesterov } => {
            step_sgd(params, grads, state, learning_rate, momentum, nesterov)
        }
        OptimizerKind::AdamW {
            beta1,
            beta2,
            eps,
            weight_decay,
        } => {
            if state.len() != params.len().saturating_mul(2) {
                return Err(OptimizerError::ShapeMismatch);
            }
            let (m, v) = state.split_at_mut(params.len());
            let cfg = super::adamw::AdamwConfig {
                learning_rate,
                step,
                beta1,
                beta2,
                eps,
                weight_decay,
            };
            step_adamw_f32(params, grads, m, v, cfg)
        }
    }
}

pub fn apply_optimizer_step_f64(
    kind: super::kind::OptimizerKindF64,
    params: &mut [f64],
    grads: &[f64],
    state: &mut [f64],
    learning_rate: f64,
    step: u32,
) -> Result<(), OptimizerError> {
    if !learning_rate.is_finite() || learning_rate <= 0.0 {
        return Err(OptimizerError::InvalidHyperParams);
    }
    if params.len() != grads.len() {
        return Err(OptimizerError::ShapeMismatch);
    }

    match kind {
        super::kind::OptimizerKindF64::Sgd { momentum, nesterov } => {
            step_sgd_f64(params, grads, state, learning_rate, momentum, nesterov)
        }
        super::kind::OptimizerKindF64::AdamW {
            beta1,
            beta2,
            eps,
            weight_decay,
        } => {
            if state.len() != params.len().saturating_mul(2) {
                return Err(OptimizerError::ShapeMismatch);
            }
            let (m, v) = state.split_at_mut(params.len());
            let cfg = AdamwConfigF64 {
                learning_rate,
                step,
                beta1,
                beta2,
                eps,
                weight_decay,
            };
            step_adamw_f64(params, grads, m, v, cfg)
        }
    }
}