use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use std::marker::PhantomData;
pub struct Lookahead<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone,
D: Dimension,
{
inner_optimizer: O,
alpha: A,
k: usize,
current_step: usize,
slow_weights: Option<Array<A, D>>,
fast_weights: Option<Array<A, D>>,
use_slow_weights: bool,
_phantom: PhantomData<D>,
}
impl<A, O, D> Lookahead<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone,
D: Dimension,
{
pub fn new(inner_optimizer: O) -> Self {
Self {
inner_optimizer,
alpha: A::from(0.5).expect("unwrap failed"), k: 5, current_step: 0,
slow_weights: None,
fast_weights: None,
use_slow_weights: false,
_phantom: PhantomData,
}
}
pub fn with_config(inner_optimizer: O, alpha: A, k: usize) -> Self {
Self {
inner_optimizer,
alpha,
k,
current_step: 0,
slow_weights: None,
fast_weights: None,
use_slow_weights: false,
_phantom: PhantomData,
}
}
pub fn with_alpha(mut self, alpha: A) -> Self {
self.alpha = alpha;
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn inner_optimizer(&self) -> &O {
&self.inner_optimizer
}
pub fn inner_optimizer_mut(&mut self) -> &mut O {
&mut self.inner_optimizer
}
pub fn alpha(&self) -> A {
self.alpha
}
pub fn k(&self) -> usize {
self.k
}
pub fn use_slow_weights_for_eval(&mut self) {
self.use_slow_weights = true;
}
pub fn use_fast_weights_for_train(&mut self) {
self.use_slow_weights = false;
}
pub fn reset(&mut self) {
self.current_step = 0;
self.slow_weights = None;
self.fast_weights = None;
}
}
impl<A, O, D> Clone for Lookahead<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone,
D: Dimension,
{
fn clone(&self) -> Self {
Self {
inner_optimizer: self.inner_optimizer.clone(),
alpha: self.alpha,
k: self.k,
current_step: self.current_step,
slow_weights: self.slow_weights.clone(),
fast_weights: self.fast_weights.clone(),
use_slow_weights: self.use_slow_weights,
_phantom: PhantomData,
}
}
}
impl<A, O, D> Debug for Lookahead<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone + Debug,
D: Dimension,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Lookahead")
.field("inner_optimizer", &self.inner_optimizer)
.field("alpha", &self.alpha)
.field("k", &self.k)
.field("current_step", &self.current_step)
.field("use_slow_weights", &self.use_slow_weights)
.finish()
}
}
impl<A, O, D> Optimizer<A, D> for Lookahead<A, O, D>
where
A: Float + ScalarOperand + Debug + Send + Sync,
O: Optimizer<A, D> + Clone + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
if self.slow_weights.is_none() {
self.slow_weights = Some(params.clone());
self.fast_weights = Some(params.clone());
}
let fast_weights = match &mut self.fast_weights {
Some(w) => w,
None => {
return Err(OptimError::OptimizationError(
"Fast weights not initialized".to_string(),
))
}
};
let slow_weights = match &mut self.slow_weights {
Some(w) => w,
None => {
return Err(OptimError::OptimizationError(
"Slow weights not initialized".to_string(),
))
}
};
*fast_weights = self.inner_optimizer.step(fast_weights, gradients)?;
self.current_step += 1;
if self.current_step >= self.k {
let diff = &*fast_weights - &*slow_weights;
*slow_weights = &*slow_weights + &(diff * self.alpha);
*fast_weights = slow_weights.clone();
self.current_step = 0;
}
if self.use_slow_weights {
Ok(slow_weights.clone())
} else {
Ok(fast_weights.clone())
}
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.inner_optimizer.set_learning_rate(learning_rate);
}
fn get_learning_rate(&self) -> A {
self.inner_optimizer.get_learning_rate()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizers::sgd::SGD;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_lookahead_creation() {
let sgd = SGD::new(0.01);
let optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = Lookahead::new(sgd);
assert_abs_diff_eq!(optimizer.alpha(), 0.5);
assert_eq!(optimizer.k(), 5);
assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
}
#[test]
fn test_lookahead_with_config() {
let sgd = SGD::new(0.01);
let optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
Lookahead::with_config(sgd, 0.8, 10);
assert_abs_diff_eq!(optimizer.alpha(), 0.8);
assert_eq!(optimizer.k(), 10);
}
#[test]
fn test_lookahead_step() {
let mut sgd = SGD::new(0.1);
sgd.set_momentum(0.0);
let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
Lookahead::with_config(sgd, 0.5, 2);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_abs_diff_eq!(updated_params[0], 0.99, epsilon = 1e-6);
assert_abs_diff_eq!(updated_params[1], 1.98, epsilon = 1e-6);
assert_abs_diff_eq!(updated_params[2], 2.97, epsilon = 1e-6);
let updated_params2 = optimizer
.step(&updated_params, &gradients)
.expect("unwrap failed");
assert_abs_diff_eq!(updated_params2[0], 0.99, epsilon = 1e-6);
assert_abs_diff_eq!(updated_params2[1], 1.98, epsilon = 1e-6);
assert_abs_diff_eq!(updated_params2[2], 2.97, epsilon = 1e-6);
let updated_params3 = optimizer
.step(&updated_params2, &gradients)
.expect("unwrap failed");
assert_abs_diff_eq!(updated_params3[0], 0.98, epsilon = 1e-6);
assert_abs_diff_eq!(updated_params3[1], 1.96, epsilon = 1e-6);
assert_abs_diff_eq!(updated_params3[2], 2.94, epsilon = 1e-6);
}
#[test]
fn test_slow_weights_for_eval() {
let mut sgd = SGD::new(0.1);
sgd.set_momentum(0.0);
let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
Lookahead::with_config(sgd, 0.5, 2);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
optimizer.use_slow_weights_for_eval();
let eval_params = optimizer
.step(&updated_params, &gradients)
.expect("unwrap failed");
assert_abs_diff_eq!(eval_params[0], 0.99, epsilon = 1e-6);
assert_abs_diff_eq!(eval_params[1], 1.98, epsilon = 1e-6);
assert_abs_diff_eq!(eval_params[2], 2.97, epsilon = 1e-6);
optimizer.use_fast_weights_for_train();
let train_params = optimizer
.step(&eval_params, &gradients)
.expect("unwrap failed");
assert!(train_params[0] < 1.0);
}
#[test]
fn test_reset() {
let sgd = SGD::new(0.1);
let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
Lookahead::new(sgd);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let _ = optimizer.step(¶ms, &gradients).expect("unwrap failed");
optimizer.reset();
let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_abs_diff_eq!(updated_params[0], 0.99, epsilon = 1e-6);
}
}