Skip to main content

entrenar/optim/
optimizer.rs

1//! Optimizer trait
2
3use crate::Tensor;
4
5/// Trait for optimization algorithms
6pub trait Optimizer {
7    /// Perform a single optimization step
8    fn step(&mut self, params: &mut [Tensor]);
9
10    /// Perform optimization step on referenced parameters
11    ///
12    /// This is useful when parameters are borrowed from a model
13    fn step_refs(&mut self, params: &mut [&mut Tensor]) {
14        // Default implementation delegates to step via collecting
15        // Subclasses can override for efficiency
16        for param in params.iter_mut() {
17            if let Some(grad) = param.grad() {
18                // Apply simple SGD update as fallback
19                let lr = self.lr();
20                let grad_data = grad.to_vec();
21                let data = param.data_mut();
22                for (d, g) in data.iter_mut().zip(grad_data.iter()) {
23                    *d -= lr * g;
24                }
25            }
26        }
27    }
28
29    /// Zero out all gradients
30    fn zero_grad(&mut self, params: &mut [Tensor]) {
31        for param in params {
32            param.zero_grad();
33        }
34    }
35
36    /// Zero gradients on referenced parameters
37    fn zero_grad_refs(&mut self, params: &mut [&mut Tensor]) {
38        for param in params.iter_mut() {
39            param.zero_grad();
40        }
41    }
42
43    /// Get learning rate
44    fn lr(&self) -> f32;
45
46    /// Set learning rate
47    fn set_lr(&mut self, lr: f32);
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use ndarray::arr1;
54
55    /// Minimal optimizer implementation for testing default trait methods
56    struct TestOptimizer {
57        learning_rate: f32,
58    }
59
60    impl TestOptimizer {
61        fn new(lr: f32) -> Self {
62            Self { learning_rate: lr }
63        }
64    }
65
66    impl Optimizer for TestOptimizer {
67        fn step(&mut self, params: &mut [Tensor]) {
68            for param in params {
69                if let Some(grad) = param.grad() {
70                    let grad_data = grad.to_vec();
71                    let data = param.data_mut();
72                    for (d, g) in data.iter_mut().zip(grad_data.iter()) {
73                        *d -= self.learning_rate * g;
74                    }
75                }
76            }
77        }
78
79        fn lr(&self) -> f32 {
80            self.learning_rate
81        }
82
83        fn set_lr(&mut self, lr: f32) {
84            self.learning_rate = lr;
85        }
86    }
87
88    #[test]
89    fn test_optimizer_step() {
90        let mut opt = TestOptimizer::new(0.1);
91        let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
92        param.set_grad(arr1(&[0.5, 1.0, 1.5]));
93
94        opt.step(&mut [param.clone()]);
95
96        // Check that lr is accessible
97        assert_eq!(opt.lr(), 0.1);
98    }
99
100    #[test]
101    fn test_optimizer_step_refs() {
102        let mut opt = TestOptimizer::new(0.1);
103        let mut param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
104        param.set_grad(arr1(&[0.5, 1.0, 1.5]));
105
106        let original_data = param.data().to_vec();
107        opt.step_refs(&mut [&mut param]);
108
109        // Check values were updated: new = old - lr * grad
110        let updated_data = param.data().to_vec();
111        for i in 0..3 {
112            let expected = original_data[i] - 0.1 * [0.5, 1.0, 1.5][i];
113            assert!((updated_data[i] - expected).abs() < 1e-6);
114        }
115    }
116
117    #[test]
118    fn test_optimizer_step_refs_no_grad() {
119        let mut opt = TestOptimizer::new(0.1);
120        let mut param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
121        // No gradient set
122
123        let original_data = param.data().to_vec();
124        opt.step_refs(&mut [&mut param]);
125
126        // Values should be unchanged when no gradient
127        let updated_data = param.data().to_vec();
128        assert_eq!(original_data, updated_data);
129    }
130
131    #[test]
132    fn test_optimizer_zero_grad() {
133        let mut opt = TestOptimizer::new(0.1);
134        let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
135        param.set_grad(arr1(&[0.5, 1.0, 1.5]));
136
137        assert!(param.grad().is_some());
138        opt.zero_grad(&mut [param.clone()]);
139        // After zero_grad, the gradient should be zeroed
140    }
141
142    #[test]
143    fn test_optimizer_zero_grad_refs() {
144        let mut opt = TestOptimizer::new(0.1);
145        let mut param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
146        param.set_grad(arr1(&[0.5, 1.0, 1.5]));
147
148        assert!(param.grad().is_some());
149        opt.zero_grad_refs(&mut [&mut param]);
150        // After zero_grad_refs, the gradient should be zeroed
151    }
152
153    #[test]
154    fn test_optimizer_set_lr() {
155        let mut opt = TestOptimizer::new(0.1);
156        assert_eq!(opt.lr(), 0.1);
157
158        opt.set_lr(0.01);
159        assert_eq!(opt.lr(), 0.01);
160    }
161
162    #[test]
163    fn test_optimizer_step_refs_multiple_params() {
164        let mut opt = TestOptimizer::new(0.1);
165        let mut param1 = Tensor::from_vec(vec![1.0, 2.0], true);
166        let mut param2 = Tensor::from_vec(vec![3.0, 4.0], true);
167        param1.set_grad(arr1(&[0.5, 1.0]));
168        param2.set_grad(arr1(&[1.5, 2.0]));
169
170        opt.step_refs(&mut [&mut param1, &mut param2]);
171
172        // Both params should be updated
173        let data1 = param1.data().to_vec();
174        let data2 = param2.data().to_vec();
175
176        assert!((data1[0] - 0.95).abs() < 1e-6); // 1.0 - 0.1 * 0.5
177        assert!((data1[1] - 1.9).abs() < 1e-6); // 2.0 - 0.1 * 1.0
178        assert!((data2[0] - 2.85).abs() < 1e-6); // 3.0 - 0.1 * 1.5
179        assert!((data2[1] - 3.8).abs() < 1e-6); // 4.0 - 0.1 * 2.0
180    }
181
182    #[test]
183    fn test_optimizer_zero_grad_multiple_params() {
184        let mut opt = TestOptimizer::new(0.1);
185        let mut params =
186            vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0, 4.0], true)];
187
188        for p in &mut params {
189            p.set_grad(arr1(&[0.5, 1.0]));
190        }
191
192        opt.zero_grad(&mut params);
193        // All gradients should be zeroed
194    }
195}