ferrite/network/optimizer/
gd.rs1use super::optimizer::*;
2use crate::tensor::*;
3use std::{collections::HashMap, sync::{Arc, RwLock}};
4
5pub struct SGD {
6 model_params: HashMap<String, Arc<RwLock<Tensor>>>,
7 lr: f32,
8 momentum: f32
9}
10
11impl SGD {
12 pub fn new(model_params: HashMap<String, Arc<RwLock<Tensor>>>, lr: f32, momentum: f32) -> Self {
13 Self{ model_params, lr, momentum }
14 }
15}
16
17impl OptimizerTrait for SGD {
18 fn step(&self) {
19 for (key, value) in self.model_params.iter() {
20 let mut tensor = value.write().unwrap();
21
22 let mut temp = tensor.grad().unwrap();
23 let grad = temp.borrow();
24
25 let mut storage = tensor.tensor_mut();
26
27 storage.sub_tensor_assign(&(&*grad * self.lr));
28 }
29 }
30}