Skip to main content

optirs_core/optimizers/
reptile.rs

1// Reptile meta-learning optimizer
2//
3// Reptile is a meta-learning algorithm that learns a good initialization for
4// model parameters. It works by:
5// 1. Saving the initial parameters (theta)
6// 2. Running N inner SGD steps on a task to get theta_adapted
7// 3. Updating: theta += epsilon * (theta_adapted - theta)
8//
9// Reference: Nichol, A., Achiam, J., & Schulman, J. (2018).
10// "On First-Order Meta-Learning Algorithms"
11
12use scirs2_core::ndarray::{Array, Dimension, IxDyn, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16use crate::error::Result;
17use crate::optimizers::Optimizer;
18
19/// Reptile meta-learning optimizer
20///
21/// Implements the Reptile algorithm for meta-learning. Reptile performs multiple
22/// inner SGD steps on a task, then interpolates between the original parameters
23/// and the adapted parameters using an interpolation factor epsilon.
24///
25/// # Algorithm
26///
27/// For each step:
28/// 1. Save initial parameters theta_0
29/// 2. Perform `inner_steps` SGD updates: theta_k = theta_{k-1} - inner_lr * grad
30/// 3. Compute meta-update: theta_new = theta_0 + epsilon * (theta_K - theta_0)
31///
32/// This effectively moves the initialization point toward a region that is
33/// beneficial for fast adaptation across tasks.
34///
35/// # Examples
36///
37/// ```
38/// use scirs2_core::ndarray::Array1;
39/// use optirs_core::optimizers::{ReptileOptimizer, Optimizer};
40///
41/// let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
42/// let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
43///
44/// let mut optimizer = ReptileOptimizer::new(0.01);
45/// let new_params = optimizer.step(&params, &gradients).expect("step failed");
46/// ```
47#[derive(Debug, Clone)]
48pub struct ReptileOptimizer<A: Float + ScalarOperand + Debug> {
49    /// Outer learning rate (used as default for epsilon)
50    learning_rate: A,
51    /// Inner SGD learning rate for task adaptation
52    inner_lr: A,
53    /// Number of inner loop SGD steps
54    inner_steps: usize,
55    /// Interpolation factor between original and adapted parameters
56    epsilon: A,
57    /// Count of outer steps taken
58    step_count: usize,
59}
60
61impl<A: Float + ScalarOperand + Debug> ReptileOptimizer<A> {
62    /// Creates a new Reptile optimizer with the given outer learning rate
63    ///
64    /// Defaults:
65    /// - inner_steps: 5
66    /// - epsilon: same as learning_rate
67    /// - inner_lr: same as learning_rate
68    ///
69    /// # Arguments
70    ///
71    /// * `lr` - The outer learning rate (also used as default epsilon and inner_lr)
72    pub fn new(lr: A) -> Self {
73        Self {
74            learning_rate: lr,
75            inner_lr: lr,
76            inner_steps: 5,
77            epsilon: lr,
78            step_count: 0,
79        }
80    }
81
82    /// Sets the number of inner SGD steps
83    ///
84    /// More inner steps allow better task adaptation but increase computation.
85    ///
86    /// # Arguments
87    ///
88    /// * `n` - Number of inner SGD steps (must be >= 1)
89    pub fn with_inner_steps(mut self, n: usize) -> Self {
90        self.inner_steps = if n == 0 { 1 } else { n };
91        self
92    }
93
94    /// Sets the interpolation factor epsilon
95    ///
96    /// Controls how much the meta-update moves toward the adapted parameters.
97    /// Smaller values mean more conservative updates.
98    ///
99    /// # Arguments
100    ///
101    /// * `e` - Interpolation factor (typically in [0, 1])
102    pub fn with_epsilon(mut self, e: A) -> Self {
103        self.epsilon = e;
104        self
105    }
106
107    /// Sets the inner SGD learning rate
108    ///
109    /// This learning rate is used for the inner adaptation steps on each task.
110    ///
111    /// # Arguments
112    ///
113    /// * `lr` - Inner learning rate
114    pub fn with_inner_lr(mut self, lr: A) -> Self {
115        self.inner_lr = lr;
116        self
117    }
118
119    /// Returns the number of inner steps configured
120    pub fn get_inner_steps(&self) -> usize {
121        self.inner_steps
122    }
123
124    /// Returns the current epsilon (interpolation factor)
125    pub fn get_epsilon(&self) -> A {
126        self.epsilon
127    }
128
129    /// Returns the inner learning rate
130    pub fn get_inner_lr(&self) -> A {
131        self.inner_lr
132    }
133
134    /// Returns the number of outer steps taken so far
135    pub fn get_step_count(&self) -> usize {
136        self.step_count
137    }
138}
139
140impl<A, D> Optimizer<A, D> for ReptileOptimizer<A>
141where
142    A: Float + ScalarOperand + Debug,
143    D: Dimension,
144{
145    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
146        // Convert to dynamic dimension for internal computation
147        let params_dyn = params.to_owned().into_dyn();
148        let gradients_dyn = gradients.to_owned().into_dyn();
149
150        // Save original parameters (theta_0)
151        let theta_original = params_dyn.clone();
152
153        // Run inner SGD steps: theta_k = theta_{k-1} - inner_lr * gradients
154        // In Reptile, we simulate multiple inner steps using the same gradient
155        // (in practice, each step would use a gradient from the current params,
156        // but with a single gradient call we approximate this)
157        let mut theta_adapted = params_dyn;
158        for _ in 0..self.inner_steps {
159            theta_adapted = &theta_adapted - &(&gradients_dyn * self.inner_lr);
160        }
161
162        // Compute the meta-update direction: (theta_adapted - theta_original)
163        let meta_direction = &theta_adapted - &theta_original;
164
165        // Apply Reptile update: theta_new = theta_original + epsilon * meta_direction
166        let updated_params = &theta_original + &(&meta_direction * self.epsilon);
167
168        self.step_count += 1;
169
170        // Convert back to original dimension
171        Ok(updated_params
172            .into_dimensionality::<D>()
173            .expect("Reptile: failed to convert back to original dimensionality"))
174    }
175
176    fn get_learning_rate(&self) -> A {
177        self.learning_rate
178    }
179
180    fn set_learning_rate(&mut self, learning_rate: A) {
181        self.learning_rate = learning_rate;
182        self.epsilon = learning_rate;
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use scirs2_core::ndarray::Array1;
190
191    #[test]
192    fn test_reptile_basic_creation() {
193        let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01);
194        assert!(
195            (Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.01)
196                .abs()
197                < 1e-10
198        );
199        assert_eq!(optimizer.get_inner_steps(), 5);
200        assert!((optimizer.get_epsilon() - 0.01).abs() < 1e-10);
201        assert!((optimizer.get_inner_lr() - 0.01).abs() < 1e-10);
202        assert_eq!(optimizer.get_step_count(), 0);
203    }
204
205    #[test]
206    fn test_reptile_builder_pattern() {
207        let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01)
208            .with_inner_steps(10)
209            .with_epsilon(0.05)
210            .with_inner_lr(0.001);
211
212        assert_eq!(optimizer.get_inner_steps(), 10);
213        assert!((optimizer.get_epsilon() - 0.05).abs() < 1e-10);
214        assert!((optimizer.get_inner_lr() - 0.001).abs() < 1e-10);
215    }
216
217    #[test]
218    fn test_reptile_step_works() {
219        let mut optimizer = ReptileOptimizer::new(0.1_f64)
220            .with_inner_steps(1)
221            .with_epsilon(1.0)
222            .with_inner_lr(0.1);
223
224        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
225        let gradients = Array1::from_vec(vec![0.5, -0.5, 0.0]);
226
227        let new_params = optimizer.step(&params, &gradients).expect("step failed");
228
229        // With inner_steps=1, epsilon=1.0:
230        // theta_adapted = params - inner_lr * gradients = [1.0 - 0.05, 2.0 + 0.05, 3.0]
231        // meta_direction = theta_adapted - params = [-0.05, 0.05, 0.0]
232        // result = params + 1.0 * meta_direction = [0.95, 2.05, 3.0]
233        assert!((new_params[0] - 0.95).abs() < 1e-10);
234        assert!((new_params[1] - 2.05).abs() < 1e-10);
235        assert!((new_params[2] - 3.0).abs() < 1e-10);
236        assert_eq!(optimizer.get_step_count(), 1);
237    }
238
239    #[test]
240    fn test_reptile_convergence_toward_minimum() {
241        // Optimize f(x) = x^2, gradient = 2x
242        // Minimum is at x = 0
243        let mut optimizer = ReptileOptimizer::new(0.1_f64)
244            .with_inner_steps(3)
245            .with_epsilon(0.5)
246            .with_inner_lr(0.1);
247
248        let mut params = Array1::from_vec(vec![5.0, -3.0, 2.0]);
249
250        for _ in 0..100 {
251            let gradients = &params * 2.0; // gradient of x^2
252            params = optimizer.step(&params, &gradients).expect("step failed");
253        }
254
255        // After many steps, params should be close to zero
256        for &val in params.iter() {
257            assert!(
258                val.abs() < 0.1,
259                "Parameter {val} did not converge to near zero"
260            );
261        }
262    }
263
264    #[test]
265    fn test_reptile_multiple_steps_decrement_count() {
266        let mut optimizer = ReptileOptimizer::new(0.01_f64);
267        let params = Array1::from_vec(vec![1.0, 2.0]);
268        let gradients = Array1::from_vec(vec![0.1, 0.2]);
269
270        for i in 0..5 {
271            let _new_params = optimizer.step(&params, &gradients).expect("step failed");
272            assert_eq!(optimizer.get_step_count(), i + 1);
273        }
274        assert_eq!(optimizer.get_step_count(), 5);
275    }
276
277    #[test]
278    fn test_reptile_zero_gradient() {
279        let mut optimizer = ReptileOptimizer::new(0.1_f64).with_inner_steps(5);
280
281        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
282        let gradients = Array1::from_vec(vec![0.0, 0.0, 0.0]);
283
284        let new_params = optimizer.step(&params, &gradients).expect("step failed");
285
286        // With zero gradients, params should not change
287        for (p, np) in params.iter().zip(new_params.iter()) {
288            assert!(
289                (*p - *np).abs() < 1e-12,
290                "Params changed with zero gradient"
291            );
292        }
293    }
294
295    #[test]
296    fn test_reptile_inner_steps_zero_clamps_to_one() {
297        let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01).with_inner_steps(0);
298        assert_eq!(optimizer.get_inner_steps(), 1);
299    }
300
301    #[test]
302    fn test_reptile_set_learning_rate() {
303        let mut optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01);
304        Optimizer::<f64, scirs2_core::ndarray::Ix1>::set_learning_rate(&mut optimizer, 0.05);
305        assert!(
306            (Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.05)
307                .abs()
308                < 1e-10
309        );
310        assert!((optimizer.get_epsilon() - 0.05).abs() < 1e-10);
311    }
312
313    #[test]
314    fn test_reptile_multiple_inner_steps_effect() {
315        // More inner steps should result in a larger effective update
316        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
317        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
318
319        let mut opt_1step = ReptileOptimizer::new(0.1_f64)
320            .with_inner_steps(1)
321            .with_epsilon(1.0)
322            .with_inner_lr(0.1);
323
324        let mut opt_5steps = ReptileOptimizer::new(0.1_f64)
325            .with_inner_steps(5)
326            .with_epsilon(1.0)
327            .with_inner_lr(0.1);
328
329        let result_1 = opt_1step.step(&params, &gradients).expect("step failed");
330        let result_5 = opt_5steps.step(&params, &gradients).expect("step failed");
331
332        // 5-step version should have moved further from original params
333        let diff_1: f64 = params
334            .iter()
335            .zip(result_1.iter())
336            .map(|(a, b)| (*a - *b).powi(2))
337            .sum();
338        let diff_5: f64 = params
339            .iter()
340            .zip(result_5.iter())
341            .map(|(a, b)| (*a - *b).powi(2))
342            .sum();
343
344        assert!(
345            diff_5 > diff_1,
346            "More inner steps should cause larger displacement: diff_5={diff_5}, diff_1={diff_1}"
347        );
348    }
349}