use crate::{Parameters, Optimizer};
use num_traits::Float;
#[derive(Debug)]
pub struct SGD<P, Scalar>
where P : Parameters<Scalar = Scalar>
{
parameters : P,
learning_rate : Scalar
}
impl<Scalar, P : Parameters<Scalar = Scalar>> SGD<P, Scalar>
where
Scalar : Float
{
pub fn new(parameters : P, learning_rate : Scalar) -> SGD<P, Scalar> {
SGD { parameters, learning_rate}
}
}
impl<Scalar, P : Parameters<Scalar = Scalar>> Optimizer for SGD<P, Scalar>
where
Scalar : Float
{
type Para = P;
fn step(&mut self, gradients : &P) {
self.parameters.zip_mut_with(&gradients, |p,&g| *p = *p - self.learning_rate * g);
}
fn parameters(&self) -> &P {
&self.parameters
}
fn parameters_mut(&mut self) -> &mut P {
&mut self.parameters
}
fn into_parameters(self) -> P {
self.parameters
}
}
#[cfg(test)]
mod tests {
use super::*;
use tch::COptimizer;
#[test]
fn pytorch_compare() {
let init = vec![3.0, 1.0, 4.0, 1.0, 5.0];
let optimizer = SGD::new(init, 0.005);
let optimizer_torch = COptimizer::sgd(0.005, 0.0, 0.0, 0.0,false).unwrap();
assert!(crate::test_utils::compare_optimizers(optimizer, optimizer_torch));
}
}