1use ghostflow_core::Tensor;
4use crate::optimizer::Optimizer;
5
6pub struct SGD {
8 params: Vec<Tensor>,
9 lr: f32,
10 momentum: f32,
11 weight_decay: f32,
12 dampening: f32,
13 nesterov: bool,
14 velocity: Vec<Vec<f32>>,
15}
16
17impl SGD {
18 pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
19 let velocity = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
20
21 SGD {
22 params,
23 lr,
24 momentum: 0.0,
25 weight_decay: 0.0,
26 dampening: 0.0,
27 nesterov: false,
28 velocity,
29 }
30 }
31
32 pub fn momentum(mut self, momentum: f32) -> Self {
33 self.momentum = momentum;
34 self
35 }
36
37 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
38 self.weight_decay = weight_decay;
39 self
40 }
41
42 pub fn dampening(mut self, dampening: f32) -> Self {
43 self.dampening = dampening;
44 self
45 }
46
47 pub fn nesterov(mut self, nesterov: bool) -> Self {
48 self.nesterov = nesterov;
49 self
50 }
51}
52
53impl Optimizer for SGD {
54 fn step(&mut self) {
55 for (i, param) in self.params.iter_mut().enumerate() {
56 if let Some(grad) = param.grad() {
57 let mut grad_data = grad.data_f32();
58 let param_data = param.data_f32();
59
60 if self.weight_decay != 0.0 {
62 for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
63 *g += self.weight_decay * p;
64 }
65 }
66
67 if self.momentum != 0.0 {
69 let v = &mut self.velocity[i];
70
71 for (j, g) in grad_data.iter().enumerate() {
72 v[j] = self.momentum * v[j] + (1.0 - self.dampening) * g;
73 }
74
75 if self.nesterov {
76 for (j, g) in grad_data.iter_mut().enumerate() {
77 *g += self.momentum * self.velocity[i][j];
78 }
79 } else {
80 grad_data = self.velocity[i].clone();
81 }
82 }
83
84 let new_data: Vec<f32> = param_data.iter()
86 .zip(grad_data.iter())
87 .map(|(&p, &g)| p - self.lr * g)
88 .collect();
89
90 *param = Tensor::from_slice(&new_data, param.dims()).unwrap();
91 }
92 }
93 }
94
95 fn zero_grad(&mut self) {
96 for param in &mut self.params {
97 param.zero_grad();
98 }
99 }
100
101 fn get_lr(&self) -> f32 {
102 self.lr
103 }
104
105 fn set_lr(&mut self, lr: f32) {
106 self.lr = lr;
107 }
108
109 fn parameters(&self) -> &[Tensor] {
110 &self.params
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_sgd_step() {
120 let mut param = Tensor::ones(&[3]);
121 param.set_requires_grad(true);
122 param.set_grad(Tensor::full(&[3], 0.1f32));
123
124 let mut sgd = SGD::new(vec![param], 0.1);
125 sgd.step();
126
127 let updated = &sgd.params[0];
128 assert!((updated.data_f32()[0] - 0.99).abs() < 1e-6);
130 }
131}