#[derive(Debug)]
pub enum OptimizerStdError {
InvalidHyperParams,
ShapeMismatch,
StepOverflow,
}
impl From<native_neural_network::optimizers::OptimizerError> for OptimizerStdError {
fn from(e: native_neural_network::optimizers::OptimizerError) -> Self {
match e {
native_neural_network::optimizers::OptimizerError::InvalidHyperParams => {
OptimizerStdError::InvalidHyperParams
}
native_neural_network::optimizers::OptimizerError::ShapeMismatch => {
OptimizerStdError::ShapeMismatch
}
native_neural_network::optimizers::OptimizerError::StepOverflow => {
OptimizerStdError::StepOverflow
}
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum OptimizerKind {
Sgd {
momentum: f32,
nesterov: bool,
},
AdamW {
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
},
}
impl From<OptimizerKind> for native_neural_network::optimizers::OptimizerKind {
fn from(k: OptimizerKind) -> Self {
match k {
OptimizerKind::Sgd { momentum, nesterov } => {
native_neural_network::optimizers::OptimizerKind::Sgd { momentum, nesterov }
}
OptimizerKind::AdamW {
beta1,
beta2,
eps,
weight_decay,
} => native_neural_network::optimizers::OptimizerKind::AdamW {
beta1,
beta2,
eps,
weight_decay,
},
}
}
}
impl From<native_neural_network::optimizers::OptimizerKind> for OptimizerKind {
fn from(k: native_neural_network::optimizers::OptimizerKind) -> Self {
match k {
native_neural_network::optimizers::OptimizerKind::Sgd { momentum, nesterov } => {
OptimizerKind::Sgd { momentum, nesterov }
}
native_neural_network::optimizers::OptimizerKind::AdamW {
beta1,
beta2,
eps,
weight_decay,
} => OptimizerKind::AdamW {
beta1,
beta2,
eps,
weight_decay,
},
}
}
}
pub fn optimizer_state_len(kind: OptimizerKind, param_len: usize) -> Option<usize> {
native_neural_network::optimizers::optimizer_state_len(kind.into(), param_len)
}
pub fn apply_optimizer_step(
kind: OptimizerKind,
params: &mut [f32],
grads: &[f32],
state: &mut [f32],
learning_rate: f32,
step: u32,
) -> Result<(), OptimizerStdError> {
native_neural_network::optimizers::apply_optimizer_step(
kind.into(),
params,
grads,
state,
learning_rate,
step,
)
.map_err(|e| e.into())
}
pub fn step_sgd(
params: &mut [f32],
grads: &[f32],
velocity: &mut [f32],
learning_rate: f32,
momentum: f32,
nesterov: bool,
) -> Result<(), OptimizerStdError> {
let kind = OptimizerKind::Sgd { momentum, nesterov };
apply_optimizer_step(kind, params, grads, velocity, learning_rate, 1)
}
#[derive(Clone, Copy, Debug)]
pub struct AdamWParams {
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
}
pub fn step_adamw(
params: &mut [f32],
grads: &[f32],
state: &mut [f32],
learning_rate: f32,
step: u32,
opts: AdamWParams,
) -> Result<(), OptimizerStdError> {
let kind = OptimizerKind::AdamW {
beta1: opts.beta1,
beta2: opts.beta2,
eps: opts.eps,
weight_decay: opts.weight_decay,
};
apply_optimizer_step(kind, params, grads, state, learning_rate, step)
}
impl core::fmt::Display for OptimizerStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "OptimizerStdError::{:?}", self)
}
}
impl std::error::Error for OptimizerStdError {}