#![doc = include_str!("README.md")]
use super::Optimizer;
use candle_core::{Result, Tensor};
pub struct SGD {
learning_rate: f64,
}
impl SGD {
pub fn new(learning_rate: f64) -> Self {
Self { learning_rate }
}
}
impl Optimizer for SGD {
fn step(&mut self, params: &mut [&mut Tensor], grads: &[Tensor]) -> Result<()> {
assert_eq!(
params.len(),
grads.len(),
"Number of parameters and gradients must match"
);
for (param, grad) in params.iter_mut().zip(grads.iter()) {
let step = (grad * -self.learning_rate)?;
**param = param.broadcast_add(&step)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Device;
#[test]
fn test_sgd_step() -> Result<()> {
let device = Device::Cpu;
let mut optim = SGD::new(0.1);
let mut param = Tensor::new(&[1.0f32, 2.0], &device.as_candle().unwrap())?;
let grad = Tensor::new(&[0.5f32, -1.0], &device.as_candle().unwrap())?;
optim.step(&mut [&mut param], &[grad])?;
let param_vec = param.to_vec1::<f32>()?;
assert!((param_vec[0] - 0.95).abs() < 1e-5);
assert!((param_vec[1] - 2.1).abs() < 1e-5);
Ok(())
}
}