entrenar/optim/
optimizer.rs1use crate::Tensor;
4
5pub trait Optimizer {
7 fn step(&mut self, params: &mut [Tensor]);
9
10 fn step_refs(&mut self, params: &mut [&mut Tensor]) {
14 for param in params.iter_mut() {
17 if let Some(grad) = param.grad() {
18 let lr = self.lr();
20 let grad_data = grad.to_vec();
21 let data = param.data_mut();
22 for (d, g) in data.iter_mut().zip(grad_data.iter()) {
23 *d -= lr * g;
24 }
25 }
26 }
27 }
28
29 fn zero_grad(&mut self, params: &mut [Tensor]) {
31 for param in params {
32 param.zero_grad();
33 }
34 }
35
36 fn zero_grad_refs(&mut self, params: &mut [&mut Tensor]) {
38 for param in params.iter_mut() {
39 param.zero_grad();
40 }
41 }
42
43 fn lr(&self) -> f32;
45
46 fn set_lr(&mut self, lr: f32);
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use ndarray::arr1;
54
55 struct TestOptimizer {
57 learning_rate: f32,
58 }
59
60 impl TestOptimizer {
61 fn new(lr: f32) -> Self {
62 Self { learning_rate: lr }
63 }
64 }
65
66 impl Optimizer for TestOptimizer {
67 fn step(&mut self, params: &mut [Tensor]) {
68 for param in params {
69 if let Some(grad) = param.grad() {
70 let grad_data = grad.to_vec();
71 let data = param.data_mut();
72 for (d, g) in data.iter_mut().zip(grad_data.iter()) {
73 *d -= self.learning_rate * g;
74 }
75 }
76 }
77 }
78
79 fn lr(&self) -> f32 {
80 self.learning_rate
81 }
82
83 fn set_lr(&mut self, lr: f32) {
84 self.learning_rate = lr;
85 }
86 }
87
88 #[test]
89 fn test_optimizer_step() {
90 let mut opt = TestOptimizer::new(0.1);
91 let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
92 param.set_grad(arr1(&[0.5, 1.0, 1.5]));
93
94 opt.step(&mut [param.clone()]);
95
96 assert_eq!(opt.lr(), 0.1);
98 }
99
100 #[test]
101 fn test_optimizer_step_refs() {
102 let mut opt = TestOptimizer::new(0.1);
103 let mut param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
104 param.set_grad(arr1(&[0.5, 1.0, 1.5]));
105
106 let original_data = param.data().to_vec();
107 opt.step_refs(&mut [&mut param]);
108
109 let updated_data = param.data().to_vec();
111 for i in 0..3 {
112 let expected = original_data[i] - 0.1 * [0.5, 1.0, 1.5][i];
113 assert!((updated_data[i] - expected).abs() < 1e-6);
114 }
115 }
116
117 #[test]
118 fn test_optimizer_step_refs_no_grad() {
119 let mut opt = TestOptimizer::new(0.1);
120 let mut param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
121 let original_data = param.data().to_vec();
124 opt.step_refs(&mut [&mut param]);
125
126 let updated_data = param.data().to_vec();
128 assert_eq!(original_data, updated_data);
129 }
130
131 #[test]
132 fn test_optimizer_zero_grad() {
133 let mut opt = TestOptimizer::new(0.1);
134 let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
135 param.set_grad(arr1(&[0.5, 1.0, 1.5]));
136
137 assert!(param.grad().is_some());
138 opt.zero_grad(&mut [param.clone()]);
139 }
141
142 #[test]
143 fn test_optimizer_zero_grad_refs() {
144 let mut opt = TestOptimizer::new(0.1);
145 let mut param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
146 param.set_grad(arr1(&[0.5, 1.0, 1.5]));
147
148 assert!(param.grad().is_some());
149 opt.zero_grad_refs(&mut [&mut param]);
150 }
152
153 #[test]
154 fn test_optimizer_set_lr() {
155 let mut opt = TestOptimizer::new(0.1);
156 assert_eq!(opt.lr(), 0.1);
157
158 opt.set_lr(0.01);
159 assert_eq!(opt.lr(), 0.01);
160 }
161
162 #[test]
163 fn test_optimizer_step_refs_multiple_params() {
164 let mut opt = TestOptimizer::new(0.1);
165 let mut param1 = Tensor::from_vec(vec![1.0, 2.0], true);
166 let mut param2 = Tensor::from_vec(vec![3.0, 4.0], true);
167 param1.set_grad(arr1(&[0.5, 1.0]));
168 param2.set_grad(arr1(&[1.5, 2.0]));
169
170 opt.step_refs(&mut [&mut param1, &mut param2]);
171
172 let data1 = param1.data().to_vec();
174 let data2 = param2.data().to_vec();
175
176 assert!((data1[0] - 0.95).abs() < 1e-6); assert!((data1[1] - 1.9).abs() < 1e-6); assert!((data2[0] - 2.85).abs() < 1e-6); assert!((data2[1] - 3.8).abs() < 1e-6); }
181
182 #[test]
183 fn test_optimizer_zero_grad_multiple_params() {
184 let mut opt = TestOptimizer::new(0.1);
185 let mut params =
186 vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0, 4.0], true)];
187
188 for p in &mut params {
189 p.set_grad(arr1(&[0.5, 1.0]));
190 }
191
192 opt.zero_grad(&mut params);
193 }
195}