zenu_optimizer/
sgd.rs

1use zenu_layer::Parameters;
2use zenu_matrix::{device::Device, num::Num};
3
4use crate::Optimizer;
5
6pub struct SGD<T: Num, D: Device> {
7    pub learning_rate: T,
8    _device: std::marker::PhantomData<D>,
9}
10
11impl<T: Num, D: Device> SGD<T, D> {
12    pub fn new(learning_rate: T) -> Self {
13        Self {
14            learning_rate,
15            _device: std::marker::PhantomData,
16        }
17    }
18}
19
20impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for SGD<T, D> {
21    fn update(&self, parameters: &P) {
22        for data in parameters.parameters().values() {
23            if let Some(grad) = data.get_grad() {
24                let update_data = grad.get_data().to_ref() * self.learning_rate;
25                let mut data = data.get_data_mut();
26                let mut data = data.to_ref_mut();
27                data -= update_data;
28            }
29        }
30    }
31}