use super::adamw::{step_adamw_f32, step_adamw_f64, AdamwConfigF64};
use super::kind::{OptimizerError, OptimizerKind};
use super::sgd::{step_sgd, step_sgd_f64};
pub fn optimizer_state_len(kind: OptimizerKind, param_count: usize) -> Option<usize> {
match kind {
OptimizerKind::Sgd { .. } => Some(param_count),
OptimizerKind::AdamW { .. } => param_count.checked_mul(2),
}
}
pub fn optimizer_state_len_f64(
kind: super::kind::OptimizerKindF64,
param_count: usize,
) -> Option<usize> {
match kind {
super::kind::OptimizerKindF64::Sgd { .. } => Some(param_count),
super::kind::OptimizerKindF64::AdamW { .. } => param_count.checked_mul(2),
}
}
pub fn apply_optimizer_step(
kind: OptimizerKind,
params: &mut [f32],
grads: &[f32],
state: &mut [f32],
learning_rate: f32,
step: u32,
) -> Result<(), OptimizerError> {
if !learning_rate.is_finite() || learning_rate <= 0.0 {
return Err(OptimizerError::InvalidHyperParams);
}
if params.len() != grads.len() {
return Err(OptimizerError::ShapeMismatch);
}
match kind {
OptimizerKind::Sgd { momentum, nesterov } => {
step_sgd(params, grads, state, learning_rate, momentum, nesterov)
}
OptimizerKind::AdamW {
beta1,
beta2,
eps,
weight_decay,
} => {
if state.len() != params.len().saturating_mul(2) {
return Err(OptimizerError::ShapeMismatch);
}
let (m, v) = state.split_at_mut(params.len());
let cfg = super::adamw::AdamwConfig {
learning_rate,
step,
beta1,
beta2,
eps,
weight_decay,
};
step_adamw_f32(params, grads, m, v, cfg)
}
}
}
pub fn apply_optimizer_step_f64(
kind: super::kind::OptimizerKindF64,
params: &mut [f64],
grads: &[f64],
state: &mut [f64],
learning_rate: f64,
step: u32,
) -> Result<(), OptimizerError> {
if !learning_rate.is_finite() || learning_rate <= 0.0 {
return Err(OptimizerError::InvalidHyperParams);
}
if params.len() != grads.len() {
return Err(OptimizerError::ShapeMismatch);
}
match kind {
super::kind::OptimizerKindF64::Sgd { momentum, nesterov } => {
step_sgd_f64(params, grads, state, learning_rate, momentum, nesterov)
}
super::kind::OptimizerKindF64::AdamW {
beta1,
beta2,
eps,
weight_decay,
} => {
if state.len() != params.len().saturating_mul(2) {
return Err(OptimizerError::ShapeMismatch);
}
let (m, v) = state.split_at_mut(params.len());
let cfg = AdamwConfigF64 {
learning_rate,
step,
beta1,
beta2,
eps,
weight_decay,
};
step_adamw_f64(params, grads, m, v, cfg)
}
}
}