ghostflow_optim/
adam.rs

1//! Adam and AdamW optimizers
2
3use ghostflow_core::Tensor;
4use crate::optimizer::Optimizer;
5
6/// Adam optimizer
7pub struct Adam {
8    params: Vec<Tensor>,
9    lr: f32,
10    betas: (f32, f32),
11    eps: f32,
12    weight_decay: f32,
13    m: Vec<Vec<f32>>,  // First moment
14    v: Vec<Vec<f32>>,  // Second moment
15    t: usize,          // Time step
16}
17
18impl Adam {
19    pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
20        let m = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
21        let v = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
22        
23        Adam {
24            params,
25            lr,
26            betas: (0.9, 0.999),
27            eps: 1e-8,
28            weight_decay: 0.0,
29            m,
30            v,
31            t: 0,
32        }
33    }
34
35    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
36        self.betas = (beta1, beta2);
37        self
38    }
39
40    pub fn eps(mut self, eps: f32) -> Self {
41        self.eps = eps;
42        self
43    }
44
45    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
46        self.weight_decay = weight_decay;
47        self
48    }
49}
50
51impl Optimizer for Adam {
52    fn step(&mut self) {
53        self.t += 1;
54        let (beta1, beta2) = self.betas;
55        
56        // Bias correction
57        let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
58        let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
59        
60        for (i, param) in self.params.iter_mut().enumerate() {
61            if let Some(grad) = param.grad() {
62                let mut grad_data = grad.data_f32();
63                let param_data = param.data_f32();
64                
65                // L2 regularization (not decoupled)
66                if self.weight_decay != 0.0 {
67                    for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
68                        *g += self.weight_decay * p;
69                    }
70                }
71                
72                // Update biased first moment estimate
73                for (j, &g) in grad_data.iter().enumerate() {
74                    self.m[i][j] = beta1 * self.m[i][j] + (1.0 - beta1) * g;
75                }
76                
77                // Update biased second moment estimate
78                for (j, &g) in grad_data.iter().enumerate() {
79                    self.v[i][j] = beta2 * self.v[i][j] + (1.0 - beta2) * g * g;
80                }
81                
82                // Compute bias-corrected estimates and update
83                let new_data: Vec<f32> = param_data.iter()
84                    .enumerate()
85                    .map(|(j, &p)| {
86                        let m_hat = self.m[i][j] / bias_correction1;
87                        let v_hat = self.v[i][j] / bias_correction2;
88                        p - self.lr * m_hat / (v_hat.sqrt() + self.eps)
89                    })
90                    .collect();
91                
92                *param = Tensor::from_slice(&new_data, param.dims()).unwrap();
93            }
94        }
95    }
96
97    fn zero_grad(&mut self) {
98        for param in &mut self.params {
99            param.zero_grad();
100        }
101    }
102
103    fn get_lr(&self) -> f32 {
104        self.lr
105    }
106
107    fn set_lr(&mut self, lr: f32) {
108        self.lr = lr;
109    }
110
111    fn parameters(&self) -> &[Tensor] {
112        &self.params
113    }
114}
115
116/// AdamW optimizer (decoupled weight decay)
117pub struct AdamW {
118    params: Vec<Tensor>,
119    lr: f32,
120    betas: (f32, f32),
121    eps: f32,
122    weight_decay: f32,
123    m: Vec<Vec<f32>>,
124    v: Vec<Vec<f32>>,
125    t: usize,
126}
127
128impl AdamW {
129    pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
130        let m = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
131        let v = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
132        
133        AdamW {
134            params,
135            lr,
136            betas: (0.9, 0.999),
137            eps: 1e-8,
138            weight_decay: 0.01,  // Default for AdamW
139            m,
140            v,
141            t: 0,
142        }
143    }
144
145    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
146        self.betas = (beta1, beta2);
147        self
148    }
149
150    pub fn eps(mut self, eps: f32) -> Self {
151        self.eps = eps;
152        self
153    }
154
155    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
156        self.weight_decay = weight_decay;
157        self
158    }
159}
160
161impl Optimizer for AdamW {
162    fn step(&mut self) {
163        self.t += 1;
164        let (beta1, beta2) = self.betas;
165        
166        let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
167        let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
168        
169        for (i, param) in self.params.iter_mut().enumerate() {
170            if let Some(grad) = param.grad() {
171                let grad_data = grad.data_f32();
172                let param_data = param.data_f32();
173                
174                // Update moments (without weight decay in gradient)
175                for (j, &g) in grad_data.iter().enumerate() {
176                    self.m[i][j] = beta1 * self.m[i][j] + (1.0 - beta1) * g;
177                    self.v[i][j] = beta2 * self.v[i][j] + (1.0 - beta2) * g * g;
178                }
179                
180                // Update with decoupled weight decay
181                let new_data: Vec<f32> = param_data.iter()
182                    .enumerate()
183                    .map(|(j, &p)| {
184                        let m_hat = self.m[i][j] / bias_correction1;
185                        let v_hat = self.v[i][j] / bias_correction2;
186                        
187                        // Decoupled weight decay
188                        let p_decayed = p * (1.0 - self.lr * self.weight_decay);
189                        
190                        p_decayed - self.lr * m_hat / (v_hat.sqrt() + self.eps)
191                    })
192                    .collect();
193                
194                *param = Tensor::from_slice(&new_data, param.dims()).unwrap();
195            }
196        }
197    }
198
199    fn zero_grad(&mut self) {
200        for param in &mut self.params {
201            param.zero_grad();
202        }
203    }
204
205    fn get_lr(&self) -> f32 {
206        self.lr
207    }
208
209    fn set_lr(&mut self, lr: f32) {
210        self.lr = lr;
211    }
212
213    fn parameters(&self) -> &[Tensor] {
214        &self.params
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_adam_step() {
224        let mut param = Tensor::ones(&[3]);
225        param.set_requires_grad(true);
226        param.set_grad(Tensor::full(&[3], 0.1f32));
227        
228        let mut adam = Adam::new(vec![param], 0.001);
229        
230        // Multiple steps
231        for _ in 0..10 {
232            adam.step();
233        }
234        
235        // Parameters should have changed
236        let updated = &adam.params[0];
237        assert!(updated.data_f32()[0] < 1.0);
238    }
239}