1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
//! # Lookahead Optimizer
//!
//! This module implements the Lookahead meta-optimizer that can wrap any base optimizer
//! to improve convergence and reduce variance in the optimization process.
//!
//! ## Algorithm
//!
//! Lookahead maintains two sets of weights:
//! - Fast weights (φ): Updated by the base optimizer
//! - Slow weights (θ): Updated every k steps using interpolation
//!
//! The update rule:
//! ```text
//! φ_t = base_optimizer_update(φ_{t-1}, g_t) // Fast weight update
//!
//! Every k steps:
//! θ_t = θ_{t-k} + α(φ_t - θ_{t-k}) // Slow weight update
//! φ_t = θ_t // Reset fast weights
//! ```
//!
//! ## Benefits
//!
//! - Reduces variance in optimization trajectories
//! - Improves convergence in noisy settings
//! - Can be combined with any base optimizer
//! - Often leads to better generalization
//!
//! ## Usage Example
//!
//! ```rust,no_run
//! use trustformers_optim::{Lookahead, AdamW};
//! use trustformers_core::traits::Optimizer;
//!
//! // Create base optimizer
//! let base_optimizer = AdamW::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
//!
//! // Wrap with Lookahead
//! let mut lookahead = Lookahead::new(
//! base_optimizer,
//! 5, // Update slow weights every 5 steps
//! 0.5, // Interpolation factor
//! );
//!
//! // Lookahead automatically handles fast/slow weight updates
//! // Use it just like any other optimizer with .zero_grad(), .update(), and .step()
//! ```
//!
//! ## Hyperparameter Guidelines
//!
//! ### k (Update Frequency)
//! - k=5 to k=10 is typical for most problems
//! - Smaller k: More frequent slow updates, more stable but slower
//! - Larger k: Less frequent updates, faster but potentially less stable
//!
//! ### α (Interpolation Factor)
//! - α=0.5 is the recommended default
//! - α=0.8 for more aggressive slow weight updates
//! - α=0.2 for more conservative updates
//!
//! ## Implementation Notes
//!
//! - Stores slow weights alongside the base optimizer
//! - Minimal memory overhead (2x parameter storage)
//! - Works with any base optimizer that implements the Optimizer trait
//! - Thread-safe for data parallel training
use crate::common::OptimizerState;
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
/// Lookahead meta-optimizer that wraps any base optimizer.
///
/// Implements the Lookahead algorithm from "Lookahead Optimizer: k steps forward,
/// 1 step back" by Zhang et al. (2019). Lookahead maintains slow weights that are
/// updated every k steps using interpolation with the fast weights.
#[derive(Debug)]
pub struct Lookahead<T: Optimizer> {
/// Base optimizer for fast weight updates
base_optimizer: T,
/// Number of fast weight update steps before slow weight update
k: usize,
/// Interpolation factor for slow weight updates (α in the paper)
alpha: f32,
/// Optimizer state tracking steps
state: OptimizerState,
/// Slow weights (θ in the paper)
slow_weights: HashMap<String, Vec<f32>>,
/// Counter for fast weight updates since last slow weight update
fast_step_count: usize,
}
impl<T: Optimizer> Lookahead<T> {
/// Creates a new Lookahead meta-optimizer.
///
/// # Arguments
///
/// * `base_optimizer` - The base optimizer to wrap
/// * `k` - Number of fast steps before slow weight update (typical: 5-10)
/// * `alpha` - Interpolation factor for slow weights (typical: 0.5)
///
/// # Example
///
/// ```
/// use trustformers_optim::{Lookahead, AdamW};
///
/// let base = AdamW::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
/// let optimizer = Lookahead::new(base, 5, 0.5);
/// ```
pub fn new(base_optimizer: T, k: usize, alpha: f32) -> Self {
assert!(k > 0, "k must be positive");
assert!(alpha > 0.0 && alpha <= 1.0, "alpha must be in (0, 1]");
Self {
base_optimizer,
k,
alpha,
state: OptimizerState::new(),
slow_weights: HashMap::new(),
fast_step_count: 0,
}
}
/// Get a reference to the base optimizer.
pub fn base_optimizer(&self) -> &T {
&self.base_optimizer
}
/// Get a mutable reference to the base optimizer.
pub fn base_optimizer_mut(&mut self) -> &mut T {
&mut self.base_optimizer
}
/// Initialize slow weights from current parameter values.
fn init_slow_weights(&mut self, parameter: &Tensor) -> Result<()> {
match parameter {
Tensor::F32(param) => {
let param_id = format!("{:p}", param.as_ptr());
self.slow_weights
.entry(param_id)
.or_insert_with(|| param.iter().cloned().collect());
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for Lookahead",
"init_slow_weights",
)),
}
}
/// Update slow weights using interpolation with fast weights.
fn update_slow_weights(&mut self, parameter: &mut Tensor) -> Result<()> {
match parameter {
Tensor::F32(param) => {
let param_id = format!("{:p}", param.as_ptr());
if let Some(slow_weights) = self.slow_weights.get_mut(¶m_id) {
if slow_weights.len() != param.len() {
return Err(TrustformersError::tensor_op_error(
"Lookahead slow weights size mismatch",
"slow weights validation",
));
}
// Update slow weights: θ = θ + α(φ - θ)
for (slow_w, fast_w) in slow_weights.iter_mut().zip(param.iter()) {
*slow_w += self.alpha * (*fast_w - *slow_w);
}
// Copy slow weights back to parameters (reset fast weights)
for (p, slow_w) in param.iter_mut().zip(slow_weights.iter()) {
*p = *slow_w;
}
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for Lookahead",
"update_slow_weights",
)),
}
}
}
impl<T: Optimizer> Optimizer for Lookahead<T> {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
// Initialize slow weights if this is the first time seeing this parameter
self.init_slow_weights(parameter)?;
// Apply base optimizer update (fast weight update)
self.base_optimizer.update(parameter, grad)?;
Ok(())
}
fn zero_grad(&mut self) {
self.base_optimizer.zero_grad();
}
fn step(&mut self) {
// Always step the base optimizer
self.base_optimizer.step();
self.fast_step_count += 1;
// Check if it's time for slow weight update
if self.fast_step_count >= self.k {
// Note: We need access to all parameters here for slow weight update
// This will be handled by the calling code that has access to all parameters
self.fast_step_count = 0;
}
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.base_optimizer.get_lr()
}
fn set_lr(&mut self, lr: f32) {
self.base_optimizer.set_lr(lr);
}
}
impl<T: Optimizer> Lookahead<T> {
/// Perform slow weight update for a specific parameter.
/// This should be called after every k fast steps for each parameter.
pub fn slow_step(&mut self, parameter: &mut Tensor) -> Result<()> {
if self.fast_step_count == 0 {
// We just completed k fast steps, time for slow weight update
self.update_slow_weights(parameter)?;
}
Ok(())
}
}
/// Convenience wrapper for Lookahead + Adam combination.
pub type LookaheadAdam = Lookahead<crate::adam::Adam>;
/// Convenience wrapper for Lookahead + AdamW combination.
pub type LookaheadAdamW = Lookahead<crate::adam::AdamW>;
/// Convenience wrapper for Lookahead + RAdam combination.
pub type LookaheadRAdam = Lookahead<crate::adam::RAdam>;
/// Convenience wrapper for Lookahead + NAdam combination.
pub type LookaheadNAdam = Lookahead<crate::adam::NAdam>;
/// Convenience wrapper for Lookahead + SGD combination.
pub type LookaheadSGD = Lookahead<crate::sgd::SGD>;
#[cfg(test)]
#[path = "lookahead_tests.rs"]
mod lookahead_tests;