use std::collections::HashMap;
use yscv_tensor::Tensor;
use super::error::OptimError;
use super::{Adagrad, Adam, AdamW, Lamb, Lars, RAdam, RmsProp, Sgd};
mod sealed {
pub trait Sealed {}
}
pub trait StepOptimizer: sealed::Sealed {
fn step(
&mut self,
parameter_id: u64,
weights: &mut Tensor,
grad: &Tensor,
) -> Result<(), OptimError>;
}
macro_rules! impl_step_optimizer {
($($ty:ty),*) => {
$(
impl sealed::Sealed for $ty {}
impl StepOptimizer for $ty {
fn step(
&mut self,
parameter_id: u64,
weights: &mut Tensor,
grad: &Tensor,
) -> Result<(), OptimError> {
<$ty>::step(self, parameter_id, weights, grad)
}
}
)*
};
}
impl_step_optimizer!(Sgd, Adam, AdamW, RmsProp, Adagrad, RAdam, Lamb, Lars);
#[derive(Debug, Clone)]
pub struct Lookahead<O> {
inner: O,
alpha: f32,
k: usize,
step_count: usize,
slow_weights: HashMap<u64, Vec<f32>>,
}
impl<O: StepOptimizer> Lookahead<O> {
pub fn new(inner: O, alpha: f32, k: usize) -> Self {
Self {
inner,
alpha,
k,
step_count: 0,
slow_weights: HashMap::new(),
}
}
pub fn step(
&mut self,
parameter_id: u64,
weights: &mut Tensor,
grad: &Tensor,
) -> Result<(), OptimError> {
self.inner.step(parameter_id, weights, grad)?;
self.slow_weights
.entry(parameter_id)
.or_insert_with(|| weights.data().to_vec());
self.step_count += 1;
if self.step_count.is_multiple_of(self.k) {
let slow = self
.slow_weights
.get_mut(¶meter_id)
.expect("slow weights must exist");
let fast = weights.data_mut();
for (s, f) in slow.iter_mut().zip(fast.iter_mut()) {
*s += self.alpha * (*f - *s);
*f = *s;
}
}
Ok(())
}
}