ghostflow_optim/
sgd.rs

1//! Stochastic Gradient Descent optimizer
2
3use ghostflow_core::Tensor;
4use crate::optimizer::Optimizer;
5
6/// SGD optimizer with optional momentum and weight decay
7pub struct SGD {
8    params: Vec<Tensor>,
9    lr: f32,
10    momentum: f32,
11    weight_decay: f32,
12    dampening: f32,
13    nesterov: bool,
14    velocity: Vec<Vec<f32>>,
15}
16
17impl SGD {
18    pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
19        let velocity = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
20        
21        SGD {
22            params,
23            lr,
24            momentum: 0.0,
25            weight_decay: 0.0,
26            dampening: 0.0,
27            nesterov: false,
28            velocity,
29        }
30    }
31
32    pub fn momentum(mut self, momentum: f32) -> Self {
33        self.momentum = momentum;
34        self
35    }
36
37    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
38        self.weight_decay = weight_decay;
39        self
40    }
41
42    pub fn dampening(mut self, dampening: f32) -> Self {
43        self.dampening = dampening;
44        self
45    }
46
47    pub fn nesterov(mut self, nesterov: bool) -> Self {
48        self.nesterov = nesterov;
49        self
50    }
51}
52
53impl Optimizer for SGD {
54    fn step(&mut self) {
55        for (i, param) in self.params.iter_mut().enumerate() {
56            if let Some(grad) = param.grad() {
57                let mut grad_data = grad.data_f32();
58                let param_data = param.data_f32();
59                
60                // Weight decay (L2 regularization)
61                if self.weight_decay != 0.0 {
62                    for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
63                        *g += self.weight_decay * p;
64                    }
65                }
66                
67                // Momentum
68                if self.momentum != 0.0 {
69                    let v = &mut self.velocity[i];
70                    
71                    for (j, g) in grad_data.iter().enumerate() {
72                        v[j] = self.momentum * v[j] + (1.0 - self.dampening) * g;
73                    }
74                    
75                    if self.nesterov {
76                        for (j, g) in grad_data.iter_mut().enumerate() {
77                            *g += self.momentum * self.velocity[i][j];
78                        }
79                    } else {
80                        grad_data = self.velocity[i].clone();
81                    }
82                }
83                
84                // Update parameters
85                let new_data: Vec<f32> = param_data.iter()
86                    .zip(grad_data.iter())
87                    .map(|(&p, &g)| p - self.lr * g)
88                    .collect();
89                
90                *param = Tensor::from_slice(&new_data, param.dims()).unwrap();
91            }
92        }
93    }
94
95    fn zero_grad(&mut self) {
96        for param in &mut self.params {
97            param.zero_grad();
98        }
99    }
100
101    fn get_lr(&self) -> f32 {
102        self.lr
103    }
104
105    fn set_lr(&mut self, lr: f32) {
106        self.lr = lr;
107    }
108
109    fn parameters(&self) -> &[Tensor] {
110        &self.params
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_sgd_step() {
120        let mut param = Tensor::ones(&[3]);
121        param.set_requires_grad(true);
122        param.set_grad(Tensor::full(&[3], 0.1f32));
123        
124        let mut sgd = SGD::new(vec![param], 0.1);
125        sgd.step();
126        
127        let updated = &sgd.params[0];
128        // 1.0 - 0.1 * 0.1 = 0.99
129        assert!((updated.data_f32()[0] - 0.99).abs() < 1e-6);
130    }
131}