ferrite-dl 0.2.0

Deep learning library written in pure Rust
Documentation
use super::optimizer::*;
use crate::tensor::*;
use std::{collections::HashMap, sync::{Arc, RwLock}};

pub struct SGD {
  model_params: HashMap<String, Arc<RwLock<Tensor>>>,
  lr: f32,
  momentum: f32
}

impl SGD {
  pub fn new(model_params: HashMap<String, Arc<RwLock<Tensor>>>, lr: f32, momentum: f32) -> Self {
    Self{ model_params, lr, momentum }
  }
}

impl OptimizerTrait for SGD {
  fn step(&self) {
    for (key, value) in self.model_params.iter() {
      let mut tensor = value.write().unwrap();

      let mut temp = tensor.grad().unwrap();
      let grad = temp.borrow();

      let mut storage = tensor.tensor_mut();
      
      storage.sub_tensor_assign(&(&*grad * self.lr));
    }
  }
}