use super::kind::OptimizerError;
pub fn step_sgd_f32(
params: &mut [f32],
grads: &[f32],
velocity: &mut [f32],
learning_rate: f32,
momentum: f32,
nesterov: bool,
) -> Result<(), OptimizerError> {
if !momentum.is_finite() || !(0.0..1.0).contains(&momentum) {
return Err(OptimizerError::InvalidHyperParams);
}
if params.len() != grads.len() || velocity.len() != params.len() {
return Err(OptimizerError::ShapeMismatch);
}
if crate::engine::try_invoke_gpu_sgd_f32(
params,
grads,
velocity,
learning_rate,
momentum,
nesterov,
) {
return Ok(());
}
for i in 0..params.len() {
let v = momentum * velocity[i] + grads[i];
velocity[i] = v;
let update = if nesterov { momentum * v + grads[i] } else { v };
params[i] -= learning_rate * update;
}
Ok(())
}
pub fn step_sgd(
params: &mut [f32],
grads: &[f32],
velocity: &mut [f32],
learning_rate: f32,
momentum: f32,
nesterov: bool,
) -> Result<(), OptimizerError> {
step_sgd_f32(params, grads, velocity, learning_rate, momentum, nesterov)
}
pub fn step_sgd_f64(
params: &mut [f64],
grads: &[f64],
velocity: &mut [f64],
learning_rate: f64,
momentum: f64,
nesterov: bool,
) -> Result<(), OptimizerError> {
if !momentum.is_finite() || !(0.0..1.0).contains(&momentum) {
return Err(OptimizerError::InvalidHyperParams);
}
if params.len() != grads.len() || velocity.len() != params.len() {
return Err(OptimizerError::ShapeMismatch);
}
if crate::engine::try_invoke_gpu_sgd_f64(
params,
grads,
velocity,
learning_rate,
momentum,
nesterov,
) {
return Ok(());
}
for i in 0..params.len() {
let v = momentum * velocity[i] + grads[i];
velocity[i] = v;
let update = if nesterov { momentum * v + grads[i] } else { v };
params[i] -= learning_rate * update;
}
Ok(())
}