use crate::common::OptimizerState;
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug)]
pub struct Lookahead<T: Optimizer> {
base_optimizer: T,
k: usize,
alpha: f32,
state: OptimizerState,
slow_weights: HashMap<String, Vec<f32>>,
fast_step_count: usize,
}
impl<T: Optimizer> Lookahead<T> {
pub fn new(base_optimizer: T, k: usize, alpha: f32) -> Self {
assert!(k > 0, "k must be positive");
assert!(alpha > 0.0 && alpha <= 1.0, "alpha must be in (0, 1]");
Self {
base_optimizer,
k,
alpha,
state: OptimizerState::new(),
slow_weights: HashMap::new(),
fast_step_count: 0,
}
}
pub fn base_optimizer(&self) -> &T {
&self.base_optimizer
}
pub fn base_optimizer_mut(&mut self) -> &mut T {
&mut self.base_optimizer
}
fn init_slow_weights(&mut self, parameter: &Tensor) -> Result<()> {
match parameter {
Tensor::F32(param) => {
let param_id = format!("{:p}", param.as_ptr());
self.slow_weights
.entry(param_id)
.or_insert_with(|| param.iter().cloned().collect());
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for Lookahead",
"init_slow_weights",
)),
}
}
fn update_slow_weights(&mut self, parameter: &mut Tensor) -> Result<()> {
match parameter {
Tensor::F32(param) => {
let param_id = format!("{:p}", param.as_ptr());
if let Some(slow_weights) = self.slow_weights.get_mut(¶m_id) {
if slow_weights.len() != param.len() {
return Err(TrustformersError::tensor_op_error(
"Lookahead slow weights size mismatch",
"slow weights validation",
));
}
for (slow_w, fast_w) in slow_weights.iter_mut().zip(param.iter()) {
*slow_w += self.alpha * (*fast_w - *slow_w);
}
for (p, slow_w) in param.iter_mut().zip(slow_weights.iter()) {
*p = *slow_w;
}
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for Lookahead",
"update_slow_weights",
)),
}
}
}
impl<T: Optimizer> Optimizer for Lookahead<T> {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
self.init_slow_weights(parameter)?;
self.base_optimizer.update(parameter, grad)?;
Ok(())
}
fn zero_grad(&mut self) {
self.base_optimizer.zero_grad();
}
fn step(&mut self) {
self.base_optimizer.step();
self.fast_step_count += 1;
if self.fast_step_count >= self.k {
self.fast_step_count = 0;
}
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.base_optimizer.get_lr()
}
fn set_lr(&mut self, lr: f32) {
self.base_optimizer.set_lr(lr);
}
}
impl<T: Optimizer> Lookahead<T> {
pub fn slow_step(&mut self, parameter: &mut Tensor) -> Result<()> {
if self.fast_step_count == 0 {
self.update_slow_weights(parameter)?;
}
Ok(())
}
}
pub type LookaheadAdam = Lookahead<crate::adam::Adam>;
pub type LookaheadAdamW = Lookahead<crate::adam::AdamW>;
pub type LookaheadRAdam = Lookahead<crate::adam::RAdam>;
pub type LookaheadNAdam = Lookahead<crate::adam::NAdam>;
pub type LookaheadSGD = Lookahead<crate::sgd::SGD>;