Skip to main content

yscv_optim/
lookahead.rs

1use std::collections::HashMap;
2
3use yscv_tensor::Tensor;
4
5use super::error::OptimError;
6use super::{Adagrad, Adam, AdamW, Lamb, Lars, RAdam, RmsProp, Sgd};
7
8/// Trait for optimizers that support a per-parameter `step` update.
9pub trait StepOptimizer {
10    fn step(
11        &mut self,
12        parameter_id: u64,
13        weights: &mut Tensor,
14        grad: &Tensor,
15    ) -> Result<(), OptimError>;
16}
17
18macro_rules! impl_step_optimizer {
19    ($($ty:ty),*) => {
20        $(
21            impl StepOptimizer for $ty {
22                fn step(
23                    &mut self,
24                    parameter_id: u64,
25                    weights: &mut Tensor,
26                    grad: &Tensor,
27                ) -> Result<(), OptimError> {
28                    <$ty>::step(self, parameter_id, weights, grad)
29                }
30            }
31        )*
32    };
33}
34
35impl_step_optimizer!(Sgd, Adam, AdamW, RmsProp, Adagrad, RAdam, Lamb, Lars);
36
37/// Lookahead optimizer wrapper.
38///
39/// Maintains "slow weights" that are periodically interpolated toward the fast
40/// weights produced by the inner optimizer.  Every `k` calls to `step`, the
41/// slow weights are updated via:
42///
43/// ```text
44/// slow_w = slow_w + alpha * (fast_w - slow_w)
45/// fast_w = slow_w
46/// ```
47#[derive(Debug, Clone)]
48pub struct Lookahead<O> {
49    inner: O,
50    alpha: f32,
51    k: usize,
52    step_count: usize,
53    slow_weights: HashMap<u64, Vec<f32>>,
54}
55
56impl<O: StepOptimizer> Lookahead<O> {
57    /// Creates a new `Lookahead` wrapper around the given optimizer.
58    ///
59    /// * `alpha` — interpolation coefficient (typically 0.5).
60    /// * `k` — synchronisation period (typically 5).
61    pub fn new(inner: O, alpha: f32, k: usize) -> Self {
62        Self {
63            inner,
64            alpha,
65            k,
66            step_count: 0,
67            slow_weights: HashMap::new(),
68        }
69    }
70
71    /// Performs one optimisation step.
72    ///
73    /// 1. Delegates to the inner optimizer to update the fast weights.
74    /// 2. Increments the internal step counter.
75    /// 3. Every `k` steps, synchronises slow and fast weights.
76    pub fn step(
77        &mut self,
78        parameter_id: u64,
79        weights: &mut Tensor,
80        grad: &Tensor,
81    ) -> Result<(), OptimError> {
82        // Inner (fast) update.
83        self.inner.step(parameter_id, weights, grad)?;
84
85        // Initialise slow weights on first encounter.
86        self.slow_weights
87            .entry(parameter_id)
88            .or_insert_with(|| weights.data().to_vec());
89
90        self.step_count += 1;
91
92        // Synchronise every k steps.
93        if self.step_count.is_multiple_of(self.k) {
94            let slow = self
95                .slow_weights
96                .get_mut(&parameter_id)
97                .expect("slow weights must exist");
98            let fast = weights.data_mut();
99            for (s, f) in slow.iter_mut().zip(fast.iter_mut()) {
100                *s += self.alpha * (*f - *s);
101                *f = *s;
102            }
103        }
104
105        Ok(())
106    }
107}