ferrite/network/optimizer/
gd.rs

1use super::optimizer::*;
2use crate::tensor::*;
3use std::{collections::HashMap, sync::{Arc, RwLock}};
4
5pub struct SGD {
6  model_params: HashMap<String, Arc<RwLock<Tensor>>>,
7  lr: f32,
8  momentum: f32
9}
10
11impl SGD {
12  pub fn new(model_params: HashMap<String, Arc<RwLock<Tensor>>>, lr: f32, momentum: f32) -> Self {
13    Self{ model_params, lr, momentum }
14  }
15}
16
17impl OptimizerTrait for SGD {
18  fn step(&self) {
19    for (key, value) in self.model_params.iter() {
20      let mut tensor = value.write().unwrap();
21
22      let mut temp = tensor.grad().unwrap();
23      let grad = temp.borrow();
24
25      let mut storage = tensor.tensor_mut();
26      
27      storage.sub_tensor_assign(&(&*grad * self.lr));
28    }
29  }
30}