optirs_core/optimizers/
meta_sgd.rs1use scirs2_core::ndarray::{Array, Dimension, IxDyn, ScalarOperand};
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13
14use crate::error::Result;
15use crate::optimizers::Optimizer;
16
17#[derive(Debug, Clone)]
48pub struct MetaSGD<A: Float + ScalarOperand + Debug> {
49 base_lr: A,
51 alpha_lr: A,
53 inner_steps: usize,
55 per_param_lr: Option<Array<A, IxDyn>>,
57 step_count: usize,
59}
60
61impl<A: Float + ScalarOperand + Debug> MetaSGD<A> {
62 pub fn new(base_lr: A) -> Self {
72 Self {
73 base_lr,
74 alpha_lr: A::from(0.001).expect("MetaSGD: failed to convert alpha_lr constant"),
75 inner_steps: 5,
76 per_param_lr: None,
77 step_count: 0,
78 }
79 }
80
81 pub fn with_alpha_lr(mut self, lr: A) -> Self {
87 self.alpha_lr = lr;
88 self
89 }
90
91 pub fn with_inner_steps(mut self, n: usize) -> Self {
97 self.inner_steps = if n == 0 { 1 } else { n };
98 self
99 }
100
101 pub fn get_base_lr(&self) -> A {
103 self.base_lr
104 }
105
106 pub fn get_alpha_lr(&self) -> A {
108 self.alpha_lr
109 }
110
111 pub fn get_inner_steps(&self) -> usize {
113 self.inner_steps
114 }
115
116 pub fn get_step_count(&self) -> usize {
118 self.step_count
119 }
120
121 pub fn get_per_param_lr(&self) -> Option<&Array<A, IxDyn>> {
123 self.per_param_lr.as_ref()
124 }
125
126 pub fn reset_per_param_lr(&mut self) {
128 self.per_param_lr = None;
129 }
130
131 fn clamp_lr_array(lr_array: &mut Array<A, IxDyn>, min_val: A, max_val: A) {
133 lr_array.mapv_inplace(|v| {
134 if v < min_val {
135 min_val
136 } else if v > max_val {
137 max_val
138 } else {
139 v
140 }
141 });
142 }
143}
144
145impl<A, D> Optimizer<A, D> for MetaSGD<A>
146where
147 A: Float + ScalarOperand + Debug,
148 D: Dimension,
149{
150 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
151 let params_dyn = params.to_owned().into_dyn();
152 let gradients_dyn = gradients.to_owned().into_dyn();
153
154 let min_lr = A::from(1e-8).expect("MetaSGD: failed to convert min_lr constant");
155 let max_lr = A::from(10.0).expect("MetaSGD: failed to convert max_lr constant");
156
157 if self.per_param_lr.is_none() {
159 let lr_init = Array::<A, IxDyn>::from_elem(params_dyn.raw_dim(), self.base_lr);
160 self.per_param_lr = Some(lr_init);
161 }
162
163 {
165 let current_lr = self
166 .per_param_lr
167 .as_ref()
168 .expect("MetaSGD: per_param_lr should be initialized");
169 if current_lr.raw_dim() != params_dyn.raw_dim() {
170 self.per_param_lr = Some(Array::<A, IxDyn>::from_elem(
171 params_dyn.raw_dim(),
172 self.base_lr,
173 ));
174 }
175 }
176
177 let per_param_lr = self
178 .per_param_lr
179 .as_ref()
180 .expect("MetaSGD: per_param_lr should be initialized")
181 .clone();
182
183 let mut adapted_params = params_dyn.clone();
185 let mut cumulative_delta = Array::<A, IxDyn>::zeros(params_dyn.raw_dim());
186
187 for _ in 0..self.inner_steps {
188 let delta = &per_param_lr * &gradients_dyn;
190 cumulative_delta = &cumulative_delta + δ
192 adapted_params = &adapted_params - δ
194 }
195
196 let meta_gradient = &gradients_dyn * &cumulative_delta;
200 let mut updated_lr = &per_param_lr - &(&meta_gradient * self.alpha_lr);
201
202 Self::clamp_lr_array(&mut updated_lr, min_lr, max_lr);
204
205 self.per_param_lr = Some(updated_lr);
206 self.step_count += 1;
207
208 Ok(adapted_params
210 .into_dimensionality::<D>()
211 .expect("MetaSGD: failed to convert back to original dimensionality"))
212 }
213
214 fn get_learning_rate(&self) -> A {
215 self.base_lr
216 }
217
218 fn set_learning_rate(&mut self, learning_rate: A) {
219 self.base_lr = learning_rate;
220 self.per_param_lr = None;
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use scirs2_core::ndarray::Array1;
229
230 #[test]
231 fn test_meta_sgd_basic_creation() {
232 let optimizer: MetaSGD<f64> = MetaSGD::new(0.01);
233 assert!((optimizer.get_base_lr() - 0.01).abs() < 1e-10);
234 assert!((optimizer.get_alpha_lr() - 0.001).abs() < 1e-10);
235 assert_eq!(optimizer.get_inner_steps(), 5);
236 assert_eq!(optimizer.get_step_count(), 0);
237 assert!(optimizer.get_per_param_lr().is_none());
238 }
239
240 #[test]
241 fn test_meta_sgd_builder_pattern() {
242 let optimizer: MetaSGD<f64> = MetaSGD::new(0.01)
243 .with_alpha_lr(0.0001)
244 .with_inner_steps(10);
245
246 assert!((optimizer.get_alpha_lr() - 0.0001).abs() < 1e-10);
247 assert_eq!(optimizer.get_inner_steps(), 10);
248 }
249
250 #[test]
251 fn test_meta_sgd_step_works() {
252 let mut optimizer = MetaSGD::new(0.1_f64).with_inner_steps(1);
253
254 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
255 let gradients = Array1::from_vec(vec![0.5, -0.5, 0.0]);
256
257 let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
258
259 assert!((new_params[0] - 0.95).abs() < 1e-10);
263 assert!((new_params[1] - 2.05).abs() < 1e-10);
264 assert!((new_params[2] - 3.0).abs() < 1e-10);
265 assert_eq!(optimizer.get_step_count(), 1);
266
267 assert!(optimizer.get_per_param_lr().is_some());
269 }
270
271 #[test]
272 fn test_meta_sgd_per_param_lr_adaptation() {
273 let mut optimizer = MetaSGD::new(0.1_f64)
274 .with_alpha_lr(0.01)
275 .with_inner_steps(1);
276
277 let params = Array1::from_vec(vec![1.0, 2.0]);
278 let gradients = Array1::from_vec(vec![1.0, 0.001]);
279
280 let _ = optimizer.step(¶ms, &gradients).expect("step failed");
282
283 let lr_after_first = optimizer
284 .get_per_param_lr()
285 .expect("per_param_lr should exist")
286 .clone();
287
288 let lr_diff_0 = (lr_after_first[0] - 0.1_f64).abs();
296 let lr_diff_1 = (lr_after_first[1] - 0.1_f64).abs();
297 assert!(
298 lr_diff_0 > lr_diff_1,
299 "Larger gradient dimension should have more LR change: diff_0={lr_diff_0}, diff_1={lr_diff_1}"
300 );
301 }
302
303 #[test]
304 fn test_meta_sgd_convergence_toward_minimum() {
305 let mut optimizer = MetaSGD::new(0.05_f64)
307 .with_alpha_lr(0.0001)
308 .with_inner_steps(1);
309
310 let mut params = Array1::from_vec(vec![5.0, -3.0, 2.0]);
311
312 for _ in 0..200 {
313 let gradients = ¶ms * 2.0;
314 params = optimizer.step(¶ms, &gradients).expect("step failed");
315 }
316
317 for &val in params.iter() {
319 assert!(
320 val.abs() < 0.5,
321 "Parameter {val} did not converge to near zero"
322 );
323 }
324 }
325
326 #[test]
327 fn test_meta_sgd_lr_clamping() {
328 let mut optimizer = MetaSGD::new(0.1_f64)
330 .with_alpha_lr(100.0) .with_inner_steps(1);
332
333 let params = Array1::from_vec(vec![1.0, 2.0]);
334 let gradients = Array1::from_vec(vec![1.0, -1.0]);
335
336 let _ = optimizer.step(¶ms, &gradients).expect("step failed");
338
339 let per_param_lr = optimizer
340 .get_per_param_lr()
341 .expect("per_param_lr should exist");
342
343 for &lr in per_param_lr.iter() {
345 assert!(
346 (1e-8..=10.0).contains(&lr),
347 "Per-param LR {lr} is out of clamped range [1e-8, 10.0]"
348 );
349 }
350 }
351
352 #[test]
353 fn test_meta_sgd_zero_gradient() {
354 let mut optimizer = MetaSGD::new(0.1_f64).with_inner_steps(3);
355
356 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
357 let gradients = Array1::from_vec(vec![0.0, 0.0, 0.0]);
358
359 let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
360
361 for (p, np) in params.iter().zip(new_params.iter()) {
363 assert!(
364 (*p - *np).abs() < 1e-12,
365 "Params changed with zero gradient"
366 );
367 }
368 }
369
370 #[test]
371 fn test_meta_sgd_set_learning_rate_resets_per_param() {
372 let mut optimizer = MetaSGD::new(0.1_f64);
373 let params = Array1::from_vec(vec![1.0, 2.0]);
374 let gradients = Array1::from_vec(vec![0.1, 0.2]);
375
376 let _ = optimizer.step(¶ms, &gradients).expect("step failed");
377 assert!(optimizer.get_per_param_lr().is_some());
378
379 Optimizer::<f64, scirs2_core::ndarray::Ix1>::set_learning_rate(&mut optimizer, 0.05);
381 assert!(optimizer.get_per_param_lr().is_none());
382 assert!(
383 (Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.05)
384 .abs()
385 < 1e-10
386 );
387 }
388
389 #[test]
390 fn test_meta_sgd_inner_steps_zero_clamps_to_one() {
391 let optimizer: MetaSGD<f64> = MetaSGD::new(0.01).with_inner_steps(0);
392 assert_eq!(optimizer.get_inner_steps(), 1);
393 }
394
395 #[test]
396 fn test_meta_sgd_multiple_steps_count() {
397 let mut optimizer = MetaSGD::new(0.01_f64);
398 let params = Array1::from_vec(vec![1.0, 2.0]);
399 let gradients = Array1::from_vec(vec![0.1, 0.2]);
400
401 for i in 0..5 {
402 let _ = optimizer.step(¶ms, &gradients).expect("step failed");
403 assert_eq!(optimizer.get_step_count(), i + 1);
404 }
405 }
406
407 #[test]
408 fn test_meta_sgd_reset_per_param_lr() {
409 let mut optimizer = MetaSGD::new(0.1_f64);
410 let params = Array1::from_vec(vec![1.0]);
411 let gradients = Array1::from_vec(vec![0.1]);
412
413 let _ = optimizer.step(¶ms, &gradients).expect("step failed");
414 assert!(optimizer.get_per_param_lr().is_some());
415
416 optimizer.reset_per_param_lr();
417 assert!(optimizer.get_per_param_lr().is_none());
418 }
419}