1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
use std::collections::HashMap;

use rai_core::Tensor;

use super::Optimizer;

pub struct SDG {
    params: Vec<Tensor>,
    lr: f32,
    momentum: Option<f32>,
    weight_decay: Option<f32>,
    dampening: Option<f32>,
    nesterov: Option<f32>,
    state: HashMap<usize, Tensor>,
}

impl SDG {
    pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
        Self {
            params,
            lr,
            momentum: None,
            weight_decay: None,
            dampening: None,
            nesterov: None,
            state: HashMap::new(),
        }
    }

    pub fn with_momentum(mut self, momentum: f32) -> Self {
        self.momentum = Some(momentum);
        self
    }

    pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
        self.weight_decay = Some(weight_decay);
        self
    }

    pub fn with_dampening(mut self, dampening: f32) -> Self {
        self.dampening = Some(dampening);
        self
    }

    pub fn with_nesterov(mut self, nesterov: f32) -> Self {
        self.nesterov = Some(nesterov);
        self
    }
}

impl Optimizer for SDG {
    fn step(&mut self, grads: &HashMap<usize, Tensor>) -> HashMap<usize, Tensor> {
        let mut new_params = HashMap::new();
        for p in self.params.iter() {
            let id = p.id();
            let mut g: Tensor = grads.get(&id).cloned().unwrap();
            let new_p = match self.momentum {
                Some(momentum) => {
                    let mut v: Tensor =
                        self.state.get(&id).cloned().unwrap_or(g.zeros_like()) * momentum;
                    if let Some(weight_decay) = self.weight_decay {
                        g = &g + p * weight_decay;
                    }

                    match self.dampening {
                        Some(dampening) => {
                            v = v + &g * (1.0 - dampening);
                        }
                        None => {
                            v = v + &g;
                        }
                    }

                    let new_p = match self.nesterov {
                        Some(nesterov) => p - &v * self.lr * nesterov,
                        None => p - &v * self.lr,
                    };

                    self.state.insert(id, v);
                    new_p
                }
                None => p - self.lr * g,
            };

            new_params.insert(id, new_p);
        }
        new_params
    }
}