1use crate::error::{OptimError, Result};
7use crate::optimizers::Optimizer;
8use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13pub struct Lookahead<A, O, D>
54where
55 A: Float + ScalarOperand + Debug,
56 O: Optimizer<A, D> + Clone,
57 D: Dimension,
58{
59 inner_optimizer: O,
61 alpha: A,
63 k: usize,
65 current_step: usize,
67 slow_weights: Option<Array<A, D>>,
69 fast_weights: Option<Array<A, D>>,
71 use_slow_weights: bool,
73 _phantom: PhantomData<D>,
75}
76
77impl<A, O, D> Lookahead<A, O, D>
78where
79 A: Float + ScalarOperand + Debug,
80 O: Optimizer<A, D> + Clone,
81 D: Dimension,
82{
83 pub fn new(inner_optimizer: O) -> Self {
85 Self {
86 inner_optimizer,
87 alpha: A::from(0.5).unwrap(), k: 5, current_step: 0,
90 slow_weights: None,
91 fast_weights: None,
92 use_slow_weights: false,
93 _phantom: PhantomData,
94 }
95 }
96
97 pub fn with_config(inner_optimizer: O, alpha: A, k: usize) -> Self {
99 Self {
100 inner_optimizer,
101 alpha,
102 k,
103 current_step: 0,
104 slow_weights: None,
105 fast_weights: None,
106 use_slow_weights: false,
107 _phantom: PhantomData,
108 }
109 }
110
111 pub fn with_alpha(mut self, alpha: A) -> Self {
113 self.alpha = alpha;
114 self
115 }
116
117 pub fn with_k(mut self, k: usize) -> Self {
119 self.k = k;
120 self
121 }
122
123 pub fn inner_optimizer(&self) -> &O {
125 &self.inner_optimizer
126 }
127
128 pub fn inner_optimizer_mut(&mut self) -> &mut O {
130 &mut self.inner_optimizer
131 }
132
133 pub fn alpha(&self) -> A {
135 self.alpha
136 }
137
138 pub fn k(&self) -> usize {
140 self.k
141 }
142
143 pub fn use_slow_weights_for_eval(&mut self) {
146 self.use_slow_weights = true;
147 }
148
149 pub fn use_fast_weights_for_train(&mut self) {
152 self.use_slow_weights = false;
153 }
154
155 pub fn reset(&mut self) {
157 self.current_step = 0;
158 self.slow_weights = None;
159 self.fast_weights = None;
160 }
161}
162
163impl<A, O, D> Clone for Lookahead<A, O, D>
164where
165 A: Float + ScalarOperand + Debug,
166 O: Optimizer<A, D> + Clone,
167 D: Dimension,
168{
169 fn clone(&self) -> Self {
170 Self {
171 inner_optimizer: self.inner_optimizer.clone(),
172 alpha: self.alpha,
173 k: self.k,
174 current_step: self.current_step,
175 slow_weights: self.slow_weights.clone(),
176 fast_weights: self.fast_weights.clone(),
177 use_slow_weights: self.use_slow_weights,
178 _phantom: PhantomData,
179 }
180 }
181}
182
183impl<A, O, D> Debug for Lookahead<A, O, D>
184where
185 A: Float + ScalarOperand + Debug,
186 O: Optimizer<A, D> + Clone + Debug,
187 D: Dimension,
188{
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("Lookahead")
191 .field("inner_optimizer", &self.inner_optimizer)
192 .field("alpha", &self.alpha)
193 .field("k", &self.k)
194 .field("current_step", &self.current_step)
195 .field("use_slow_weights", &self.use_slow_weights)
196 .finish()
197 }
198}
199
200impl<A, O, D> Optimizer<A, D> for Lookahead<A, O, D>
201where
202 A: Float + ScalarOperand + Debug + Send + Sync,
203 O: Optimizer<A, D> + Clone + Send + Sync,
204 D: Dimension,
205{
206 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
207 if self.slow_weights.is_none() {
209 self.slow_weights = Some(params.clone());
210 self.fast_weights = Some(params.clone());
211 }
212
213 let fast_weights = match &mut self.fast_weights {
215 Some(w) => w,
216 None => {
217 return Err(OptimError::OptimizationError(
218 "Fast weights not initialized".to_string(),
219 ))
220 }
221 };
222
223 let slow_weights = match &mut self.slow_weights {
224 Some(w) => w,
225 None => {
226 return Err(OptimError::OptimizationError(
227 "Slow weights not initialized".to_string(),
228 ))
229 }
230 };
231
232 *fast_weights = self.inner_optimizer.step(fast_weights, gradients)?;
234
235 self.current_step += 1;
237
238 if self.current_step >= self.k {
240 let diff = &*fast_weights - &*slow_weights;
243
244 *slow_weights = &*slow_weights + &(diff * self.alpha);
246
247 *fast_weights = slow_weights.clone();
249
250 self.current_step = 0;
252 }
253
254 if self.use_slow_weights {
256 Ok(slow_weights.clone())
257 } else {
258 Ok(fast_weights.clone())
259 }
260 }
261
262 fn set_learning_rate(&mut self, learning_rate: A) {
263 self.inner_optimizer.set_learning_rate(learning_rate);
264 }
265
266 fn get_learning_rate(&self) -> A {
267 self.inner_optimizer.get_learning_rate()
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::optimizers::sgd::SGD;
275 use approx::assert_abs_diff_eq;
276 use scirs2_core::ndarray::Array1;
277
278 #[test]
279 fn test_lookahead_creation() {
280 let sgd = SGD::new(0.01);
281 let optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = Lookahead::new(sgd);
282
283 assert_abs_diff_eq!(optimizer.alpha(), 0.5);
284 assert_eq!(optimizer.k(), 5);
285 assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
286 }
287
288 #[test]
289 fn test_lookahead_with_config() {
290 let sgd = SGD::new(0.01);
291 let optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
292 Lookahead::with_config(sgd, 0.8, 10);
293
294 assert_abs_diff_eq!(optimizer.alpha(), 0.8);
295 assert_eq!(optimizer.k(), 10);
296 }
297
298 #[test]
299 fn test_lookahead_step() {
300 let mut sgd = SGD::new(0.1);
301 sgd.set_momentum(0.0);
302 let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
303 Lookahead::with_config(sgd, 0.5, 2);
304
305 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
306 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
307
308 let updated_params = optimizer.step(¶ms, &gradients).unwrap();
310
311 assert_abs_diff_eq!(updated_params[0], 0.99, epsilon = 1e-6);
314 assert_abs_diff_eq!(updated_params[1], 1.98, epsilon = 1e-6);
315 assert_abs_diff_eq!(updated_params[2], 2.97, epsilon = 1e-6);
316
317 let updated_params2 = optimizer.step(&updated_params, &gradients).unwrap();
319
320 assert_abs_diff_eq!(updated_params2[0], 0.99, epsilon = 1e-6);
329 assert_abs_diff_eq!(updated_params2[1], 1.98, epsilon = 1e-6);
330 assert_abs_diff_eq!(updated_params2[2], 2.97, epsilon = 1e-6);
331
332 let updated_params3 = optimizer.step(&updated_params2, &gradients).unwrap();
334
335 assert_abs_diff_eq!(updated_params3[0], 0.98, epsilon = 1e-6);
337 assert_abs_diff_eq!(updated_params3[1], 1.96, epsilon = 1e-6);
338 assert_abs_diff_eq!(updated_params3[2], 2.94, epsilon = 1e-6);
339 }
340
341 #[test]
342 fn test_slow_weights_for_eval() {
343 let mut sgd = SGD::new(0.1);
344 sgd.set_momentum(0.0);
345 let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
346 Lookahead::with_config(sgd, 0.5, 2);
347
348 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
349 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
350
351 let updated_params = optimizer.step(¶ms, &gradients).unwrap();
353
354 optimizer.use_slow_weights_for_eval();
356
357 let eval_params = optimizer.step(&updated_params, &gradients).unwrap();
359
360 assert_abs_diff_eq!(eval_params[0], 0.99, epsilon = 1e-6);
364 assert_abs_diff_eq!(eval_params[1], 1.98, epsilon = 1e-6);
365 assert_abs_diff_eq!(eval_params[2], 2.97, epsilon = 1e-6);
366
367 optimizer.use_fast_weights_for_train();
369
370 let train_params = optimizer.step(&eval_params, &gradients).unwrap();
372 assert!(train_params[0] < 1.0);
373 }
374
375 #[test]
376 fn test_reset() {
377 let sgd = SGD::new(0.1);
378 let mut optimizer: Lookahead<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
379 Lookahead::new(sgd);
380
381 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
382 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
383
384 let _ = optimizer.step(¶ms, &gradients).unwrap();
386
387 optimizer.reset();
389
390 let updated_params = optimizer.step(¶ms, &gradients).unwrap();
392 assert_abs_diff_eq!(updated_params[0], 0.99, epsilon = 1e-6);
394 }
395}