use super::adamw::step_adamw;
use super::kind::{OptimizerError, OptimizerKind};
use super::sgd::step_sgd;
pub fn optimizer_state_len(kind: OptimizerKind, param_count: usize) -> Option<usize> {
match kind {
OptimizerKind::Sgd { .. } => Some(param_count),
OptimizerKind::AdamW { .. } => param_count.checked_mul(2),
}
}
pub fn apply_optimizer_step(
kind: OptimizerKind,
params: &mut [f32],
grads: &[f32],
state: &mut [f32],
learning_rate: f32,
step: u32,
) -> Result<(), OptimizerError> {
if !learning_rate.is_finite() || learning_rate <= 0.0 {
return Err(OptimizerError::InvalidHyperParams);
}
if params.len() != grads.len() {
return Err(OptimizerError::ShapeMismatch);
}
match kind {
OptimizerKind::Sgd { momentum, nesterov } => {
step_sgd(params, grads, state, learning_rate, momentum, nesterov)
}
OptimizerKind::AdamW { beta1, beta2, eps, weight_decay } => {
if state.len() != params.len().saturating_mul(2) {
return Err(OptimizerError::ShapeMismatch);
}
let (m, v) = state.split_at_mut(params.len());
let cfg = super::adamw::AdamwConfig {
learning_rate,
step,
beta1,
beta2,
eps,
weight_decay,
};
step_adamw(params, grads, m, v, cfg)
}
}
}