native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::kind::OptimizerError;

pub fn step_sgd(
    params: &mut [f32],
    grads: &[f32],
    velocity: &mut [f32],
    learning_rate: f32,
    momentum: f32,
    nesterov: bool,
) -> Result<(), OptimizerError> {
    if !momentum.is_finite() || !(0.0..1.0).contains(&momentum) {
        return Err(OptimizerError::InvalidHyperParams);
    }
    if params.len() != grads.len() || velocity.len() != params.len() {
        return Err(OptimizerError::ShapeMismatch);
    }

    for i in 0..params.len() {
        let v = momentum * velocity[i] + grads[i];
        velocity[i] = v;
        let update = if nesterov { momentum * v + grads[i] } else { v };
        params[i] -= learning_rate * update;
    }
    Ok(())
}