Skip to main content

optirs_core/optimizers/
meta_sgd.rs

1// Meta-SGD optimizer with per-parameter learnable learning rates
2//
3// Meta-SGD extends MAML by learning not only the model initialization but also
4// per-parameter learning rates. This allows the model to adapt more effectively
5// to new tasks by using different learning rates for different parameters.
6//
7// Reference: Li, Z., Zhou, F., Chen, F., & Li, H. (2017).
8// "Meta-SGD: Learning to Learn Quickly for Few-Shot Learning"
9
10use scirs2_core::ndarray::{Array, Dimension, IxDyn, ScalarOperand};
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13
14use crate::error::Result;
15use crate::optimizers::Optimizer;
16
17/// Meta-SGD optimizer with per-parameter learnable learning rates
18///
19/// Implements the Meta-SGD algorithm which learns per-parameter learning rates
20/// alongside the model parameters. Each parameter gets its own adaptive learning
21/// rate that is updated based on the meta-gradient.
22///
23/// # Algorithm
24///
25/// For each step:
26/// 1. Initialize per-parameter learning rates alpha_i to base_lr (if first step)
27/// 2. Compute parameter update: delta_i = alpha_i * grad_i
28/// 3. Update parameters: theta_i = theta_i - delta_i
29/// 4. Update per-parameter LRs: alpha_i = alpha_i - alpha_lr * grad_i * delta_i
30/// 5. Clamp alpha_i to [1e-8, 10.0]
31///
32/// The per-parameter learning rates evolve over time, allowing the optimizer to
33/// automatically discover the best learning rate for each parameter dimension.
34///
35/// # Examples
36///
37/// ```
38/// use scirs2_core::ndarray::Array1;
39/// use optirs_core::optimizers::{MetaSGD, Optimizer};
40///
41/// let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
42/// let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
43///
44/// let mut optimizer = MetaSGD::new(0.01);
45/// let new_params = optimizer.step(&params, &gradients).expect("step failed");
46/// ```
47#[derive(Debug, Clone)]
48pub struct MetaSGD<A: Float + ScalarOperand + Debug> {
49    /// Base learning rate (used to initialize per-parameter LRs)
50    base_lr: A,
51    /// Learning rate for updating per-parameter learning rates (meta-learning rate)
52    alpha_lr: A,
53    /// Number of inner adaptation steps
54    inner_steps: usize,
55    /// Per-parameter learnable learning rates
56    per_param_lr: Option<Array<A, IxDyn>>,
57    /// Count of steps taken
58    step_count: usize,
59}
60
61impl<A: Float + ScalarOperand + Debug> MetaSGD<A> {
62    /// Creates a new Meta-SGD optimizer with the given base learning rate
63    ///
64    /// Defaults:
65    /// - alpha_lr: 0.001
66    /// - inner_steps: 5
67    ///
68    /// # Arguments
69    ///
70    /// * `base_lr` - Base learning rate for initializing per-parameter LRs
71    pub fn new(base_lr: A) -> Self {
72        Self {
73            base_lr,
74            alpha_lr: A::from(0.001).expect("MetaSGD: failed to convert alpha_lr constant"),
75            inner_steps: 5,
76            per_param_lr: None,
77            step_count: 0,
78        }
79    }
80
81    /// Sets the meta-learning rate for updating per-parameter learning rates
82    ///
83    /// # Arguments
84    ///
85    /// * `lr` - Learning rate for the per-parameter LR updates
86    pub fn with_alpha_lr(mut self, lr: A) -> Self {
87        self.alpha_lr = lr;
88        self
89    }
90
91    /// Sets the number of inner adaptation steps
92    ///
93    /// # Arguments
94    ///
95    /// * `n` - Number of inner steps (must be >= 1)
96    pub fn with_inner_steps(mut self, n: usize) -> Self {
97        self.inner_steps = if n == 0 { 1 } else { n };
98        self
99    }
100
101    /// Returns the base learning rate
102    pub fn get_base_lr(&self) -> A {
103        self.base_lr
104    }
105
106    /// Returns the meta-learning rate (alpha_lr)
107    pub fn get_alpha_lr(&self) -> A {
108        self.alpha_lr
109    }
110
111    /// Returns the number of inner adaptation steps
112    pub fn get_inner_steps(&self) -> usize {
113        self.inner_steps
114    }
115
116    /// Returns the number of steps taken so far
117    pub fn get_step_count(&self) -> usize {
118        self.step_count
119    }
120
121    /// Returns a reference to the current per-parameter learning rates, if initialized
122    pub fn get_per_param_lr(&self) -> Option<&Array<A, IxDyn>> {
123        self.per_param_lr.as_ref()
124    }
125
126    /// Resets the per-parameter learning rates (they will be re-initialized on next step)
127    pub fn reset_per_param_lr(&mut self) {
128        self.per_param_lr = None;
129    }
130
131    /// Clamp learning rate values to the valid range [min_val, max_val]
132    fn clamp_lr_array(lr_array: &mut Array<A, IxDyn>, min_val: A, max_val: A) {
133        lr_array.mapv_inplace(|v| {
134            if v < min_val {
135                min_val
136            } else if v > max_val {
137                max_val
138            } else {
139                v
140            }
141        });
142    }
143}
144
145impl<A, D> Optimizer<A, D> for MetaSGD<A>
146where
147    A: Float + ScalarOperand + Debug,
148    D: Dimension,
149{
150    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
151        let params_dyn = params.to_owned().into_dyn();
152        let gradients_dyn = gradients.to_owned().into_dyn();
153
154        let min_lr = A::from(1e-8).expect("MetaSGD: failed to convert min_lr constant");
155        let max_lr = A::from(10.0).expect("MetaSGD: failed to convert max_lr constant");
156
157        // Step 1: Initialize per-parameter learning rates if needed
158        if self.per_param_lr.is_none() {
159            let lr_init = Array::<A, IxDyn>::from_elem(params_dyn.raw_dim(), self.base_lr);
160            self.per_param_lr = Some(lr_init);
161        }
162
163        // Handle shape mismatch (if params shape changed since last call)
164        {
165            let current_lr = self
166                .per_param_lr
167                .as_ref()
168                .expect("MetaSGD: per_param_lr should be initialized");
169            if current_lr.raw_dim() != params_dyn.raw_dim() {
170                self.per_param_lr = Some(Array::<A, IxDyn>::from_elem(
171                    params_dyn.raw_dim(),
172                    self.base_lr,
173                ));
174            }
175        }
176
177        let per_param_lr = self
178            .per_param_lr
179            .as_ref()
180            .expect("MetaSGD: per_param_lr should be initialized")
181            .clone();
182
183        // Step 2-3: Apply inner adaptation steps using per-parameter learning rates
184        let mut adapted_params = params_dyn.clone();
185        let mut cumulative_delta = Array::<A, IxDyn>::zeros(params_dyn.raw_dim());
186
187        for _ in 0..self.inner_steps {
188            // delta = per_param_lr * gradients
189            let delta = &per_param_lr * &gradients_dyn;
190            // Accumulate total parameter change for meta-gradient
191            cumulative_delta = &cumulative_delta + &delta;
192            // Update adapted params
193            adapted_params = &adapted_params - &delta;
194        }
195
196        // Step 4: Update per-parameter learning rates using meta-gradient
197        // The meta-gradient for alpha is: grad * cumulative_delta
198        // This encourages learning rates that reduce the loss
199        let meta_gradient = &gradients_dyn * &cumulative_delta;
200        let mut updated_lr = &per_param_lr - &(&meta_gradient * self.alpha_lr);
201
202        // Step 5: Clamp per-parameter learning rates
203        Self::clamp_lr_array(&mut updated_lr, min_lr, max_lr);
204
205        self.per_param_lr = Some(updated_lr);
206        self.step_count += 1;
207
208        // Convert back to original dimension
209        Ok(adapted_params
210            .into_dimensionality::<D>()
211            .expect("MetaSGD: failed to convert back to original dimensionality"))
212    }
213
214    fn get_learning_rate(&self) -> A {
215        self.base_lr
216    }
217
218    fn set_learning_rate(&mut self, learning_rate: A) {
219        self.base_lr = learning_rate;
220        // Reset per-param LRs so they re-initialize with new base_lr
221        self.per_param_lr = None;
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use scirs2_core::ndarray::Array1;
229
230    #[test]
231    fn test_meta_sgd_basic_creation() {
232        let optimizer: MetaSGD<f64> = MetaSGD::new(0.01);
233        assert!((optimizer.get_base_lr() - 0.01).abs() < 1e-10);
234        assert!((optimizer.get_alpha_lr() - 0.001).abs() < 1e-10);
235        assert_eq!(optimizer.get_inner_steps(), 5);
236        assert_eq!(optimizer.get_step_count(), 0);
237        assert!(optimizer.get_per_param_lr().is_none());
238    }
239
240    #[test]
241    fn test_meta_sgd_builder_pattern() {
242        let optimizer: MetaSGD<f64> = MetaSGD::new(0.01)
243            .with_alpha_lr(0.0001)
244            .with_inner_steps(10);
245
246        assert!((optimizer.get_alpha_lr() - 0.0001).abs() < 1e-10);
247        assert_eq!(optimizer.get_inner_steps(), 10);
248    }
249
250    #[test]
251    fn test_meta_sgd_step_works() {
252        let mut optimizer = MetaSGD::new(0.1_f64).with_inner_steps(1);
253
254        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
255        let gradients = Array1::from_vec(vec![0.5, -0.5, 0.0]);
256
257        let new_params = optimizer.step(&params, &gradients).expect("step failed");
258
259        // With inner_steps=1, base_lr=0.1:
260        // delta = per_param_lr * gradients = [0.1*0.5, 0.1*(-0.5), 0.1*0.0] = [0.05, -0.05, 0.0]
261        // new_params = params - delta = [0.95, 2.05, 3.0]
262        assert!((new_params[0] - 0.95).abs() < 1e-10);
263        assert!((new_params[1] - 2.05).abs() < 1e-10);
264        assert!((new_params[2] - 3.0).abs() < 1e-10);
265        assert_eq!(optimizer.get_step_count(), 1);
266
267        // Per-param LR should be initialized now
268        assert!(optimizer.get_per_param_lr().is_some());
269    }
270
271    #[test]
272    fn test_meta_sgd_per_param_lr_adaptation() {
273        let mut optimizer = MetaSGD::new(0.1_f64)
274            .with_alpha_lr(0.01)
275            .with_inner_steps(1);
276
277        let params = Array1::from_vec(vec![1.0, 2.0]);
278        let gradients = Array1::from_vec(vec![1.0, 0.001]);
279
280        // First step initializes per-param LRs
281        let _ = optimizer.step(&params, &gradients).expect("step failed");
282
283        let lr_after_first = optimizer
284            .get_per_param_lr()
285            .expect("per_param_lr should exist")
286            .clone();
287
288        // The parameter with larger gradient (dim 0) should have its LR adjusted more
289        // than the parameter with smaller gradient (dim 1)
290        // meta_gradient = grad * delta = grad * (lr * grad) = lr * grad^2
291        // For dim 0: meta_grad = 0.1 * 1.0^2 = 0.1
292        //   new_lr = 0.1 - 0.01 * 0.1 = 0.099
293        // For dim 1: meta_grad = 0.1 * 0.001^2 = 0.0000001
294        //   new_lr = 0.1 - 0.01 * 0.0000001 ≈ 0.1
295        let lr_diff_0 = (lr_after_first[0] - 0.1_f64).abs();
296        let lr_diff_1 = (lr_after_first[1] - 0.1_f64).abs();
297        assert!(
298            lr_diff_0 > lr_diff_1,
299            "Larger gradient dimension should have more LR change: diff_0={lr_diff_0}, diff_1={lr_diff_1}"
300        );
301    }
302
303    #[test]
304    fn test_meta_sgd_convergence_toward_minimum() {
305        // Optimize f(x) = x^2, gradient = 2x
306        let mut optimizer = MetaSGD::new(0.05_f64)
307            .with_alpha_lr(0.0001)
308            .with_inner_steps(1);
309
310        let mut params = Array1::from_vec(vec![5.0, -3.0, 2.0]);
311
312        for _ in 0..200 {
313            let gradients = &params * 2.0;
314            params = optimizer.step(&params, &gradients).expect("step failed");
315        }
316
317        // After many steps, params should be close to zero
318        for &val in params.iter() {
319            assert!(
320                val.abs() < 0.5,
321                "Parameter {val} did not converge to near zero"
322            );
323        }
324    }
325
326    #[test]
327    fn test_meta_sgd_lr_clamping() {
328        // Use very large alpha_lr to force per-param LRs to be clamped
329        let mut optimizer = MetaSGD::new(0.1_f64)
330            .with_alpha_lr(100.0) // Extremely large meta-LR
331            .with_inner_steps(1);
332
333        let params = Array1::from_vec(vec![1.0, 2.0]);
334        let gradients = Array1::from_vec(vec![1.0, -1.0]);
335
336        // Run a step - the large alpha_lr should cause LRs to hit clamp bounds
337        let _ = optimizer.step(&params, &gradients).expect("step failed");
338
339        let per_param_lr = optimizer
340            .get_per_param_lr()
341            .expect("per_param_lr should exist");
342
343        // All LR values should be within [1e-8, 10.0]
344        for &lr in per_param_lr.iter() {
345            assert!(
346                (1e-8..=10.0).contains(&lr),
347                "Per-param LR {lr} is out of clamped range [1e-8, 10.0]"
348            );
349        }
350    }
351
352    #[test]
353    fn test_meta_sgd_zero_gradient() {
354        let mut optimizer = MetaSGD::new(0.1_f64).with_inner_steps(3);
355
356        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
357        let gradients = Array1::from_vec(vec![0.0, 0.0, 0.0]);
358
359        let new_params = optimizer.step(&params, &gradients).expect("step failed");
360
361        // With zero gradients, params should not change
362        for (p, np) in params.iter().zip(new_params.iter()) {
363            assert!(
364                (*p - *np).abs() < 1e-12,
365                "Params changed with zero gradient"
366            );
367        }
368    }
369
370    #[test]
371    fn test_meta_sgd_set_learning_rate_resets_per_param() {
372        let mut optimizer = MetaSGD::new(0.1_f64);
373        let params = Array1::from_vec(vec![1.0, 2.0]);
374        let gradients = Array1::from_vec(vec![0.1, 0.2]);
375
376        let _ = optimizer.step(&params, &gradients).expect("step failed");
377        assert!(optimizer.get_per_param_lr().is_some());
378
379        // Setting learning rate should reset per-param LRs
380        Optimizer::<f64, scirs2_core::ndarray::Ix1>::set_learning_rate(&mut optimizer, 0.05);
381        assert!(optimizer.get_per_param_lr().is_none());
382        assert!(
383            (Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.05)
384                .abs()
385                < 1e-10
386        );
387    }
388
389    #[test]
390    fn test_meta_sgd_inner_steps_zero_clamps_to_one() {
391        let optimizer: MetaSGD<f64> = MetaSGD::new(0.01).with_inner_steps(0);
392        assert_eq!(optimizer.get_inner_steps(), 1);
393    }
394
395    #[test]
396    fn test_meta_sgd_multiple_steps_count() {
397        let mut optimizer = MetaSGD::new(0.01_f64);
398        let params = Array1::from_vec(vec![1.0, 2.0]);
399        let gradients = Array1::from_vec(vec![0.1, 0.2]);
400
401        for i in 0..5 {
402            let _ = optimizer.step(&params, &gradients).expect("step failed");
403            assert_eq!(optimizer.get_step_count(), i + 1);
404        }
405    }
406
407    #[test]
408    fn test_meta_sgd_reset_per_param_lr() {
409        let mut optimizer = MetaSGD::new(0.1_f64);
410        let params = Array1::from_vec(vec![1.0]);
411        let gradients = Array1::from_vec(vec![0.1]);
412
413        let _ = optimizer.step(&params, &gradients).expect("step failed");
414        assert!(optimizer.get_per_param_lr().is_some());
415
416        optimizer.reset_per_param_lr();
417        assert!(optimizer.get_per_param_lr().is_none());
418    }
419}