use scirs2_core::ndarray::{Array, Dimension, IxDyn, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct ReptileOptimizer<A: Float + ScalarOperand + Debug> {
learning_rate: A,
inner_lr: A,
inner_steps: usize,
epsilon: A,
step_count: usize,
}
impl<A: Float + ScalarOperand + Debug> ReptileOptimizer<A> {
pub fn new(lr: A) -> Self {
Self {
learning_rate: lr,
inner_lr: lr,
inner_steps: 5,
epsilon: lr,
step_count: 0,
}
}
pub fn with_inner_steps(mut self, n: usize) -> Self {
self.inner_steps = if n == 0 { 1 } else { n };
self
}
pub fn with_epsilon(mut self, e: A) -> Self {
self.epsilon = e;
self
}
pub fn with_inner_lr(mut self, lr: A) -> Self {
self.inner_lr = lr;
self
}
pub fn get_inner_steps(&self) -> usize {
self.inner_steps
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn get_inner_lr(&self) -> A {
self.inner_lr
}
pub fn get_step_count(&self) -> usize {
self.step_count
}
}
impl<A, D> Optimizer<A, D> for ReptileOptimizer<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
let theta_original = params_dyn.clone();
let mut theta_adapted = params_dyn;
for _ in 0..self.inner_steps {
theta_adapted = &theta_adapted - &(&gradients_dyn * self.inner_lr);
}
let meta_direction = &theta_adapted - &theta_original;
let updated_params = &theta_original + &(&meta_direction * self.epsilon);
self.step_count += 1;
Ok(updated_params
.into_dimensionality::<D>()
.expect("Reptile: failed to convert back to original dimensionality"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
self.epsilon = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_reptile_basic_creation() {
let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01);
assert!(
(Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.01)
.abs()
< 1e-10
);
assert_eq!(optimizer.get_inner_steps(), 5);
assert!((optimizer.get_epsilon() - 0.01).abs() < 1e-10);
assert!((optimizer.get_inner_lr() - 0.01).abs() < 1e-10);
assert_eq!(optimizer.get_step_count(), 0);
}
#[test]
fn test_reptile_builder_pattern() {
let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01)
.with_inner_steps(10)
.with_epsilon(0.05)
.with_inner_lr(0.001);
assert_eq!(optimizer.get_inner_steps(), 10);
assert!((optimizer.get_epsilon() - 0.05).abs() < 1e-10);
assert!((optimizer.get_inner_lr() - 0.001).abs() < 1e-10);
}
#[test]
fn test_reptile_step_works() {
let mut optimizer = ReptileOptimizer::new(0.1_f64)
.with_inner_steps(1)
.with_epsilon(1.0)
.with_inner_lr(0.1);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.5, -0.5, 0.0]);
let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
assert!((new_params[0] - 0.95).abs() < 1e-10);
assert!((new_params[1] - 2.05).abs() < 1e-10);
assert!((new_params[2] - 3.0).abs() < 1e-10);
assert_eq!(optimizer.get_step_count(), 1);
}
#[test]
fn test_reptile_convergence_toward_minimum() {
let mut optimizer = ReptileOptimizer::new(0.1_f64)
.with_inner_steps(3)
.with_epsilon(0.5)
.with_inner_lr(0.1);
let mut params = Array1::from_vec(vec![5.0, -3.0, 2.0]);
for _ in 0..100 {
let gradients = ¶ms * 2.0; params = optimizer.step(¶ms, &gradients).expect("step failed");
}
for &val in params.iter() {
assert!(
val.abs() < 0.1,
"Parameter {val} did not converge to near zero"
);
}
}
#[test]
fn test_reptile_multiple_steps_decrement_count() {
let mut optimizer = ReptileOptimizer::new(0.01_f64);
let params = Array1::from_vec(vec![1.0, 2.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2]);
for i in 0..5 {
let _new_params = optimizer.step(¶ms, &gradients).expect("step failed");
assert_eq!(optimizer.get_step_count(), i + 1);
}
assert_eq!(optimizer.get_step_count(), 5);
}
#[test]
fn test_reptile_zero_gradient() {
let mut optimizer = ReptileOptimizer::new(0.1_f64).with_inner_steps(5);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.0, 0.0, 0.0]);
let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
for (p, np) in params.iter().zip(new_params.iter()) {
assert!(
(*p - *np).abs() < 1e-12,
"Params changed with zero gradient"
);
}
}
#[test]
fn test_reptile_inner_steps_zero_clamps_to_one() {
let optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01).with_inner_steps(0);
assert_eq!(optimizer.get_inner_steps(), 1);
}
#[test]
fn test_reptile_set_learning_rate() {
let mut optimizer: ReptileOptimizer<f64> = ReptileOptimizer::new(0.01);
Optimizer::<f64, scirs2_core::ndarray::Ix1>::set_learning_rate(&mut optimizer, 0.05);
assert!(
(Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.05)
.abs()
< 1e-10
);
assert!((optimizer.get_epsilon() - 0.05).abs() < 1e-10);
}
#[test]
fn test_reptile_multiple_inner_steps_effect() {
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut opt_1step = ReptileOptimizer::new(0.1_f64)
.with_inner_steps(1)
.with_epsilon(1.0)
.with_inner_lr(0.1);
let mut opt_5steps = ReptileOptimizer::new(0.1_f64)
.with_inner_steps(5)
.with_epsilon(1.0)
.with_inner_lr(0.1);
let result_1 = opt_1step.step(¶ms, &gradients).expect("step failed");
let result_5 = opt_5steps.step(¶ms, &gradients).expect("step failed");
let diff_1: f64 = params
.iter()
.zip(result_1.iter())
.map(|(a, b)| (*a - *b).powi(2))
.sum();
let diff_5: f64 = params
.iter()
.zip(result_5.iter())
.map(|(a, b)| (*a - *b).powi(2))
.sum();
assert!(
diff_5 > diff_1,
"More inner steps should cause larger displacement: diff_5={diff_5}, diff_1={diff_1}"
);
}
}