optirs_core/optimizers/
ranger.rs

1// OptiRS - Ranger Optimizer
2// RAdam + Lookahead combination for improved convergence and stability
3// Reference: "Ranger - a synergistic optimizer" by Less Wright (2019)
4//
5// Ranger combines:
6// 1. RAdam (Rectified Adam) - Adaptive learning rate with variance rectification
7// 2. Lookahead - Slow and fast weight updates for stability
8//
9// This combination provides:
10// - Fast convergence from RAdam
11// - Stability and reduced variance from Lookahead
12// - Better generalization than either optimizer alone
13
14use crate::error::{OptimError, Result};
15use scirs2_core::ndarray::ScalarOperand;
16use scirs2_core::ndarray_ext::{Array1, ArrayView1};
17use scirs2_core::numeric::{Float, Zero};
18use serde::{Deserialize, Serialize};
19
20/// Ranger optimizer configuration
21///
22/// Ranger combines RAdam (Rectified Adam) with Lookahead mechanism.
23/// This standalone implementation integrates both algorithms efficiently.
24///
25/// # Key Features
26/// - Fast convergence from RAdam's variance rectification
27/// - Stability from Lookahead's slow weight trajectory
28/// - Reduced sensitivity to hyperparameter choices
29/// - Better generalization than Adam or RAdam alone
30///
31/// # Type Parameters
32/// - `T`: Floating-point type (f32 or f64)
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Ranger<T: Float + ScalarOperand> {
35    // RAdam parameters
36    learning_rate: T,
37    beta1: T,
38    beta2: T,
39    epsilon: T,
40    weight_decay: T,
41
42    // Lookahead parameters
43    lookahead_k: usize,
44    lookahead_alpha: T,
45
46    // RAdam state
47    momentum: Option<Array1<T>>,
48    velocity: Option<Array1<T>>,
49
50    // Lookahead state
51    slow_weights: Option<Array1<T>>,
52
53    // Step counters
54    step_count: usize,
55    slow_update_count: usize,
56}
57
58impl<T: Float + ScalarOperand> Default for Ranger<T> {
59    fn default() -> Self {
60        Self::new(
61            T::from(0.001).unwrap(), // learning_rate
62            T::from(0.9).unwrap(),   // beta1
63            T::from(0.999).unwrap(), // beta2
64            T::from(1e-8).unwrap(),  // epsilon
65            T::zero(),               // weight_decay
66            5,                       // lookahead_k
67            T::from(0.5).unwrap(),   // lookahead_alpha
68        )
69        .unwrap()
70    }
71}
72
73impl<T: Float + ScalarOperand> Ranger<T> {
74    /// Create a new Ranger optimizer
75    ///
76    /// # Arguments
77    /// - `learning_rate`: Learning rate for RAdam (typically 0.001)
78    /// - `beta1`: First moment decay rate (typically 0.9)
79    /// - `beta2`: Second moment decay rate (typically 0.999)
80    /// - `epsilon`: Small constant for numerical stability (typically 1e-8)
81    /// - `weight_decay`: L2 regularization coefficient (typically 0.0)
82    /// - `lookahead_k`: Number of fast updates per slow update (typically 5-6)
83    /// - `lookahead_alpha`: Interpolation factor for slow weights (typically 0.5)
84    ///
85    /// # Example
86    /// ```
87    /// use optirs_core::optimizers::Ranger;
88    ///
89    /// let optimizer = Ranger::<f32>::new(
90    ///     0.001,  // learning_rate
91    ///     0.9,    // beta1
92    ///     0.999,  // beta2
93    ///     1e-8,   // epsilon
94    ///     0.0,    // weight_decay
95    ///     5,      // lookahead_k
96    ///     0.5     // lookahead_alpha
97    /// ).unwrap();
98    /// ```
99    #[allow(clippy::too_many_arguments)]
100    pub fn new(
101        learning_rate: T,
102        beta1: T,
103        beta2: T,
104        epsilon: T,
105        weight_decay: T,
106        lookahead_k: usize,
107        lookahead_alpha: T,
108    ) -> Result<Self> {
109        // Validate parameters
110        if learning_rate.to_f64().unwrap() <= 0.0 {
111            return Err(OptimError::InvalidParameter(format!(
112                "learning_rate must be positive, got {}",
113                learning_rate.to_f64().unwrap()
114            )));
115        }
116        if beta1.to_f64().unwrap() <= 0.0 || beta1.to_f64().unwrap() >= 1.0 {
117            return Err(OptimError::InvalidParameter(format!(
118                "beta1 must be in (0, 1), got {}",
119                beta1.to_f64().unwrap()
120            )));
121        }
122        if beta2.to_f64().unwrap() <= 0.0 || beta2.to_f64().unwrap() >= 1.0 {
123            return Err(OptimError::InvalidParameter(format!(
124                "beta2 must be in (0, 1), got {}",
125                beta2.to_f64().unwrap()
126            )));
127        }
128        if epsilon.to_f64().unwrap() <= 0.0 {
129            return Err(OptimError::InvalidParameter(format!(
130                "epsilon must be positive, got {}",
131                epsilon.to_f64().unwrap()
132            )));
133        }
134        if weight_decay.to_f64().unwrap() < 0.0 {
135            return Err(OptimError::InvalidParameter(format!(
136                "weight_decay must be non-negative, got {}",
137                weight_decay.to_f64().unwrap()
138            )));
139        }
140        if lookahead_k == 0 {
141            return Err(OptimError::InvalidParameter(
142                "lookahead_k must be positive".to_string(),
143            ));
144        }
145        if lookahead_alpha.to_f64().unwrap() <= 0.0 || lookahead_alpha.to_f64().unwrap() > 1.0 {
146            return Err(OptimError::InvalidParameter(format!(
147                "lookahead_alpha must be in (0, 1], got {}",
148                lookahead_alpha.to_f64().unwrap()
149            )));
150        }
151
152        Ok(Self {
153            learning_rate,
154            beta1,
155            beta2,
156            epsilon,
157            weight_decay,
158            lookahead_k,
159            lookahead_alpha,
160            momentum: None,
161            velocity: None,
162            slow_weights: None,
163            step_count: 0,
164            slow_update_count: 0,
165        })
166    }
167
168    /// Perform a single optimization step
169    ///
170    /// Combines RAdam (fast weights) with Lookahead (slow weights)
171    ///
172    /// # Example
173    /// ```
174    /// use optirs_core::optimizers::Ranger;
175    /// use scirs2_core::ndarray_ext::array;
176    ///
177    /// let mut optimizer = Ranger::<f32>::default();
178    /// let params = array![1.0, 2.0, 3.0];
179    /// let grads = array![0.1, 0.2, 0.3];
180    ///
181    /// let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
182    /// ```
183    pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
184        let n = params.len();
185
186        if grads.len() != n {
187            return Err(OptimError::DimensionMismatch(format!(
188                "Expected gradient size {}, got {}",
189                n,
190                grads.len()
191            )));
192        }
193
194        // Initialize state on first step
195        if self.momentum.is_none() {
196            self.momentum = Some(Array1::zeros(n));
197            self.velocity = Some(Array1::zeros(n));
198            self.slow_weights = Some(params.to_owned());
199        }
200
201        self.step_count += 1;
202        let t = T::from(self.step_count).unwrap();
203
204        let momentum = self.momentum.as_mut().unwrap();
205        let velocity = self.velocity.as_mut().unwrap();
206
207        let one = T::one();
208        let two = T::from(2).unwrap();
209
210        // Apply weight decay if configured
211        let effective_grads = if self.weight_decay > T::zero() {
212            grads.to_owned() + &(params.to_owned() * self.weight_decay)
213        } else {
214            grads.to_owned()
215        };
216
217        // RAdam: Update biased first moment
218        for i in 0..n {
219            momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
220        }
221
222        // RAdam: Update biased second moment
223        for i in 0..n {
224            let grad_sq = effective_grads[i] * effective_grads[i];
225            velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
226        }
227
228        // RAdam: Compute bias correction
229        let bias_correction1 = one - self.beta1.powf(t);
230        let bias_correction2 = one - self.beta2.powf(t);
231
232        // RAdam: Compute SMA (Simple Moving Average) length
233        let rho_inf = two / (one - self.beta2) - one;
234        let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
235
236        // RAdam: Apply variance rectification
237        let mut updated_params = params.to_owned();
238
239        if rho_t.to_f64().unwrap() > 4.0 {
240            // Use adaptive learning rate with variance rectification
241            let rect_term = ((rho_t - T::from(4).unwrap()) * (rho_t - two) * rho_inf
242                / ((rho_inf - T::from(4).unwrap()) * (rho_inf - two) * rho_t))
243                .sqrt();
244
245            for i in 0..n {
246                let m_hat = momentum[i] / bias_correction1;
247                let v_hat = velocity[i] / bias_correction2;
248                let step_size = self.learning_rate * rect_term / (v_hat.sqrt() + self.epsilon);
249                updated_params[i] = updated_params[i] - step_size * m_hat;
250            }
251        } else {
252            // Use simple momentum update during warmup
253            for i in 0..n {
254                let m_hat = momentum[i] / bias_correction1;
255                updated_params[i] = updated_params[i] - self.learning_rate * m_hat;
256            }
257        }
258
259        // Lookahead: Update slow weights every k steps
260        if self.step_count.is_multiple_of(self.lookahead_k) {
261            let slow = self.slow_weights.as_mut().unwrap();
262            for i in 0..n {
263                slow[i] = slow[i] + self.lookahead_alpha * (updated_params[i] - slow[i]);
264            }
265            self.slow_update_count += 1;
266
267            // Synchronize fast weights with slow weights
268            // This is the key to Lookahead: we return the slow weights after update
269            Ok(slow.clone())
270        } else {
271            // Between slow updates, return fast weights
272            Ok(updated_params)
273        }
274    }
275
276    /// Get the number of optimization steps performed
277    pub fn step_count(&self) -> usize {
278        self.step_count
279    }
280
281    /// Get the number of slow weight updates performed
282    pub fn slow_update_count(&self) -> usize {
283        self.slow_update_count
284    }
285
286    /// Reset the optimizer state
287    pub fn reset(&mut self) {
288        self.momentum = None;
289        self.velocity = None;
290        self.slow_weights = None;
291        self.step_count = 0;
292        self.slow_update_count = 0;
293    }
294
295    /// Get the slow weights (Lookahead trajectory)
296    pub fn slow_weights(&self) -> Option<&Array1<T>> {
297        self.slow_weights.as_ref()
298    }
299
300    /// Check if variance rectification is active
301    pub fn is_rectified(&self) -> bool {
302        if self.step_count == 0 {
303            return false;
304        }
305        let t = T::from(self.step_count).unwrap();
306        let one = T::one();
307        let two = T::from(2).unwrap();
308        let bias_correction2 = one - self.beta2.powf(t);
309        let rho_inf = two / (one - self.beta2) - one;
310        let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
311        rho_t.to_f64().unwrap() > 4.0
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use approx::assert_relative_eq;
319    use scirs2_core::ndarray_ext::array;
320
321    #[test]
322    fn test_ranger_creation() {
323        let optimizer = Ranger::<f32>::default();
324        assert_eq!(optimizer.step_count(), 0);
325        assert_eq!(optimizer.slow_update_count(), 0);
326    }
327
328    #[test]
329    fn test_ranger_custom_creation() {
330        let optimizer = Ranger::<f32>::new(0.002, 0.95, 0.9999, 1e-7, 0.01, 6, 0.6).unwrap();
331        assert_eq!(optimizer.step_count(), 0);
332    }
333
334    #[test]
335    fn test_ranger_single_step() {
336        let mut optimizer = Ranger::<f32>::default();
337        let params = array![1.0, 2.0, 3.0];
338        let grads = array![0.1, 0.2, 0.3];
339
340        let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
341        assert_eq!(updated_params.len(), 3);
342        assert_eq!(optimizer.step_count(), 1);
343
344        for i in 0..3 {
345            assert!(updated_params[i] < params[i]);
346        }
347    }
348
349    #[test]
350    fn test_ranger_slow_updates() {
351        let mut optimizer = Ranger::<f32>::new(0.001, 0.9, 0.999, 1e-8, 0.0, 3, 0.5).unwrap();
352        let mut params = array![1.0, 2.0, 3.0];
353
354        for _ in 0..3 {
355            let grads = array![0.1, 0.2, 0.3];
356            params = optimizer.step(params.view(), grads.view()).unwrap();
357        }
358        assert_eq!(optimizer.slow_update_count(), 1);
359    }
360
361    #[test]
362    fn test_ranger_convergence() {
363        // Use higher learning rate for this simple convex problem
364        // Default 0.001 is tuned for neural networks
365        let mut optimizer = Ranger::<f64>::new(
366            0.1,   // learning_rate: higher for simple problem
367            0.9,   // beta1
368            0.999, // beta2
369            1e-8,  // epsilon
370            0.0,   // weight_decay
371            5,     // lookahead_k
372            0.5,   // lookahead_alpha
373        )
374        .unwrap();
375        let mut params = array![5.0];
376
377        // Ranger combines RAdam (adaptive LR) with Lookahead (slow updates)
378        for _ in 0..500 {
379            let grads = params.mapv(|x| 2.0 * x);
380            params = optimizer.step(params.view(), grads.view()).unwrap();
381        }
382
383        assert!(
384            params[0].abs() < 0.1,
385            "Failed to converge, got {}",
386            params[0]
387        );
388    }
389
390    #[test]
391    fn test_ranger_reset() {
392        let mut optimizer = Ranger::<f32>::default();
393        let params = array![1.0, 2.0, 3.0];
394        let grads = array![0.1, 0.2, 0.3];
395
396        for _ in 0..10 {
397            optimizer.step(params.view(), grads.view()).unwrap();
398        }
399
400        optimizer.reset();
401        assert_eq!(optimizer.step_count(), 0);
402        assert_eq!(optimizer.slow_update_count(), 0);
403        assert!(optimizer.slow_weights().is_none());
404    }
405
406    #[test]
407    fn test_ranger_rectification() {
408        let mut optimizer = Ranger::<f32>::default();
409        let params = array![1.0];
410        let grads = array![0.1];
411
412        // Initially not rectified
413        assert!(!optimizer.is_rectified());
414
415        // After several steps, should be rectified
416        for _ in 0..10 {
417            optimizer.step(params.view(), grads.view()).unwrap();
418        }
419        assert!(optimizer.is_rectified());
420    }
421}