1use zenu_layer::Parameters;
2use zenu_matrix::{device::Device, num::Num};
3
4use crate::Optimizer;
5
6pub struct SGD<T: Num, D: Device> {
7 pub learning_rate: T,
8 _device: std::marker::PhantomData<D>,
9}
10
11impl<T: Num, D: Device> SGD<T, D> {
12 pub fn new(learning_rate: T) -> Self {
13 Self {
14 learning_rate,
15 _device: std::marker::PhantomData,
16 }
17 }
18}
19
20impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for SGD<T, D> {
21 fn update(&self, parameters: &P) {
22 for data in parameters.parameters().values() {
23 if let Some(grad) = data.get_grad() {
24 let update_data = grad.get_data().to_ref() * self.learning_rate;
25 let mut data = data.get_data_mut();
26 let mut data = data.to_ref_mut();
27 data -= update_data;
28 }
29 }
30 }
31}