use std::collections::HashMap;
use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorBase, TensorOps};
use super::compute_graph::TensorId;
pub trait Optimizer {
fn init_param(&mut self, param_id: TensorId, param: &DenseTensor);
fn step_param(&mut self, param_id: TensorId, param: &mut DenseTensor, grad: &DenseTensor);
fn step(&mut self, params: &mut HashMap<TensorId, DenseTensor>);
}
#[derive(Debug, Clone)]
pub struct Sgd {
pub lr: f64,
pub momentum: f64,
pub weight_decay: f64,
velocity: HashMap<TensorId, DenseTensor>,
}
impl Sgd {
pub fn new(lr: f64, momentum: f64, weight_decay: f64) -> Self {
Self {
lr,
momentum,
weight_decay,
velocity: HashMap::new(),
}
}
}
impl Default for Sgd {
fn default() -> Self {
Self::new(0.01, 0.0, 0.0)
}
}
impl Optimizer for Sgd {
fn init_param(&mut self, param_id: TensorId, param: &DenseTensor) {
if self.momentum > 0.0 {
let zeros = DenseTensor::zeros(param.shape().to_vec());
self.velocity.insert(param_id, zeros);
}
}
fn step_param(&mut self, param_id: TensorId, param: &mut DenseTensor, grad: &DenseTensor) {
let mut effective_grad = grad.clone();
if self.weight_decay > 0.0 {
let decay_grad = param.scale(self.weight_decay);
effective_grad = effective_grad.add(&decay_grad);
}
if self.momentum > 0.0 {
if let Some(v) = self.velocity.get_mut(¶m_id) {
let scaled_v = v.scale(self.momentum);
*v = scaled_v.add(&effective_grad);
let update = v.scale(self.lr);
*param = param.sub(&update);
}
} else {
let update = effective_grad.scale(self.lr);
*param = param.sub(&update);
}
}
fn step(&mut self, params: &mut HashMap<TensorId, DenseTensor>) {
for (param_id, param) in params.iter_mut() {
if !self.velocity.contains_key(param_id) && self.momentum > 0.0 {
self.init_param(*param_id, param);
}
let grad = DenseTensor::zeros(param.shape().to_vec());
self.step_param(*param_id, param, &grad);
}
}
}
#[derive(Debug, Clone)]
pub struct Adam {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub epsilon: f64,
m: HashMap<TensorId, DenseTensor>,
v: HashMap<TensorId, DenseTensor>,
t: usize,
}
impl Adam {
pub fn new(lr: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
Self {
lr,
beta1,
beta2,
epsilon,
m: HashMap::new(),
v: HashMap::new(),
t: 0,
}
}
pub fn set_lr(&mut self, lr: f64) {
self.lr = lr;
}
}
impl Default for Adam {
fn default() -> Self {
Self::new(0.001, 0.9, 0.999, 1e-8)
}
}
impl Optimizer for Adam {
fn init_param(&mut self, param_id: TensorId, param: &DenseTensor) {
let zeros = DenseTensor::zeros(param.shape().to_vec());
self.m.insert(param_id, zeros.clone());
self.v.insert(param_id, zeros);
}
fn step_param(&mut self, param_id: TensorId, param: &mut DenseTensor, grad: &DenseTensor) {
if !self.m.contains_key(¶m_id) {
self.init_param(param_id, param);
}
if let (Some(m), Some(v)) = (self.m.get_mut(¶m_id), self.v.get_mut(¶m_id)) {
let grad_scaled = grad.scale(1.0 - self.beta1);
let m_scaled = m.scale(self.beta1);
*m = m_scaled.add(&grad_scaled);
let grad_squared = grad.mul(grad);
let grad_squared_scaled = grad_squared.scale(1.0 - self.beta2);
let v_scaled = v.scale(self.beta2);
*v = v_scaled.add(&grad_squared_scaled);
let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
let m_hat = m.scale(1.0 / bias_correction1);
let v_hat = v.scale(1.0 / bias_correction2);
let sqrt_v = v_hat.sqrt().add(&DenseTensor::full(v_hat.shape(), self.epsilon));
let update = m_hat.div(&sqrt_v).scale(self.lr);
*param = param.sub(&update);
}
}
fn step(&mut self, params: &mut HashMap<TensorId, DenseTensor>) {
self.t += 1;
for (param_id, param) in params.iter_mut() {
if !self.m.contains_key(param_id) {
self.init_param(*param_id, param);
}
let grad = DenseTensor::zeros(param.shape().to_vec());
self.step_param(*param_id, param, &grad);
}
}
}
#[derive(Debug, Clone)]
pub struct AdamW {
pub adam: Adam,
pub weight_decay: f64,
}
impl AdamW {
pub fn new(lr: f64, beta1: f64, beta2: f64, epsilon: f64, weight_decay: f64) -> Self {
Self {
adam: Adam::new(lr, beta1, beta2, epsilon),
weight_decay,
}
}
}
impl Default for AdamW {
fn default() -> Self {
Self::new(0.001, 0.9, 0.999, 1e-8, 0.01)
}
}
impl Optimizer for AdamW {
fn init_param(&mut self, param_id: TensorId, param: &DenseTensor) {
self.adam.init_param(param_id, param);
}
fn step_param(&mut self, param_id: TensorId, param: &mut DenseTensor, grad: &DenseTensor) {
if self.weight_decay > 0.0 {
let decay = param.scale(self.weight_decay * self.adam.lr);
*param = param.sub(&decay);
}
self.adam.step_param(param_id, param, grad);
}
fn step(&mut self, params: &mut HashMap<TensorId, DenseTensor>) {
self.adam.step(params);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sgd_basic() {
let mut optimizer = Sgd::new(0.01, 0.0, 0.0);
let mut params = HashMap::new();
let param_id = TensorId(0);
let param = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
params.insert(param_id, param);
optimizer.step(&mut params);
let updated = params.get(¶m_id).unwrap();
assert_eq!(updated.data(), &vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_adam_basic() {
let mut optimizer = Adam::default();
let mut params = HashMap::new();
let param_id = TensorId(0);
let param = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
params.insert(param_id, param);
optimizer.step(&mut params);
let updated = params.get(¶m_id).unwrap();
assert_eq!(updated.data(), &vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_sgd_with_momentum() {
let mut optimizer = Sgd::new(0.01, 0.9, 0.0);
let param_id = TensorId(0);
let param = DenseTensor::new(vec![1.0, 2.0], vec![1, 2]);
optimizer.init_param(param_id, ¶m);
assert!(optimizer.velocity.contains_key(¶m_id));
}
}