optirs_core/optimizers/
lookahead.rs

1// Lookahead optimizer
2//
3// Implements the Lookahead optimization algorithm from:
4// "Lookahead Optimizer: k steps forward, 1 step back" (Zhang et al., 2019)
5
6use crate::error::{OptimError, Result};
7use crate::optimizers::Optimizer;
8use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13/// Lookahead optimizer
14///
15/// Implements the "Lookahead Optimizer: k steps forward, 1 step back" algorithm.
16/// This optimizer maintains two sets of weights: "fast" weights that are updated by
17/// an inner optimizer, and "slow" weights that follow behind at a controlled pace.
18///
19/// The algorithm proceeds by:
20/// 1. Starting with both sets of weights synchronized
21/// 2. Letting the fast weights explore using the inner optimizer for k steps
22/// 3. Then updating the slow weights to move partially toward the fast weights
23/// 4. Resetting the fast weights back to the slow weights
24/// 5. Repeating this process
25///
26/// This provides more stable optimization by allowing aggressive exploration while
27/// maintaining a conservative trajectory.
28///
29/// # Parameters
30///
31/// * `inner_optimizer` - The optimizer to use for fast weight updates
32/// * `alpha` - The step size for slow weight updates (default: 0.5)
33/// * `k` - The number of fast weight updates before updating slow weights (default: 5)
34///
35/// # Example
36///
37/// ```
38/// use scirs2_core::ndarray::Array1;
39/// use optirs_core::optimizers::{Lookahead, SGD};
40/// use optirs_core::Optimizer;
41///
42/// // Create an inner optimizer
43/// let sgd = SGD::new(0.01);
44///
45/// // Wrap it with Lookahead
46/// let mut optimizer = Lookahead::new(sgd);
47///
48/// // Use like any other optimizer
49/// let params = Array1::zeros(10);
50/// let gradients = Array1::ones(10);
51/// let updated_params = optimizer.step(&params, &gradients).unwrap();
52/// ```
53pub struct Lookahead<A, O, D>
54where
55    A: Float + ScalarOperand + Debug,
56    O: Optimizer<A, D> + Clone,
57    D: Dimension,
58{
59    /// Inner optimizer for fast weights
60    inner_optimizer: O,
61    /// Step size for slow weights update (alpha)
62    alpha: A,
63    /// Synchronization period (k)
64    k: usize,
65    /// Current step counter
66    current_step: usize,
67    /// Slow weights
68    slow_weights: Option<Array<A, D>>,
69    /// Fast weights
70    fast_weights: Option<Array<A, D>>,
71    /// Use slow weights for evaluation
72    use_slow_weights: bool,
73    /// Dimension type marker
74    _phantom: PhantomData<D>,
75}
76
77impl<A, O, D> Lookahead<A, O, D>
78where
79    A: Float + ScalarOperand + Debug,
80    O: Optimizer<A, D> + Clone,
81    D: Dimension,
82{
83    /// Creates a new Lookahead optimizer with the given inner optimizer and default settings
84    pub fn new(inner_optimizer: O) -> Self {
85        Self {
86            inner_optimizer,
87            alpha: A::from(0.5).unwrap(), // Default alpha is 0.5
88            k: 5,                         // Default k is 5
89            current_step: 0,
90            slow_weights: None,
91            fast_weights: None,
92            use_slow_weights: false,
93            _phantom: PhantomData,
94        }
95    }
96
97    /// Creates a new Lookahead optimizer with the specified alpha and k values
98    pub fn with_config(inner_optimizer: O, alpha: A, k: usize) -> Self {
99        Self {
100            inner_optimizer,
101            alpha,
102            k,
103            current_step: 0,
104            slow_weights: None,
105            fast_weights: None,
106            use_slow_weights: false,
107            _phantom: PhantomData,
108        }
109    }
110
111    /// Set the alpha parameter (slow weights step size)
112    pub fn with_alpha(mut self, alpha: A) -> Self {
113        self.alpha = alpha;
114        self
115    }
116
117    /// Set the k parameter (synchronization period)
118    pub fn with_k(mut self, k: usize) -> Self {
119        self.k = k;
120        self
121    }
122
123    /// Get the inner optimizer
124    pub fn inner_optimizer(&self) -> &O {
125        &self.inner_optimizer
126    }
127
128    /// Get a mutable reference to the inner optimizer
129    pub fn inner_optimizer_mut(&mut self) -> &mut O {
130        &mut self.inner_optimizer
131    }
132
133    /// Get the alpha parameter (slow weights step size)
134    pub fn alpha(&self) -> A {
135        self.alpha
136    }
137
138    /// Get the k parameter (synchronization period)
139    pub fn k(&self) -> usize {
140        self.k
141    }
142
143    /// Switches to using slow weights for evaluation
144    /// Call this before evaluation to get better performance
145    pub fn use_slow_weights_for_eval(&mut self) {
146        self.use_slow_weights = true;
147    }
148
149    /// Switches to using fast weights for training
150    /// Call this after evaluation to resume training
151    pub fn use_fast_weights_for_train(&mut self) {
152        self.use_slow_weights = false;
153    }
154
155    /// Resets the internal state
156    pub fn reset(&mut self) {
157        self.current_step = 0;
158        self.slow_weights = None;
159        self.fast_weights = None;
160    }
161}
162
163impl<A, O, D> Clone for Lookahead<A, O, D>
164where
165    A: Float + ScalarOperand + Debug,
166    O: Optimizer<A, D> + Clone,
167    D: Dimension,
168{
169    fn clone(&self) -> Self {
170        Self {
171            inner_optimizer: self.inner_optimizer.clone(),
172            alpha: self.alpha,
173            k: self.k,
174            current_step: self.current_step,
175            slow_weights: self.slow_weights.clone(),
176            fast_weights: self.fast_weights.clone(),
177            use_slow_weights: self.use_slow_weights,
178            _phantom: PhantomData,
179        }
180    }
181}
182
183impl<A, O, D> Debug for Lookahead<A, O, D>
184where
185    A: Float + ScalarOperand + Debug,
186    O: Optimizer<A, D> + Clone + Debug,
187    D: Dimension,
188{
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("Lookahead")
191            .field("inner_optimizer", &self.inner_optimizer)
192            .field("alpha", &self.alpha)
193            .field("k", &self.k)
194            .field("current_step", &self.current_step)
195            .field("use_slow_weights", &self.use_slow_weights)
196            .finish()
197    }
198}
199
200impl<A, O, D> Optimizer<A, D> for Lookahead<A, O, D>
201where
202    A: Float + ScalarOperand + Debug + Send + Sync,
203    O: Optimizer<A, D> + Clone + Send + Sync,
204    D: Dimension,
205{
206    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
207        // Initialize weights if first step
208        if self.slow_weights.is_none() {
209            self.slow_weights = Some(params.clone());
210            self.fast_weights = Some(params.clone());
211        }
212
213        // Get mutable references to weights
214        let fast_weights = match &mut self.fast_weights {
215            Some(w) => w,
216            None => {
217                return Err(OptimError::OptimizationError(
218                    "Fast weights not initialized".to_string(),
219                ))
220            }
221        };
222
223        let slow_weights = match &mut self.slow_weights {
224            Some(w) => w,
225            None => {
226                return Err(OptimError::OptimizationError(
227                    "Slow weights not initialized".to_string(),
228                ))
229            }
230        };
231
232        // Update fast weights using inner optimizer
233        *fast_weights = self.inner_optimizer.step(fast_weights, gradients)?;
234
235        // Increment step counter
236        self.current_step += 1;
237
238        // If we've reached k steps, update slow weights and reset fast weights
239        if self.current_step >= self.k {
240            // Update slow weights: φₜ ← φₜ₋₁ + α(θₜ,ₖ - φₜ₋₁)
241            // Compute difference between fast and slow weights
242            let diff = &*fast_weights - &*slow_weights;
243
244            // Update slow weights by moving alpha of the way toward fast weights
245            *slow_weights = &*slow_weights + &(diff * self.alpha);
246
247            // Reset fast weights to slow weights
248            *fast_weights = slow_weights.clone();
249
250            // Reset step counter
251            self.current_step = 0;
252        }
253
254        // Return the appropriate weights (slow for evaluation, fast for training)
255        if self.use_slow_weights {
256            Ok(slow_weights.clone())
257        } else {
258            Ok(fast_weights.clone())
259        }
260    }
261
262    fn set_learning_rate(&mut self, learning_rate: A) {
263        self.inner_optimizer.set_learning_rate(learning_rate);
264    }
265
266    fn get_learning_rate(&self) -> A {
267        self.inner_optimizer.get_learning_rate()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::optimizers::sgd::SGD;
275    use approx::assert_abs_diff_eq;
276    use scirs2_core::ndarray::Array1;
277
278    #[test]
279    fn test_lookahead_creation() {
280        let sgd = SGD::new(0.01);
281        let optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = Lookahead::new(sgd);
282
283        assert_abs_diff_eq!(optimizer.alpha(), 0.5);
284        assert_eq!(optimizer.k(), 5);
285        assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
286    }
287
288    #[test]
289    fn test_lookahead_with_config() {
290        let sgd = SGD::new(0.01);
291        let optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
292            Lookahead::with_config(sgd, 0.8, 10);
293
294        assert_abs_diff_eq!(optimizer.alpha(), 0.8);
295        assert_eq!(optimizer.k(), 10);
296    }
297
298    #[test]
299    fn test_lookahead_step() {
300        let mut sgd = SGD::new(0.1);
301        sgd.set_momentum(0.0);
302        let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
303            Lookahead::with_config(sgd, 0.5, 2);
304
305        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
306        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
307
308        // First step
309        let updated_params = optimizer.step(&params, &gradients).unwrap();
310
311        // After first step, fast weights should be updated by SGD but slow weights unchanged
312        // SGD update: params - lr * gradients = [1.0, 2.0, 3.0] - 0.1 * [0.1, 0.2, 0.3] = [0.99, 1.98, 2.97]
313        assert_abs_diff_eq!(updated_params[0], 0.99, epsilon = 1e-6);
314        assert_abs_diff_eq!(updated_params[1], 1.98, epsilon = 1e-6);
315        assert_abs_diff_eq!(updated_params[2], 2.97, epsilon = 1e-6);
316
317        // Second step
318        let updated_params2 = optimizer.step(&updated_params, &gradients).unwrap();
319
320        // After second step (which is k), slow weights should be updated and fast weights reset to slow weights
321        // SGD update on fast weights: [0.99, 1.98, 2.97] - 0.1 * [0.1, 0.2, 0.3] = [0.98, 1.96, 2.94]
322        // Slow weights update: [1.0, 2.0, 3.0] + 0.5 * ([0.98, 1.96, 2.94] - [1.0, 2.0, 3.0])
323        //                    = [1.0, 2.0, 3.0] + 0.5 * [-0.02, -0.04, -0.06]
324        //                    = [0.99, 1.98, 2.97]
325        // Fast weights are reset to slow weights = [0.99, 1.98, 2.97]
326
327        // The returned value should be the fast weights (which are now reset to slow weights)
328        assert_abs_diff_eq!(updated_params2[0], 0.99, epsilon = 1e-6);
329        assert_abs_diff_eq!(updated_params2[1], 1.98, epsilon = 1e-6);
330        assert_abs_diff_eq!(updated_params2[2], 2.97, epsilon = 1e-6);
331
332        // Third step (starting a new cycle)
333        let updated_params3 = optimizer.step(&updated_params2, &gradients).unwrap();
334
335        // SGD update on fast weights: [0.99, 1.98, 2.97] - 0.1 * [0.1, 0.2, 0.3] = [0.98, 1.96, 2.94]
336        assert_abs_diff_eq!(updated_params3[0], 0.98, epsilon = 1e-6);
337        assert_abs_diff_eq!(updated_params3[1], 1.96, epsilon = 1e-6);
338        assert_abs_diff_eq!(updated_params3[2], 2.94, epsilon = 1e-6);
339    }
340
341    #[test]
342    fn test_slow_weights_for_eval() {
343        let mut sgd = SGD::new(0.1);
344        sgd.set_momentum(0.0);
345        let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
346            Lookahead::with_config(sgd, 0.5, 2);
347
348        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
349        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
350
351        // First step
352        let updated_params = optimizer.step(&params, &gradients).unwrap();
353
354        // Switch to slow weights for evaluation
355        optimizer.use_slow_weights_for_eval();
356
357        // Get the parameters when using slow weights
358        let eval_params = optimizer.step(&updated_params, &gradients).unwrap();
359
360        // First step already updated both fast and slow weights
361        // When using slow weights, we should get the slow weights which were initialized with
362        // values from params: [1.0, 2.0, 3.0] but then updated by the first step
363        assert_abs_diff_eq!(eval_params[0], 0.99, epsilon = 1e-6);
364        assert_abs_diff_eq!(eval_params[1], 1.98, epsilon = 1e-6);
365        assert_abs_diff_eq!(eval_params[2], 2.97, epsilon = 1e-6);
366
367        // Switch back to fast weights for training
368        optimizer.use_fast_weights_for_train();
369
370        // Should be back to fast weights
371        let train_params = optimizer.step(&eval_params, &gradients).unwrap();
372        assert!(train_params[0] < 1.0);
373    }
374
375    #[test]
376    fn test_reset() {
377        let sgd = SGD::new(0.1);
378        let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
379            Lookahead::new(sgd);
380
381        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
382        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
383
384        // Do a step to initialize weights
385        let _ = optimizer.step(&params, &gradients).unwrap();
386
387        // Reset
388        optimizer.reset();
389
390        // Both fast and slow weights should be None, verified by new initialization
391        let updated_params = optimizer.step(&params, &gradients).unwrap();
392        // First step after reset should be equivalent to first step on a new optimizer
393        assert_abs_diff_eq!(updated_params[0], 0.99, epsilon = 1e-6);
394    }
395}