optirs_core/optimizers/
reptile.rs1use scirs2_core::ndarray::{Array, Dimension, IxDyn, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16use crate::error::Result;
17use crate::optimizers::Optimizer;
18
19#[derive(Debug, Clone)]
48pub struct ReptileOptimizer<A: Float + ScalarOperand + Debug> {
49 learning_rate: A,
51 inner_lr: A,
53 inner_steps: usize,
55 epsilon: A,
57 step_count: usize,
59}
60
61impl<A: Float + ScalarOperand + Debug> ReptileOptimizer<A> {
62 pub fn new(lr: A) -> Self {
73 Self {
74 learning_rate: lr,
75 inner_lr: lr,
76 inner_steps: 5,
77 epsilon: lr,
78 step_count: 0,
79 }
80 }
81
82 pub fn with_inner_steps(mut self, n: usize) -> Self {
90 self.inner_steps = if n == 0 { 1 } else { n };
91 self
92 }
93
94 pub fn with_epsilon(mut self, e: A) -> Self {
103 self.epsilon = e;
104 self
105 }
106
107 pub fn with_inner_lr(mut self, lr: A) -> Self {
115 self.inner_lr = lr;
116 self
117 }
118
119 pub fn get_inner_steps(&self) -> usize {
121 self.inner_steps
122 }
123
124 pub fn get_epsilon(&self) -> A {
126 self.epsilon
127 }
128
129 pub fn get_inner_lr(&self) -> A {
131 self.inner_lr
132 }
133
134 pub fn get_step_count(&self) -> usize {
136 self.step_count
137 }
138}
139
140impl<A, D> Optimizer<A, D> for ReptileOptimizer<A>
141where
142 A: Float + ScalarOperand + Debug,
143 D: Dimension,
144{
145 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
146 let params_dyn = params.to_owned().into_dyn();
148 let gradients_dyn = gradients.to_owned().into_dyn();
149
150 let theta_original = params_dyn.clone();
152
153 let mut theta_adapted = params_dyn;
158 for _ in 0..self.inner_steps {
159 theta_adapted = &theta_adapted - &(&gradients_dyn * self.inner_lr);
160 }
161
162 let meta_direction = &theta_adapted - &theta_original;
164
165 let updated_params = &theta_original + &(&meta_direction * self.epsilon);
167
168 self.step_count += 1;
169
170 Ok(updated_params
172 .into_dimensionality::<D>()
173 .expect("Reptile: failed to convert back to original dimensionality"))
174 }
175
176 fn get_learning_rate(&self) -> A {
177 self.learning_rate
178 }
179
180 fn set_learning_rate(&mut self, learning_rate: A) {
181 self.learning_rate = learning_rate;
182 self.epsilon = learning_rate;
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use scirs2_core::ndarray::Array1;
190
191 #[test]
192 fn test_reptile_basic_creation() {
193 let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01);
194 assert!(
195 (Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.01)
196 .abs()
197 < 1e-10
198 );
199 assert_eq!(optimizer.get_inner_steps(), 5);
200 assert!((optimizer.get_epsilon() - 0.01).abs() < 1e-10);
201 assert!((optimizer.get_inner_lr() - 0.01).abs() < 1e-10);
202 assert_eq!(optimizer.get_step_count(), 0);
203 }
204
205 #[test]
206 fn test_reptile_builder_pattern() {
207 let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01)
208 .with_inner_steps(10)
209 .with_epsilon(0.05)
210 .with_inner_lr(0.001);
211
212 assert_eq!(optimizer.get_inner_steps(), 10);
213 assert!((optimizer.get_epsilon() - 0.05).abs() < 1e-10);
214 assert!((optimizer.get_inner_lr() - 0.001).abs() < 1e-10);
215 }
216
217 #[test]
218 fn test_reptile_step_works() {
219 let mut optimizer = ReptileOptimizer::new(0.1_f64)
220 .with_inner_steps(1)
221 .with_epsilon(1.0)
222 .with_inner_lr(0.1);
223
224 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
225 let gradients = Array1::from_vec(vec![0.5, -0.5, 0.0]);
226
227 let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
228
229 assert!((new_params[0] - 0.95).abs() < 1e-10);
234 assert!((new_params[1] - 2.05).abs() < 1e-10);
235 assert!((new_params[2] - 3.0).abs() < 1e-10);
236 assert_eq!(optimizer.get_step_count(), 1);
237 }
238
239 #[test]
240 fn test_reptile_convergence_toward_minimum() {
241 let mut optimizer = ReptileOptimizer::new(0.1_f64)
244 .with_inner_steps(3)
245 .with_epsilon(0.5)
246 .with_inner_lr(0.1);
247
248 let mut params = Array1::from_vec(vec![5.0, -3.0, 2.0]);
249
250 for _ in 0..100 {
251 let gradients = ¶ms * 2.0; params = optimizer.step(¶ms, &gradients).expect("step failed");
253 }
254
255 for &val in params.iter() {
257 assert!(
258 val.abs() < 0.1,
259 "Parameter {val} did not converge to near zero"
260 );
261 }
262 }
263
264 #[test]
265 fn test_reptile_multiple_steps_decrement_count() {
266 let mut optimizer = ReptileOptimizer::new(0.01_f64);
267 let params = Array1::from_vec(vec![1.0, 2.0]);
268 let gradients = Array1::from_vec(vec![0.1, 0.2]);
269
270 for i in 0..5 {
271 let _new_params = optimizer.step(¶ms, &gradients).expect("step failed");
272 assert_eq!(optimizer.get_step_count(), i + 1);
273 }
274 assert_eq!(optimizer.get_step_count(), 5);
275 }
276
277 #[test]
278 fn test_reptile_zero_gradient() {
279 let mut optimizer = ReptileOptimizer::new(0.1_f64).with_inner_steps(5);
280
281 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
282 let gradients = Array1::from_vec(vec![0.0, 0.0, 0.0]);
283
284 let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
285
286 for (p, np) in params.iter().zip(new_params.iter()) {
288 assert!(
289 (*p - *np).abs() < 1e-12,
290 "Params changed with zero gradient"
291 );
292 }
293 }
294
295 #[test]
296 fn test_reptile_inner_steps_zero_clamps_to_one() {
297 let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01).with_inner_steps(0);
298 assert_eq!(optimizer.get_inner_steps(), 1);
299 }
300
301 #[test]
302 fn test_reptile_set_learning_rate() {
303 let mut optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01);
304 Optimizer::<f64, scirs2_core::ndarray::Ix1>::set_learning_rate(&mut optimizer, 0.05);
305 assert!(
306 (Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.05)
307 .abs()
308 < 1e-10
309 );
310 assert!((optimizer.get_epsilon() - 0.05).abs() < 1e-10);
311 }
312
313 #[test]
314 fn test_reptile_multiple_inner_steps_effect() {
315 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
317 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
318
319 let mut opt_1step = ReptileOptimizer::new(0.1_f64)
320 .with_inner_steps(1)
321 .with_epsilon(1.0)
322 .with_inner_lr(0.1);
323
324 let mut opt_5steps = ReptileOptimizer::new(0.1_f64)
325 .with_inner_steps(5)
326 .with_epsilon(1.0)
327 .with_inner_lr(0.1);
328
329 let result_1 = opt_1step.step(¶ms, &gradients).expect("step failed");
330 let result_5 = opt_5steps.step(¶ms, &gradients).expect("step failed");
331
332 let diff_1: f64 = params
334 .iter()
335 .zip(result_1.iter())
336 .map(|(a, b)| (*a - *b).powi(2))
337 .sum();
338 let diff_5: f64 = params
339 .iter()
340 .zip(result_5.iter())
341 .map(|(a, b)| (*a - *b).powi(2))
342 .sum();
343
344 assert!(
345 diff_5 > diff_1,
346 "More inner steps should cause larger displacement: diff_5={diff_5}, diff_1={diff_1}"
347 );
348 }
349}