use scirs2_core::ndarray::{Array, Dimension, IxDyn, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct MetaSGD<A: Float + ScalarOperand + Debug> {
base_lr: A,
alpha_lr: A,
inner_steps: usize,
per_param_lr: Option<Array<A, IxDyn>>,
step_count: usize,
}
impl<A: Float + ScalarOperand + Debug> MetaSGD<A> {
pub fn new(base_lr: A) -> Self {
Self {
base_lr,
alpha_lr: A::from(0.001).expect("MetaSGD: failed to convert alpha_lr constant"),
inner_steps: 5,
per_param_lr: None,
step_count: 0,
}
}
pub fn with_alpha_lr(mut self, lr: A) -> Self {
self.alpha_lr = lr;
self
}
pub fn with_inner_steps(mut self, n: usize) -> Self {
self.inner_steps = if n == 0 { 1 } else { n };
self
}
pub fn get_base_lr(&self) -> A {
self.base_lr
}
pub fn get_alpha_lr(&self) -> A {
self.alpha_lr
}
pub fn get_inner_steps(&self) -> usize {
self.inner_steps
}
pub fn get_step_count(&self) -> usize {
self.step_count
}
pub fn get_per_param_lr(&self) -> Option<&Array<A, IxDyn>> {
self.per_param_lr.as_ref()
}
pub fn reset_per_param_lr(&mut self) {
self.per_param_lr = None;
}
fn clamp_lr_array(lr_array: &mut Array<A, IxDyn>, min_val: A, max_val: A) {
lr_array.mapv_inplace(|v| {
if v < min_val {
min_val
} else if v > max_val {
max_val
} else {
v
}
});
}
}
impl<A, D> Optimizer<A, D> for MetaSGD<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
let min_lr = A::from(1e-8).expect("MetaSGD: failed to convert min_lr constant");
let max_lr = A::from(10.0).expect("MetaSGD: failed to convert max_lr constant");
if self.per_param_lr.is_none() {
let lr_init = Array::<A, IxDyn>::from_elem(params_dyn.raw_dim(), self.base_lr);
self.per_param_lr = Some(lr_init);
}
{
let current_lr = self
.per_param_lr
.as_ref()
.expect("MetaSGD: per_param_lr should be initialized");
if current_lr.raw_dim() != params_dyn.raw_dim() {
self.per_param_lr = Some(Array::<A, IxDyn>::from_elem(
params_dyn.raw_dim(),
self.base_lr,
));
}
}
let per_param_lr = self
.per_param_lr
.as_ref()
.expect("MetaSGD: per_param_lr should be initialized")
.clone();
let mut adapted_params = params_dyn.clone();
let mut cumulative_delta = Array::<A, IxDyn>::zeros(params_dyn.raw_dim());
for _ in 0..self.inner_steps {
let delta = &per_param_lr * &gradients_dyn;
cumulative_delta = &cumulative_delta + δ
adapted_params = &adapted_params - δ
}
let meta_gradient = &gradients_dyn * &cumulative_delta;
let mut updated_lr = &per_param_lr - &(&meta_gradient * self.alpha_lr);
Self::clamp_lr_array(&mut updated_lr, min_lr, max_lr);
self.per_param_lr = Some(updated_lr);
self.step_count += 1;
Ok(adapted_params
.into_dimensionality::<D>()
.expect("MetaSGD: failed to convert back to original dimensionality"))
}
fn get_learning_rate(&self) -> A {
self.base_lr
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.base_lr = learning_rate;
self.per_param_lr = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_meta_sgd_basic_creation() {
let optimizer: MetaSGD<f64> = MetaSGD::new(0.01);
assert!((optimizer.get_base_lr() - 0.01).abs() < 1e-10);
assert!((optimizer.get_alpha_lr() - 0.001).abs() < 1e-10);
assert_eq!(optimizer.get_inner_steps(), 5);
assert_eq!(optimizer.get_step_count(), 0);
assert!(optimizer.get_per_param_lr().is_none());
}
#[test]
fn test_meta_sgd_builder_pattern() {
let optimizer: MetaSGD<f64> = MetaSGD::new(0.01)
.with_alpha_lr(0.0001)
.with_inner_steps(10);
assert!((optimizer.get_alpha_lr() - 0.0001).abs() < 1e-10);
assert_eq!(optimizer.get_inner_steps(), 10);
}
#[test]
fn test_meta_sgd_step_works() {
let mut optimizer = MetaSGD::new(0.1_f64).with_inner_steps(1);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.5, -0.5, 0.0]);
let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
assert!((new_params[0] - 0.95).abs() < 1e-10);
assert!((new_params[1] - 2.05).abs() < 1e-10);
assert!((new_params[2] - 3.0).abs() < 1e-10);
assert_eq!(optimizer.get_step_count(), 1);
assert!(optimizer.get_per_param_lr().is_some());
}
#[test]
fn test_meta_sgd_per_param_lr_adaptation() {
let mut optimizer = MetaSGD::new(0.1_f64)
.with_alpha_lr(0.01)
.with_inner_steps(1);
let params = Array1::from_vec(vec![1.0, 2.0]);
let gradients = Array1::from_vec(vec![1.0, 0.001]);
let _ = optimizer.step(¶ms, &gradients).expect("step failed");
let lr_after_first = optimizer
.get_per_param_lr()
.expect("per_param_lr should exist")
.clone();
let lr_diff_0 = (lr_after_first[0] - 0.1_f64).abs();
let lr_diff_1 = (lr_after_first[1] - 0.1_f64).abs();
assert!(
lr_diff_0 > lr_diff_1,
"Larger gradient dimension should have more LR change: diff_0={lr_diff_0}, diff_1={lr_diff_1}"
);
}
#[test]
fn test_meta_sgd_convergence_toward_minimum() {
let mut optimizer = MetaSGD::new(0.05_f64)
.with_alpha_lr(0.0001)
.with_inner_steps(1);
let mut params = Array1::from_vec(vec![5.0, -3.0, 2.0]);
for _ in 0..200 {
let gradients = ¶ms * 2.0;
params = optimizer.step(¶ms, &gradients).expect("step failed");
}
for &val in params.iter() {
assert!(
val.abs() < 0.5,
"Parameter {val} did not converge to near zero"
);
}
}
#[test]
fn test_meta_sgd_lr_clamping() {
let mut optimizer = MetaSGD::new(0.1_f64)
.with_alpha_lr(100.0) .with_inner_steps(1);
let params = Array1::from_vec(vec![1.0, 2.0]);
let gradients = Array1::from_vec(vec![1.0, -1.0]);
let _ = optimizer.step(¶ms, &gradients).expect("step failed");
let per_param_lr = optimizer
.get_per_param_lr()
.expect("per_param_lr should exist");
for &lr in per_param_lr.iter() {
assert!(
(1e-8..=10.0).contains(&lr),
"Per-param LR {lr} is out of clamped range [1e-8, 10.0]"
);
}
}
#[test]
fn test_meta_sgd_zero_gradient() {
let mut optimizer = MetaSGD::new(0.1_f64).with_inner_steps(3);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.0, 0.0, 0.0]);
let new_params = optimizer.step(¶ms, &gradients).expect("step failed");
for (p, np) in params.iter().zip(new_params.iter()) {
assert!(
(*p - *np).abs() < 1e-12,
"Params changed with zero gradient"
);
}
}
#[test]
fn test_meta_sgd_set_learning_rate_resets_per_param() {
let mut optimizer = MetaSGD::new(0.1_f64);
let params = Array1::from_vec(vec![1.0, 2.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2]);
let _ = optimizer.step(¶ms, &gradients).expect("step failed");
assert!(optimizer.get_per_param_lr().is_some());
Optimizer::<f64, scirs2_core::ndarray::Ix1>::set_learning_rate(&mut optimizer, 0.05);
assert!(optimizer.get_per_param_lr().is_none());
assert!(
(Optimizer::<f64, scirs2_core::ndarray::Ix1>::get_learning_rate(&optimizer) - 0.05)
.abs()
< 1e-10
);
}
#[test]
fn test_meta_sgd_inner_steps_zero_clamps_to_one() {
let optimizer: MetaSGD<f64> = MetaSGD::new(0.01).with_inner_steps(0);
assert_eq!(optimizer.get_inner_steps(), 1);
}
#[test]
fn test_meta_sgd_multiple_steps_count() {
let mut optimizer = MetaSGD::new(0.01_f64);
let params = Array1::from_vec(vec![1.0, 2.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2]);
for i in 0..5 {
let _ = optimizer.step(¶ms, &gradients).expect("step failed");
assert_eq!(optimizer.get_step_count(), i + 1);
}
}
#[test]
fn test_meta_sgd_reset_per_param_lr() {
let mut optimizer = MetaSGD::new(0.1_f64);
let params = Array1::from_vec(vec![1.0]);
let gradients = Array1::from_vec(vec![0.1]);
let _ = optimizer.step(¶ms, &gradients).expect("step failed");
assert!(optimizer.get_per_param_lr().is_some());
optimizer.reset_per_param_lr();
assert!(optimizer.get_per_param_lr().is_none());
}
}