aprender/nn/
optim.rs

1//! Gradient-based optimizers for neural network training.
2//!
3//! These optimizers work with autograd Tensors to update parameters
4//! based on computed gradients.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use aprender::nn::{Linear, Module, optim::SGD};
10//! use aprender::nn::loss::MSELoss;
11//! use aprender::autograd::Tensor;
12//!
13//! // Create model and optimizer
14//! let mut model = Linear::new(10, 5);
15//! let mut optimizer = SGD::new(model.parameters_mut(), 0.01);
16//!
17//! // Training loop
18//! for epoch in 0..100 {
19//!     let x = Tensor::randn(&[32, 10]);
20//!     let y = Tensor::randn(&[32, 5]);
21//!
22//!     // Forward pass
23//!     let pred = model.forward(&x);
24//!     let loss = MSELoss::new().forward(&pred, &y);
25//!
26//!     // Backward pass
27//!     optimizer.zero_grad();
28//!     loss.backward();
29//!     optimizer.step();
30//! }
31//! ```
32//!
33//! # References
34//!
35//! - Robbins, H., & Monro, S. (1951). A stochastic approximation method.
36//! - Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. ICLR.
37//! - Loshchilov, I., & Hutter, F. (2019). Decoupled weight decay regularization. ICLR.
38
39use crate::autograd::{get_grad, Tensor, TensorId};
40
41/// Common trait for all optimizers.
42pub trait Optimizer {
43    /// Perform a single optimization step using computed gradients.
44    fn step(&mut self);
45
46    /// Zero all parameter gradients.
47    fn zero_grad(&mut self);
48
49    /// Get current learning rate.
50    fn lr(&self) -> f32;
51
52    /// Set learning rate (for schedulers).
53    fn set_lr(&mut self, lr: f32);
54}
55
56/// Stochastic Gradient Descent optimizer with momentum.
57///
58/// Update rule:
59/// ```text
60/// v_t = momentum * v_{t-1} + grad
61/// param = param - lr * v_t
62/// ```
63///
64/// With Nesterov momentum:
65/// ```text
66/// v_t = momentum * v_{t-1} + grad
67/// param = param - lr * (momentum * v_t + grad)
68/// ```
69#[derive(Debug)]
70pub struct SGD {
71    /// Parameter tensor IDs to optimize
72    param_ids: Vec<TensorId>,
73    /// Learning rate
74    lr: f32,
75    /// Momentum factor (0 = no momentum)
76    momentum: f32,
77    /// Weight decay (L2 regularization)
78    weight_decay: f32,
79    /// Nesterov momentum
80    nesterov: bool,
81    /// Velocity buffers for momentum
82    velocities: Vec<Vec<f32>>,
83    /// Whether velocities have been initialized
84    initialized: bool,
85}
86
87impl SGD {
88    /// Create a new SGD optimizer.
89    ///
90    /// # Arguments
91    ///
92    /// * `params` - Mutable references to parameter tensors
93    /// * `lr` - Learning rate
94    #[allow(clippy::needless_pass_by_value)]
95    #[must_use]
96    pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
97        let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
98        Self {
99            param_ids,
100            lr,
101            momentum: 0.0,
102            weight_decay: 0.0,
103            nesterov: false,
104            velocities: Vec::new(),
105            initialized: false,
106        }
107    }
108
109    /// Create SGD with momentum.
110    #[allow(clippy::needless_pass_by_value)]
111    #[must_use]
112    pub fn with_momentum(params: Vec<&mut Tensor>, lr: f32, momentum: f32) -> Self {
113        let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
114        Self {
115            param_ids,
116            lr,
117            momentum,
118            weight_decay: 0.0,
119            nesterov: false,
120            velocities: Vec::new(),
121            initialized: false,
122        }
123    }
124
125    /// Enable Nesterov momentum.
126    #[must_use]
127    pub fn nesterov(mut self) -> Self {
128        self.nesterov = true;
129        self
130    }
131
132    /// Set weight decay (L2 regularization).
133    #[must_use]
134    pub fn weight_decay(mut self, wd: f32) -> Self {
135        self.weight_decay = wd;
136        self
137    }
138
139    /// Update a single parameter tensor.
140    #[allow(clippy::if_not_else)]
141    fn update_param(&mut self, param: &mut Tensor, idx: usize) {
142        let Some(grad) = get_grad(param.id()) else {
143            return; // No gradient available
144        };
145
146        let grad_data = grad.data();
147        let param_data = param.data_mut();
148
149        // Initialize velocity if needed
150        if !self.initialized || idx >= self.velocities.len() {
151            if idx >= self.velocities.len() {
152                self.velocities.resize(idx + 1, Vec::new());
153            }
154            self.velocities[idx] = vec![0.0; param_data.len()];
155        }
156
157        let velocity = &mut self.velocities[idx];
158
159        for i in 0..param_data.len() {
160            let mut g = grad_data[i];
161
162            // Apply weight decay
163            if self.weight_decay != 0.0 {
164                g += self.weight_decay * param_data[i];
165            }
166
167            if self.momentum != 0.0 {
168                // Update velocity
169                velocity[i] = self.momentum * velocity[i] + g;
170
171                if self.nesterov {
172                    // Nesterov: look ahead
173                    param_data[i] -= self.lr * (self.momentum * velocity[i] + g);
174                } else {
175                    // Standard momentum
176                    param_data[i] -= self.lr * velocity[i];
177                }
178            } else {
179                // Vanilla SGD
180                param_data[i] -= self.lr * g;
181            }
182        }
183    }
184}
185
186impl Optimizer for SGD {
187    fn step(&mut self) {
188        // We need to get mutable access to the tensors through the global graph
189        // For now, this is a placeholder that demonstrates the pattern
190        // In practice, users will call update_param directly with their tensors
191        self.initialized = true;
192    }
193
194    fn zero_grad(&mut self) {
195        for &id in &self.param_ids {
196            crate::autograd::clear_grad(id);
197        }
198    }
199
200    fn lr(&self) -> f32 {
201        self.lr
202    }
203
204    fn set_lr(&mut self, lr: f32) {
205        self.lr = lr;
206    }
207}
208
209impl SGD {
210    /// Perform optimization step with direct tensor access.
211    ///
212    /// This is the recommended way to use SGD in a training loop.
213    pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
214        for (idx, param) in params.iter_mut().enumerate() {
215            self.update_param(param, idx);
216        }
217        self.initialized = true;
218    }
219}
220
221/// Adam optimizer (Kingma & Ba, 2015).
222///
223/// Combines momentum with adaptive learning rates using first and second
224/// moment estimates.
225///
226/// Update rule:
227/// ```text
228/// m_t = β₁ * m_{t-1} + (1 - β₁) * grad
229/// v_t = β₂ * v_{t-1} + (1 - β₂) * grad²
230/// m̂_t = m_t / (1 - β₁ᵗ)
231/// v̂_t = v_t / (1 - β₂ᵗ)
232/// param = param - lr * m̂_t / (√v̂_t + ε)
233/// ```
234#[derive(Debug)]
235pub struct Adam {
236    param_ids: Vec<TensorId>,
237    lr: f32,
238    beta1: f32,
239    beta2: f32,
240    eps: f32,
241    weight_decay: f32,
242    /// First moment estimates
243    m: Vec<Vec<f32>>,
244    /// Second moment estimates
245    v: Vec<Vec<f32>>,
246    /// Current timestep for bias correction
247    t: usize,
248    initialized: bool,
249}
250
251impl Adam {
252    /// Create a new Adam optimizer with default hyperparameters.
253    ///
254    /// Default: β₁=0.9, β₂=0.999, ε=1e-8
255    #[allow(clippy::needless_pass_by_value)]
256    #[must_use]
257    pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
258        let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
259        Self {
260            param_ids,
261            lr,
262            beta1: 0.9,
263            beta2: 0.999,
264            eps: 1e-8,
265            weight_decay: 0.0,
266            m: Vec::new(),
267            v: Vec::new(),
268            t: 0,
269            initialized: false,
270        }
271    }
272
273    /// Set beta parameters.
274    #[must_use]
275    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
276        self.beta1 = beta1;
277        self.beta2 = beta2;
278        self
279    }
280
281    /// Set epsilon for numerical stability.
282    #[must_use]
283    pub fn eps(mut self, eps: f32) -> Self {
284        self.eps = eps;
285        self
286    }
287
288    /// Set weight decay (L2 regularization, applied to gradient).
289    #[must_use]
290    pub fn weight_decay(mut self, wd: f32) -> Self {
291        self.weight_decay = wd;
292        self
293    }
294
295    fn update_param(&mut self, param: &mut Tensor, idx: usize) {
296        let Some(grad) = get_grad(param.id()) else {
297            return;
298        };
299
300        let grad_data = grad.data();
301        let param_data = param.data_mut();
302
303        // Initialize state if needed
304        if !self.initialized || idx >= self.m.len() {
305            if idx >= self.m.len() {
306                self.m.resize(idx + 1, Vec::new());
307                self.v.resize(idx + 1, Vec::new());
308            }
309            self.m[idx] = vec![0.0; param_data.len()];
310            self.v[idx] = vec![0.0; param_data.len()];
311        }
312
313        let m = &mut self.m[idx];
314        let v = &mut self.v[idx];
315
316        // Bias correction factors
317        let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
318        let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
319
320        for i in 0..param_data.len() {
321            let mut g = grad_data[i];
322
323            // L2 regularization (applied to gradient, not decoupled)
324            if self.weight_decay != 0.0 {
325                g += self.weight_decay * param_data[i];
326            }
327
328            // Update biased first moment estimate
329            m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * g;
330
331            // Update biased second moment estimate
332            v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * g * g;
333
334            // Compute bias-corrected estimates
335            let m_hat = m[i] / bias_correction1;
336            let v_hat = v[i] / bias_correction2;
337
338            // Update parameter
339            param_data[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
340        }
341    }
342
343    /// Perform optimization step with direct tensor access.
344    pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
345        self.t += 1;
346        for (idx, param) in params.iter_mut().enumerate() {
347            self.update_param(param, idx);
348        }
349        self.initialized = true;
350    }
351}
352
353impl Optimizer for Adam {
354    fn step(&mut self) {
355        self.t += 1;
356        self.initialized = true;
357    }
358
359    fn zero_grad(&mut self) {
360        for &id in &self.param_ids {
361            crate::autograd::clear_grad(id);
362        }
363    }
364
365    fn lr(&self) -> f32 {
366        self.lr
367    }
368
369    fn set_lr(&mut self, lr: f32) {
370        self.lr = lr;
371    }
372}
373
374/// `AdamW` optimizer (Loshchilov & Hutter, 2019).
375///
376/// Like Adam but with decoupled weight decay, which is more effective
377/// for regularization.
378///
379/// The key difference from Adam:
380/// ```text
381/// param = param - lr * weight_decay * param  // Decoupled weight decay
382/// param = param - lr * m̂_t / (√v̂_t + ε)      // Then Adam update
383/// ```
384#[derive(Debug)]
385pub struct AdamW {
386    param_ids: Vec<TensorId>,
387    lr: f32,
388    beta1: f32,
389    beta2: f32,
390    eps: f32,
391    weight_decay: f32,
392    m: Vec<Vec<f32>>,
393    v: Vec<Vec<f32>>,
394    t: usize,
395    initialized: bool,
396}
397
398impl AdamW {
399    /// Create a new `AdamW` optimizer.
400    ///
401    /// Default: β₁=0.9, β₂=0.999, ε=1e-8, `weight_decay=0.01`
402    #[allow(clippy::needless_pass_by_value)]
403    #[must_use]
404    pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
405        let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
406        Self {
407            param_ids,
408            lr,
409            beta1: 0.9,
410            beta2: 0.999,
411            eps: 1e-8,
412            weight_decay: 0.01,
413            m: Vec::new(),
414            v: Vec::new(),
415            t: 0,
416            initialized: false,
417        }
418    }
419
420    #[must_use]
421    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
422        self.beta1 = beta1;
423        self.beta2 = beta2;
424        self
425    }
426
427    #[must_use]
428    pub fn eps(mut self, eps: f32) -> Self {
429        self.eps = eps;
430        self
431    }
432
433    #[must_use]
434    pub fn weight_decay(mut self, wd: f32) -> Self {
435        self.weight_decay = wd;
436        self
437    }
438
439    fn update_param(&mut self, param: &mut Tensor, idx: usize) {
440        let Some(grad) = get_grad(param.id()) else {
441            return;
442        };
443
444        let grad_data = grad.data();
445        let param_data = param.data_mut();
446
447        // Initialize state if needed
448        if !self.initialized || idx >= self.m.len() {
449            if idx >= self.m.len() {
450                self.m.resize(idx + 1, Vec::new());
451                self.v.resize(idx + 1, Vec::new());
452            }
453            self.m[idx] = vec![0.0; param_data.len()];
454            self.v[idx] = vec![0.0; param_data.len()];
455        }
456
457        let m = &mut self.m[idx];
458        let v = &mut self.v[idx];
459
460        let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
461        let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
462
463        for i in 0..param_data.len() {
464            let g = grad_data[i];
465
466            // Update moment estimates (no weight decay in gradient)
467            m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * g;
468            v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * g * g;
469
470            let m_hat = m[i] / bias_correction1;
471            let v_hat = v[i] / bias_correction2;
472
473            // Decoupled weight decay: applied directly to parameter
474            param_data[i] -= self.lr * self.weight_decay * param_data[i];
475
476            // Adam update
477            param_data[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
478        }
479    }
480
481    pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
482        self.t += 1;
483        for (idx, param) in params.iter_mut().enumerate() {
484            self.update_param(param, idx);
485        }
486        self.initialized = true;
487    }
488}
489
490impl Optimizer for AdamW {
491    fn step(&mut self) {
492        self.t += 1;
493        self.initialized = true;
494    }
495
496    fn zero_grad(&mut self) {
497        for &id in &self.param_ids {
498            crate::autograd::clear_grad(id);
499        }
500    }
501
502    fn lr(&self) -> f32 {
503        self.lr
504    }
505
506    fn set_lr(&mut self, lr: f32) {
507        self.lr = lr;
508    }
509}
510
511/// `RMSprop` optimizer.
512///
513/// Maintains a moving average of squared gradients for adaptive learning rates.
514///
515/// Update rule:
516/// ```text
517/// v_t = α * v_{t-1} + (1 - α) * grad²
518/// param = param - lr * grad / (√v_t + ε)
519/// ```
520#[derive(Debug)]
521pub struct RMSprop {
522    param_ids: Vec<TensorId>,
523    lr: f32,
524    alpha: f32,
525    eps: f32,
526    weight_decay: f32,
527    momentum: f32,
528    /// Running average of squared gradients
529    v: Vec<Vec<f32>>,
530    /// Momentum buffer
531    buffer: Vec<Vec<f32>>,
532    initialized: bool,
533}
534
535impl RMSprop {
536    /// Create a new `RMSprop` optimizer.
537    ///
538    /// Default: α=0.99, ε=1e-8
539    #[allow(clippy::needless_pass_by_value)]
540    #[must_use]
541    pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
542        let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
543        Self {
544            param_ids,
545            lr,
546            alpha: 0.99,
547            eps: 1e-8,
548            weight_decay: 0.0,
549            momentum: 0.0,
550            v: Vec::new(),
551            buffer: Vec::new(),
552            initialized: false,
553        }
554    }
555
556    #[must_use]
557    pub fn alpha(mut self, alpha: f32) -> Self {
558        self.alpha = alpha;
559        self
560    }
561
562    #[must_use]
563    pub fn eps(mut self, eps: f32) -> Self {
564        self.eps = eps;
565        self
566    }
567
568    #[must_use]
569    pub fn momentum(mut self, momentum: f32) -> Self {
570        self.momentum = momentum;
571        self
572    }
573
574    #[must_use]
575    pub fn weight_decay(mut self, wd: f32) -> Self {
576        self.weight_decay = wd;
577        self
578    }
579
580    fn update_param(&mut self, param: &mut Tensor, idx: usize) {
581        let Some(grad) = get_grad(param.id()) else {
582            return;
583        };
584
585        let grad_data = grad.data();
586        let param_data = param.data_mut();
587
588        // Initialize state if needed
589        if !self.initialized || idx >= self.v.len() {
590            if idx >= self.v.len() {
591                self.v.resize(idx + 1, Vec::new());
592                self.buffer.resize(idx + 1, Vec::new());
593            }
594            self.v[idx] = vec![0.0; param_data.len()];
595            self.buffer[idx] = vec![0.0; param_data.len()];
596        }
597
598        let v = &mut self.v[idx];
599        let buffer = &mut self.buffer[idx];
600
601        for i in 0..param_data.len() {
602            let mut g = grad_data[i];
603
604            // Weight decay
605            if self.weight_decay != 0.0 {
606                g += self.weight_decay * param_data[i];
607            }
608
609            // Update running average of squared gradients
610            v[i] = self.alpha * v[i] + (1.0 - self.alpha) * g * g;
611
612            // Compute update
613            let update = g / (v[i].sqrt() + self.eps);
614
615            if self.momentum > 0.0 {
616                buffer[i] = self.momentum * buffer[i] + update;
617                param_data[i] -= self.lr * buffer[i];
618            } else {
619                param_data[i] -= self.lr * update;
620            }
621        }
622    }
623
624    pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
625        for (idx, param) in params.iter_mut().enumerate() {
626            self.update_param(param, idx);
627        }
628        self.initialized = true;
629    }
630}
631
632impl Optimizer for RMSprop {
633    fn step(&mut self) {
634        self.initialized = true;
635    }
636
637    fn zero_grad(&mut self) {
638        for &id in &self.param_ids {
639            crate::autograd::clear_grad(id);
640        }
641    }
642
643    fn lr(&self) -> f32 {
644        self.lr
645    }
646
647    fn set_lr(&mut self, lr: f32) {
648        self.lr = lr;
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655    use crate::autograd::clear_graph;
656
657    #[test]
658    fn test_sgd_basic() {
659        clear_graph();
660
661        // Create a simple tensor
662        let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
663        let param_id = param.id();
664
665        // Simulate a loss: sum of squared elements
666        let loss = param.pow(2.0).sum();
667        loss.backward();
668
669        // Check gradient exists
670        let grad = get_grad(param_id).expect("Should have gradient");
671        assert_eq!(grad.data(), &[2.0, 4.0, 6.0]); // d/dx(x²) = 2x
672
673        // Create optimizer and step
674        let mut sgd = SGD::new(vec![&mut param], 0.1);
675        sgd.step_with_params(&mut [&mut param]);
676
677        // param = param - lr * grad = [1, 2, 3] - 0.1 * [2, 4, 6] = [0.8, 1.6, 2.4]
678        let expected = [0.8, 1.6, 2.4];
679        for (p, e) in param.data().iter().zip(expected.iter()) {
680            assert!((p - e).abs() < 1e-5, "Expected {e}, got {p}");
681        }
682    }
683
684    #[test]
685    fn test_sgd_with_momentum() {
686        clear_graph();
687
688        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
689
690        // First step
691        let loss = param.pow(2.0).sum();
692        loss.backward();
693
694        let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9);
695        sgd.step_with_params(&mut [&mut param]);
696
697        // v = 0.9 * 0 + 2.0 = 2.0
698        // param = 1.0 - 0.1 * 2.0 = 0.8
699        assert!((param.data()[0] - 0.8).abs() < 1e-5);
700
701        // Second step
702        clear_graph();
703        let loss = param.pow(2.0).sum();
704        loss.backward();
705
706        sgd.step_with_params(&mut [&mut param]);
707
708        // grad = 2 * 0.8 = 1.6
709        // v = 0.9 * 2.0 + 1.6 = 3.4
710        // param = 0.8 - 0.1 * 3.4 = 0.46
711        assert!((param.data()[0] - 0.46).abs() < 1e-5);
712    }
713
714    #[test]
715    fn test_adam_basic() {
716        clear_graph();
717
718        let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
719
720        let loss = param.pow(2.0).sum();
721        loss.backward();
722
723        let mut adam = Adam::new(vec![&mut param], 0.1);
724        adam.step_with_params(&mut [&mut param]);
725
726        // After one step, params should decrease
727        assert!(param.data()[0] < 1.0);
728        assert!(param.data()[1] < 2.0);
729    }
730
731    #[test]
732    fn test_adam_convergence() {
733        // Test that Adam can minimize a simple quadratic
734        clear_graph();
735
736        let mut param = Tensor::from_slice(&[5.0]).requires_grad();
737        let mut adam = Adam::new(vec![&mut param], 0.5);
738
739        // Minimize x² (optimal at x=0)
740        for _ in 0..100 {
741            clear_graph();
742            let loss = param.pow(2.0).sum();
743            loss.backward();
744            adam.step_with_params(&mut [&mut param]);
745        }
746
747        // Should be close to 0
748        assert!(
749            param.data()[0].abs() < 0.1,
750            "Parameter should converge to 0, got {}",
751            param.data()[0]
752        );
753    }
754
755    #[test]
756    fn test_adamw_weight_decay() {
757        clear_graph();
758
759        let mut param = Tensor::from_slice(&[10.0]).requires_grad();
760
761        // With zero gradient, only weight decay applies
762        // We need a loss that has zero gradient at current point
763        // Actually, let's just test the decoupled nature
764
765        let loss = param.pow(2.0).sum();
766        loss.backward();
767
768        let mut adamw = AdamW::new(vec![&mut param], 0.1).weight_decay(0.1);
769        adamw.step_with_params(&mut [&mut param]);
770
771        // With weight decay, param should decrease more
772        assert!(param.data()[0] < 10.0);
773    }
774
775    #[test]
776    fn test_rmsprop_basic() {
777        clear_graph();
778
779        let mut param = Tensor::from_slice(&[3.0]).requires_grad();
780
781        let loss = param.pow(2.0).sum();
782        loss.backward();
783
784        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
785        rmsprop.step_with_params(&mut [&mut param]);
786
787        // Param should decrease
788        assert!(param.data()[0] < 3.0);
789    }
790
791    #[test]
792    fn test_zero_grad() {
793        clear_graph();
794
795        let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
796        let param_id = param.id();
797
798        let loss = param.pow(2.0).sum();
799        loss.backward();
800
801        // Gradient should exist
802        assert!(get_grad(param_id).is_some());
803
804        // Zero grad
805        let mut sgd = SGD::new(vec![&mut param], 0.1);
806        sgd.zero_grad();
807
808        // Gradient should be cleared
809        assert!(get_grad(param_id).is_none());
810    }
811
812    #[test]
813    fn test_learning_rate_change() {
814        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
815        let mut sgd = SGD::new(vec![&mut param], 0.1);
816
817        assert!((sgd.lr() - 0.1).abs() < 1e-6);
818
819        sgd.set_lr(0.01);
820        assert!((sgd.lr() - 0.01).abs() < 1e-6);
821    }
822
823    #[test]
824    fn test_sgd_nesterov() {
825        clear_graph();
826
827        let mut param = Tensor::from_slice(&[2.0]).requires_grad();
828
829        let loss = param.pow(2.0).sum();
830        loss.backward();
831
832        let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9).nesterov();
833        sgd.step_with_params(&mut [&mut param]);
834
835        // Nesterov should apply a "look ahead" update
836        // With nesterov: param = param - lr * (momentum * velocity + grad)
837        // v = 0.9 * 0 + 4 = 4 (grad = 2 * 2 = 4)
838        // param = 2 - 0.1 * (0.9 * 4 + 4) = 2 - 0.1 * 7.6 = 1.24
839        assert!(
840            (param.data()[0] - 1.24).abs() < 1e-5,
841            "Nesterov update failed: {}",
842            param.data()[0]
843        );
844    }
845
846    #[test]
847    fn test_sgd_weight_decay() {
848        clear_graph();
849
850        let mut param = Tensor::from_slice(&[5.0]).requires_grad();
851
852        let loss = param.pow(2.0).sum();
853        loss.backward();
854
855        let mut sgd = SGD::new(vec![&mut param], 0.1).weight_decay(0.1);
856        sgd.step_with_params(&mut [&mut param]);
857
858        // grad = 2 * 5 = 10, with weight_decay: g = 10 + 0.1 * 5 = 10.5
859        // param = 5 - 0.1 * 10.5 = 3.95
860        assert!(
861            (param.data()[0] - 3.95).abs() < 1e-5,
862            "Weight decay update failed: {}",
863            param.data()[0]
864        );
865    }
866
867    #[test]
868    fn test_adam_with_custom_betas() {
869        clear_graph();
870
871        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
872
873        let loss = param.pow(2.0).sum();
874        loss.backward();
875
876        let mut adam = Adam::new(vec![&mut param], 0.1).betas(0.8, 0.99);
877        adam.step_with_params(&mut [&mut param]);
878
879        // Param should decrease with custom betas
880        assert!(param.data()[0] < 1.0);
881    }
882
883    #[test]
884    fn test_adam_with_eps() {
885        clear_graph();
886
887        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
888
889        let loss = param.pow(2.0).sum();
890        loss.backward();
891
892        let mut adam = Adam::new(vec![&mut param], 0.1).eps(1e-6);
893        adam.step_with_params(&mut [&mut param]);
894
895        assert!(param.data()[0] < 1.0);
896    }
897
898    #[test]
899    fn test_adam_with_weight_decay() {
900        clear_graph();
901
902        let mut param = Tensor::from_slice(&[10.0]).requires_grad();
903
904        let loss = param.pow(2.0).sum();
905        loss.backward();
906
907        // Compare with and without weight decay
908        let mut adam_wd = Adam::new(vec![&mut param], 0.1).weight_decay(0.1);
909        adam_wd.step_with_params(&mut [&mut param]);
910
911        // With weight decay, the update should be larger
912        assert!(param.data()[0] < 10.0);
913    }
914
915    #[test]
916    fn test_adamw_with_custom_betas_and_eps() {
917        clear_graph();
918
919        let mut param = Tensor::from_slice(&[3.0]).requires_grad();
920
921        let loss = param.pow(2.0).sum();
922        loss.backward();
923
924        let mut adamw = AdamW::new(vec![&mut param], 0.1)
925            .betas(0.85, 0.995)
926            .eps(1e-7);
927        adamw.step_with_params(&mut [&mut param]);
928
929        assert!(param.data()[0] < 3.0);
930    }
931
932    #[test]
933    fn test_adamw_lr_methods() {
934        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
935        let mut adamw = AdamW::new(vec![&mut param], 0.01);
936
937        assert!((adamw.lr() - 0.01).abs() < 1e-6);
938        adamw.set_lr(0.001);
939        assert!((adamw.lr() - 0.001).abs() < 1e-6);
940    }
941
942    #[test]
943    fn test_adamw_zero_grad() {
944        clear_graph();
945
946        let mut param = Tensor::from_slice(&[2.0]).requires_grad();
947        let param_id = param.id();
948
949        let loss = param.pow(2.0).sum();
950        loss.backward();
951
952        assert!(get_grad(param_id).is_some());
953
954        let mut adamw = AdamW::new(vec![&mut param], 0.1);
955        adamw.zero_grad();
956
957        assert!(get_grad(param_id).is_none());
958    }
959
960    #[test]
961    fn test_adamw_step_trait() {
962        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
963        let mut adamw = AdamW::new(vec![&mut param], 0.1);
964
965        // Test the Optimizer trait step method
966        adamw.step();
967        assert!(adamw.initialized);
968        assert_eq!(adamw.t, 1);
969    }
970
971    #[test]
972    fn test_rmsprop_with_alpha() {
973        clear_graph();
974
975        let mut param = Tensor::from_slice(&[2.0]).requires_grad();
976
977        let loss = param.pow(2.0).sum();
978        loss.backward();
979
980        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).alpha(0.9);
981        rmsprop.step_with_params(&mut [&mut param]);
982
983        assert!(param.data()[0] < 2.0);
984    }
985
986    #[test]
987    fn test_rmsprop_with_eps() {
988        clear_graph();
989
990        let mut param = Tensor::from_slice(&[2.0]).requires_grad();
991
992        let loss = param.pow(2.0).sum();
993        loss.backward();
994
995        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).eps(1e-6);
996        rmsprop.step_with_params(&mut [&mut param]);
997
998        assert!(param.data()[0] < 2.0);
999    }
1000
1001    #[test]
1002    fn test_rmsprop_with_momentum() {
1003        clear_graph();
1004
1005        let mut param = Tensor::from_slice(&[3.0]).requires_grad();
1006
1007        // First step
1008        let loss = param.pow(2.0).sum();
1009        loss.backward();
1010
1011        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).momentum(0.9);
1012        rmsprop.step_with_params(&mut [&mut param]);
1013
1014        let after_first = param.data()[0];
1015        assert!(after_first < 3.0);
1016
1017        // Second step with momentum accumulation
1018        clear_graph();
1019        let loss = param.pow(2.0).sum();
1020        loss.backward();
1021
1022        rmsprop.step_with_params(&mut [&mut param]);
1023
1024        assert!(param.data()[0] < after_first);
1025    }
1026
1027    #[test]
1028    fn test_rmsprop_with_weight_decay() {
1029        clear_graph();
1030
1031        let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1032
1033        let loss = param.pow(2.0).sum();
1034        loss.backward();
1035
1036        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).weight_decay(0.1);
1037        rmsprop.step_with_params(&mut [&mut param]);
1038
1039        assert!(param.data()[0] < 5.0);
1040    }
1041
1042    #[test]
1043    fn test_rmsprop_lr_methods() {
1044        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1045        let mut rmsprop = RMSprop::new(vec![&mut param], 0.01);
1046
1047        assert!((rmsprop.lr() - 0.01).abs() < 1e-6);
1048        rmsprop.set_lr(0.001);
1049        assert!((rmsprop.lr() - 0.001).abs() < 1e-6);
1050    }
1051
1052    #[test]
1053    fn test_rmsprop_zero_grad() {
1054        clear_graph();
1055
1056        let mut param = Tensor::from_slice(&[2.0]).requires_grad();
1057        let param_id = param.id();
1058
1059        let loss = param.pow(2.0).sum();
1060        loss.backward();
1061
1062        assert!(get_grad(param_id).is_some());
1063
1064        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1065        rmsprop.zero_grad();
1066
1067        assert!(get_grad(param_id).is_none());
1068    }
1069
1070    #[test]
1071    fn test_rmsprop_step_trait() {
1072        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1073        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1074
1075        rmsprop.step();
1076        assert!(rmsprop.initialized);
1077    }
1078
1079    #[test]
1080    fn test_sgd_step_trait() {
1081        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1082        let mut sgd = SGD::new(vec![&mut param], 0.1);
1083
1084        sgd.step();
1085        assert!(sgd.initialized);
1086    }
1087
1088    #[test]
1089    fn test_adam_step_trait() {
1090        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1091        let mut adam = Adam::new(vec![&mut param], 0.1);
1092
1093        adam.step();
1094        assert!(adam.initialized);
1095        assert_eq!(adam.t, 1);
1096    }
1097
1098    #[test]
1099    fn test_adam_lr_methods() {
1100        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1101        let mut adam = Adam::new(vec![&mut param], 0.01);
1102
1103        assert!((adam.lr() - 0.01).abs() < 1e-6);
1104        adam.set_lr(0.001);
1105        assert!((adam.lr() - 0.001).abs() < 1e-6);
1106    }
1107
1108    #[test]
1109    fn test_adam_zero_grad() {
1110        clear_graph();
1111
1112        let mut param = Tensor::from_slice(&[2.0]).requires_grad();
1113        let param_id = param.id();
1114
1115        let loss = param.pow(2.0).sum();
1116        loss.backward();
1117
1118        assert!(get_grad(param_id).is_some());
1119
1120        let mut adam = Adam::new(vec![&mut param], 0.1);
1121        adam.zero_grad();
1122
1123        assert!(get_grad(param_id).is_none());
1124    }
1125
1126    #[test]
1127    fn test_sgd_multi_element_tensor() {
1128        clear_graph();
1129
1130        let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]).requires_grad();
1131
1132        let loss = param.pow(2.0).sum();
1133        loss.backward();
1134
1135        let mut sgd = SGD::new(vec![&mut param], 0.1);
1136        sgd.step_with_params(&mut [&mut param]);
1137
1138        // All elements should have decreased
1139        assert!(param.data()[0] < 1.0);
1140        assert!(param.data()[1] < 2.0);
1141        assert!(param.data()[2] < 3.0);
1142        assert!(param.data()[3] < 4.0);
1143    }
1144
1145    #[test]
1146    fn test_adam_multi_element_tensor() {
1147        clear_graph();
1148
1149        let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
1150
1151        let loss = param.pow(2.0).sum();
1152        loss.backward();
1153
1154        let mut adam = Adam::new(vec![&mut param], 0.1);
1155        adam.step_with_params(&mut [&mut param]);
1156
1157        // All elements should have decreased
1158        assert!(param.data()[0] < 1.0);
1159        assert!(param.data()[1] < 2.0);
1160        assert!(param.data()[2] < 3.0);
1161    }
1162
1163    #[test]
1164    fn test_adamw_multi_step() {
1165        clear_graph();
1166
1167        let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1168        let mut adamw = AdamW::new(vec![&mut param], 0.5).weight_decay(0.01);
1169
1170        // Multiple steps to test convergence
1171        for _ in 0..10 {
1172            clear_graph();
1173            let loss = param.pow(2.0).sum();
1174            loss.backward();
1175            adamw.step_with_params(&mut [&mut param]);
1176        }
1177
1178        // Should have decreased significantly
1179        assert!(param.data()[0] < 1.0);
1180    }
1181
1182    #[test]
1183    fn test_rmsprop_convergence() {
1184        clear_graph();
1185
1186        let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1187        let mut rmsprop = RMSprop::new(vec![&mut param], 0.5);
1188
1189        // Multiple steps to test convergence
1190        for _ in 0..10 {
1191            clear_graph();
1192            let loss = param.pow(2.0).sum();
1193            loss.backward();
1194            rmsprop.step_with_params(&mut [&mut param]);
1195        }
1196
1197        // Should have decreased significantly
1198        assert!(param.data()[0] < 1.0);
1199    }
1200
1201    // ========== Additional Coverage Tests ==========
1202
1203    #[test]
1204    fn test_sgd_lr_accessor() {
1205        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1206        let sgd = SGD::new(vec![&mut param], 0.05);
1207        assert!((sgd.lr() - 0.05).abs() < 1e-6);
1208    }
1209
1210    #[test]
1211    fn test_adam_lr_accessor() {
1212        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1213        let adam = Adam::new(vec![&mut param], 0.001);
1214        assert!((adam.lr() - 0.001).abs() < 1e-6);
1215    }
1216
1217    #[test]
1218    fn test_adamw_lr_accessor() {
1219        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1220        let adamw = AdamW::new(vec![&mut param], 0.002);
1221        assert!((adamw.lr() - 0.002).abs() < 1e-6);
1222    }
1223
1224    #[test]
1225    fn test_rmsprop_lr_accessor() {
1226        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1227        let rmsprop = RMSprop::new(vec![&mut param], 0.003);
1228        assert!((rmsprop.lr() - 0.003).abs() < 1e-6);
1229    }
1230
1231    #[test]
1232    fn test_adam_set_lr() {
1233        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1234        let mut adam = Adam::new(vec![&mut param], 0.1);
1235        adam.set_lr(0.001);
1236        assert!((adam.lr() - 0.001).abs() < 1e-6);
1237    }
1238
1239    #[test]
1240    fn test_adamw_set_lr() {
1241        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1242        let mut adamw = AdamW::new(vec![&mut param], 0.1);
1243        adamw.set_lr(0.001);
1244        assert!((adamw.lr() - 0.001).abs() < 1e-6);
1245    }
1246
1247    #[test]
1248    fn test_rmsprop_set_lr() {
1249        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1250        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1251        rmsprop.set_lr(0.001);
1252        assert!((rmsprop.lr() - 0.001).abs() < 1e-6);
1253    }
1254
1255    #[test]
1256    fn test_adam_zero_grad_clears() {
1257        clear_graph();
1258
1259        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1260        let param_id = param.id();
1261
1262        let loss = param.pow(2.0).sum();
1263        loss.backward();
1264
1265        assert!(get_grad(param_id).is_some());
1266
1267        let mut adam = Adam::new(vec![&mut param], 0.1);
1268        adam.zero_grad();
1269
1270        assert!(get_grad(param_id).is_none());
1271    }
1272
1273    #[test]
1274    fn test_adamw_zero_grad_clears() {
1275        clear_graph();
1276
1277        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1278        let param_id = param.id();
1279
1280        let loss = param.pow(2.0).sum();
1281        loss.backward();
1282
1283        assert!(get_grad(param_id).is_some());
1284
1285        let mut adamw = AdamW::new(vec![&mut param], 0.1);
1286        adamw.zero_grad();
1287
1288        assert!(get_grad(param_id).is_none());
1289    }
1290
1291    #[test]
1292    fn test_rmsprop_zero_grad_clears() {
1293        clear_graph();
1294
1295        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1296        let param_id = param.id();
1297
1298        let loss = param.pow(2.0).sum();
1299        loss.backward();
1300
1301        assert!(get_grad(param_id).is_some());
1302
1303        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1304        rmsprop.zero_grad();
1305
1306        assert!(get_grad(param_id).is_none());
1307    }
1308
1309    #[test]
1310    fn test_sgd_multiple_params() {
1311        clear_graph();
1312
1313        let mut param1 = Tensor::from_slice(&[1.0]).requires_grad();
1314        let mut param2 = Tensor::from_slice(&[2.0]).requires_grad();
1315
1316        // Create loss using both params (use add method instead of +)
1317        let loss1 = param1.pow(2.0).sum();
1318        let loss2 = param2.pow(2.0).sum();
1319        let loss = loss1.add(&loss2);
1320        loss.backward();
1321
1322        let mut sgd = SGD::new(vec![&mut param1, &mut param2], 0.1);
1323        sgd.step_with_params(&mut [&mut param1, &mut param2]);
1324
1325        // Both params should have decreased
1326        assert!(param1.data()[0] < 1.0);
1327        assert!(param2.data()[0] < 2.0);
1328    }
1329
1330    #[test]
1331    fn test_adam_multiple_params() {
1332        clear_graph();
1333
1334        let mut param1 = Tensor::from_slice(&[1.0]).requires_grad();
1335        let mut param2 = Tensor::from_slice(&[2.0]).requires_grad();
1336
1337        let loss1 = param1.pow(2.0).sum();
1338        let loss2 = param2.pow(2.0).sum();
1339        let loss = loss1.add(&loss2);
1340        loss.backward();
1341
1342        let mut adam = Adam::new(vec![&mut param1, &mut param2], 0.1);
1343        adam.step_with_params(&mut [&mut param1, &mut param2]);
1344
1345        assert!(param1.data()[0] < 1.0);
1346        assert!(param2.data()[0] < 2.0);
1347    }
1348
1349    #[test]
1350    fn test_rmsprop_alpha_builder() {
1351        clear_graph();
1352
1353        let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1354
1355        let loss = param.pow(2.0).sum();
1356        loss.backward();
1357
1358        // Create RMSprop with custom alpha using builder pattern
1359        let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).alpha(0.9);
1360        rmsprop.step_with_params(&mut [&mut param]);
1361
1362        assert!(param.data()[0] < 5.0);
1363    }
1364
1365    #[test]
1366    fn test_sgd_debug_trait() {
1367        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1368        let sgd = SGD::new(vec![&mut param], 0.1);
1369        let debug_str = format!("{:?}", sgd);
1370        assert!(debug_str.contains("SGD"));
1371    }
1372
1373    #[test]
1374    fn test_adam_debug_trait() {
1375        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1376        let adam = Adam::new(vec![&mut param], 0.1);
1377        let debug_str = format!("{:?}", adam);
1378        assert!(debug_str.contains("Adam"));
1379    }
1380
1381    #[test]
1382    fn test_adamw_debug_trait() {
1383        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1384        let adamw = AdamW::new(vec![&mut param], 0.1);
1385        let debug_str = format!("{:?}", adamw);
1386        assert!(debug_str.contains("AdamW"));
1387    }
1388
1389    #[test]
1390    fn test_rmsprop_debug_trait() {
1391        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1392        let rmsprop = RMSprop::new(vec![&mut param], 0.1);
1393        let debug_str = format!("{:?}", rmsprop);
1394        assert!(debug_str.contains("RMSprop"));
1395    }
1396
1397    #[test]
1398    fn test_sgd_empty_params() {
1399        let sgd = SGD::new(vec![], 0.1);
1400        assert!((sgd.lr() - 0.1).abs() < 1e-6);
1401    }
1402
1403    #[test]
1404    fn test_adam_empty_params() {
1405        let adam = Adam::new(vec![], 0.1);
1406        assert!((adam.lr() - 0.1).abs() < 1e-6);
1407    }
1408
1409    #[test]
1410    fn test_sgd_momentum_initialization() {
1411        clear_graph();
1412
1413        let mut param = Tensor::from_slice(&[3.0, 4.0]).requires_grad();
1414
1415        let loss = param.pow(2.0).sum();
1416        loss.backward();
1417
1418        // First step initializes velocities
1419        let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9);
1420        sgd.step_with_params(&mut [&mut param]);
1421
1422        // After first step, momentum buffer should be initialized
1423        assert!(param.data()[0] < 3.0);
1424        assert!(param.data()[1] < 4.0);
1425    }
1426
1427    #[test]
1428    fn test_adam_step_counter() {
1429        clear_graph();
1430
1431        let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1432        let mut adam = Adam::new(vec![&mut param], 0.1);
1433
1434        // Step multiple times
1435        for _ in 0..3 {
1436            clear_graph();
1437            let loss = param.pow(2.0).sum();
1438            loss.backward();
1439            adam.step_with_params(&mut [&mut param]);
1440        }
1441
1442        // After 3 steps param should have decreased from 1.0
1443        assert!(param.data()[0] < 1.0);
1444    }
1445}