Skip to main content

optirs_core/optimizers/
ranger.rs

1// OptiRS - Ranger Optimizer
2// RAdam + Lookahead combination for improved convergence and stability
3// Reference: "Ranger - a synergistic optimizer" by Less Wright (2019)
4//
5// Ranger combines:
6// 1. RAdam (Rectified Adam) - Adaptive learning rate with variance rectification
7// 2. Lookahead - Slow and fast weight updates for stability
8//
9// This combination provides:
10// - Fast convergence from RAdam
11// - Stability and reduced variance from Lookahead
12// - Better generalization than either optimizer alone
13
14use crate::error::{OptimError, Result};
15use scirs2_core::ndarray::ScalarOperand;
16use scirs2_core::ndarray_ext::{Array1, ArrayView1};
17use scirs2_core::numeric::{Float, Zero};
18use serde::{Deserialize, Serialize};
19
20/// Ranger optimizer configuration
21///
22/// Ranger combines RAdam (Rectified Adam) with Lookahead mechanism.
23/// This standalone implementation integrates both algorithms efficiently.
24///
25/// # Key Features
26/// - Fast convergence from RAdam's variance rectification
27/// - Stability from Lookahead's slow weight trajectory
28/// - Reduced sensitivity to hyperparameter choices
29/// - Better generalization than Adam or RAdam alone
30///
31/// # Type Parameters
32/// - `T`: Floating-point type (f32 or f64)
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Ranger<T: Float + ScalarOperand> {
35    // RAdam parameters
36    learning_rate: T,
37    beta1: T,
38    beta2: T,
39    epsilon: T,
40    weight_decay: T,
41
42    // Lookahead parameters
43    lookahead_k: usize,
44    lookahead_alpha: T,
45
46    // RAdam state
47    momentum: Option<Array1<T>>,
48    velocity: Option<Array1<T>>,
49
50    // Lookahead state
51    slow_weights: Option<Array1<T>>,
52
53    // Step counters
54    step_count: usize,
55    slow_update_count: usize,
56}
57
58impl<T: Float + ScalarOperand> Default for Ranger<T> {
59    fn default() -> Self {
60        Self::new(
61            T::from(0.001).expect("unwrap failed"), // learning_rate
62            T::from(0.9).expect("unwrap failed"),   // beta1
63            T::from(0.999).expect("unwrap failed"), // beta2
64            T::from(1e-8).expect("unwrap failed"),  // epsilon
65            T::zero(),                              // weight_decay
66            5,                                      // lookahead_k
67            T::from(0.5).expect("unwrap failed"),   // lookahead_alpha
68        )
69        .expect("unwrap failed")
70    }
71}
72
73impl<T: Float + ScalarOperand> Ranger<T> {
74    /// Create a new Ranger optimizer
75    ///
76    /// # Arguments
77    /// - `learning_rate`: Learning rate for RAdam (typically 0.001)
78    /// - `beta1`: First moment decay rate (typically 0.9)
79    /// - `beta2`: Second moment decay rate (typically 0.999)
80    /// - `epsilon`: Small constant for numerical stability (typically 1e-8)
81    /// - `weight_decay`: L2 regularization coefficient (typically 0.0)
82    /// - `lookahead_k`: Number of fast updates per slow update (typically 5-6)
83    /// - `lookahead_alpha`: Interpolation factor for slow weights (typically 0.5)
84    ///
85    /// # Example
86    /// ```
87    /// use optirs_core::optimizers::Ranger;
88    ///
89    /// let optimizer = Ranger::<f32>::new(
90    ///     0.001,  // learning_rate
91    ///     0.9,    // beta1
92    ///     0.999,  // beta2
93    ///     1e-8,   // epsilon
94    ///     0.0,    // weight_decay
95    ///     5,      // lookahead_k
96    ///     0.5     // lookahead_alpha
97    /// ).expect("unwrap failed");
98    /// ```
99    #[allow(clippy::too_many_arguments)]
100    pub fn new(
101        learning_rate: T,
102        beta1: T,
103        beta2: T,
104        epsilon: T,
105        weight_decay: T,
106        lookahead_k: usize,
107        lookahead_alpha: T,
108    ) -> Result<Self> {
109        // Validate parameters
110        if learning_rate.to_f64().expect("unwrap failed") <= 0.0 {
111            return Err(OptimError::InvalidParameter(format!(
112                "learning_rate must be positive, got {}",
113                learning_rate.to_f64().expect("unwrap failed")
114            )));
115        }
116        if beta1.to_f64().expect("unwrap failed") <= 0.0
117            || beta1.to_f64().expect("unwrap failed") >= 1.0
118        {
119            return Err(OptimError::InvalidParameter(format!(
120                "beta1 must be in (0, 1), got {}",
121                beta1.to_f64().expect("unwrap failed")
122            )));
123        }
124        if beta2.to_f64().expect("unwrap failed") <= 0.0
125            || beta2.to_f64().expect("unwrap failed") >= 1.0
126        {
127            return Err(OptimError::InvalidParameter(format!(
128                "beta2 must be in (0, 1), got {}",
129                beta2.to_f64().expect("unwrap failed")
130            )));
131        }
132        if epsilon.to_f64().expect("unwrap failed") <= 0.0 {
133            return Err(OptimError::InvalidParameter(format!(
134                "epsilon must be positive, got {}",
135                epsilon.to_f64().expect("unwrap failed")
136            )));
137        }
138        if weight_decay.to_f64().expect("unwrap failed") < 0.0 {
139            return Err(OptimError::InvalidParameter(format!(
140                "weight_decay must be non-negative, got {}",
141                weight_decay.to_f64().expect("unwrap failed")
142            )));
143        }
144        if lookahead_k == 0 {
145            return Err(OptimError::InvalidParameter(
146                "lookahead_k must be positive".to_string(),
147            ));
148        }
149        if lookahead_alpha.to_f64().expect("unwrap failed") <= 0.0
150            || lookahead_alpha.to_f64().expect("unwrap failed") > 1.0
151        {
152            return Err(OptimError::InvalidParameter(format!(
153                "lookahead_alpha must be in (0, 1], got {}",
154                lookahead_alpha.to_f64().expect("unwrap failed")
155            )));
156        }
157
158        Ok(Self {
159            learning_rate,
160            beta1,
161            beta2,
162            epsilon,
163            weight_decay,
164            lookahead_k,
165            lookahead_alpha,
166            momentum: None,
167            velocity: None,
168            slow_weights: None,
169            step_count: 0,
170            slow_update_count: 0,
171        })
172    }
173
174    /// Perform a single optimization step
175    ///
176    /// Combines RAdam (fast weights) with Lookahead (slow weights)
177    ///
178    /// # Example
179    /// ```
180    /// use optirs_core::optimizers::Ranger;
181    /// use scirs2_core::ndarray_ext::array;
182    ///
183    /// let mut optimizer = Ranger::<f32>::default();
184    /// let params = array![1.0, 2.0, 3.0];
185    /// let grads = array![0.1, 0.2, 0.3];
186    ///
187    /// let updated_params = optimizer.step(params.view(), grads.view()).expect("unwrap failed");
188    /// ```
189    pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
190        let n = params.len();
191
192        if grads.len() != n {
193            return Err(OptimError::DimensionMismatch(format!(
194                "Expected gradient size {}, got {}",
195                n,
196                grads.len()
197            )));
198        }
199
200        // Initialize state on first step
201        if self.momentum.is_none() {
202            self.momentum = Some(Array1::zeros(n));
203            self.velocity = Some(Array1::zeros(n));
204            self.slow_weights = Some(params.to_owned());
205        }
206
207        self.step_count += 1;
208        let t = T::from(self.step_count).expect("unwrap failed");
209
210        let momentum = self.momentum.as_mut().expect("unwrap failed");
211        let velocity = self.velocity.as_mut().expect("unwrap failed");
212
213        let one = T::one();
214        let two = T::from(2).expect("unwrap failed");
215
216        // Apply weight decay if configured
217        let effective_grads = if self.weight_decay > T::zero() {
218            grads.to_owned() + &(params.to_owned() * self.weight_decay)
219        } else {
220            grads.to_owned()
221        };
222
223        // RAdam: Update biased first moment
224        for i in 0..n {
225            momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
226        }
227
228        // RAdam: Update biased second moment
229        for i in 0..n {
230            let grad_sq = effective_grads[i] * effective_grads[i];
231            velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
232        }
233
234        // RAdam: Compute bias correction
235        let bias_correction1 = one - self.beta1.powf(t);
236        let bias_correction2 = one - self.beta2.powf(t);
237
238        // RAdam: Compute SMA (Simple Moving Average) length
239        let rho_inf = two / (one - self.beta2) - one;
240        let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
241
242        // RAdam: Apply variance rectification
243        let mut updated_params = params.to_owned();
244
245        if rho_t.to_f64().expect("unwrap failed") > 4.0 {
246            // Use adaptive learning rate with variance rectification
247            let rect_term =
248                ((rho_t - T::from(4).expect("unwrap failed")) * (rho_t - two) * rho_inf
249                    / ((rho_inf - T::from(4).expect("unwrap failed")) * (rho_inf - two) * rho_t))
250                    .sqrt();
251
252            for i in 0..n {
253                let m_hat = momentum[i] / bias_correction1;
254                let v_hat = velocity[i] / bias_correction2;
255                let step_size = self.learning_rate * rect_term / (v_hat.sqrt() + self.epsilon);
256                updated_params[i] = updated_params[i] - step_size * m_hat;
257            }
258        } else {
259            // Use simple momentum update during warmup
260            for i in 0..n {
261                let m_hat = momentum[i] / bias_correction1;
262                updated_params[i] = updated_params[i] - self.learning_rate * m_hat;
263            }
264        }
265
266        // Lookahead: Update slow weights every k steps
267        if self.step_count.is_multiple_of(self.lookahead_k) {
268            let slow = self.slow_weights.as_mut().expect("unwrap failed");
269            for i in 0..n {
270                slow[i] = slow[i] + self.lookahead_alpha * (updated_params[i] - slow[i]);
271            }
272            self.slow_update_count += 1;
273
274            // Synchronize fast weights with slow weights
275            // This is the key to Lookahead: we return the slow weights after update
276            Ok(slow.clone())
277        } else {
278            // Between slow updates, return fast weights
279            Ok(updated_params)
280        }
281    }
282
283    /// Get the number of optimization steps performed
284    pub fn step_count(&self) -> usize {
285        self.step_count
286    }
287
288    /// Get the number of slow weight updates performed
289    pub fn slow_update_count(&self) -> usize {
290        self.slow_update_count
291    }
292
293    /// Reset the optimizer state
294    pub fn reset(&mut self) {
295        self.momentum = None;
296        self.velocity = None;
297        self.slow_weights = None;
298        self.step_count = 0;
299        self.slow_update_count = 0;
300    }
301
302    /// Get the slow weights (Lookahead trajectory)
303    pub fn slow_weights(&self) -> Option<&Array1<T>> {
304        self.slow_weights.as_ref()
305    }
306
307    /// Check if variance rectification is active
308    pub fn is_rectified(&self) -> bool {
309        if self.step_count == 0 {
310            return false;
311        }
312        let t = T::from(self.step_count).expect("unwrap failed");
313        let one = T::one();
314        let two = T::from(2).expect("unwrap failed");
315        let bias_correction2 = one - self.beta2.powf(t);
316        let rho_inf = two / (one - self.beta2) - one;
317        let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
318        rho_t.to_f64().expect("unwrap failed") > 4.0
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use approx::assert_relative_eq;
326    use scirs2_core::ndarray_ext::array;
327
328    #[test]
329    fn test_ranger_creation() {
330        let optimizer = Ranger::<f32>::default();
331        assert_eq!(optimizer.step_count(), 0);
332        assert_eq!(optimizer.slow_update_count(), 0);
333    }
334
335    #[test]
336    fn test_ranger_custom_creation() {
337        let optimizer =
338            Ranger::<f32>::new(0.002, 0.95, 0.9999, 1e-7, 0.01, 6, 0.6).expect("unwrap failed");
339        assert_eq!(optimizer.step_count(), 0);
340    }
341
342    #[test]
343    fn test_ranger_single_step() {
344        let mut optimizer = Ranger::<f32>::default();
345        let params = array![1.0, 2.0, 3.0];
346        let grads = array![0.1, 0.2, 0.3];
347
348        let updated_params = optimizer
349            .step(params.view(), grads.view())
350            .expect("unwrap failed");
351        assert_eq!(updated_params.len(), 3);
352        assert_eq!(optimizer.step_count(), 1);
353
354        for i in 0..3 {
355            assert!(updated_params[i] < params[i]);
356        }
357    }
358
359    #[test]
360    fn test_ranger_slow_updates() {
361        let mut optimizer =
362            Ranger::<f32>::new(0.001, 0.9, 0.999, 1e-8, 0.0, 3, 0.5).expect("unwrap failed");
363        let mut params = array![1.0, 2.0, 3.0];
364
365        for _ in 0..3 {
366            let grads = array![0.1, 0.2, 0.3];
367            params = optimizer
368                .step(params.view(), grads.view())
369                .expect("unwrap failed");
370        }
371        assert_eq!(optimizer.slow_update_count(), 1);
372    }
373
374    #[test]
375    fn test_ranger_convergence() {
376        // Use higher learning rate for this simple convex problem
377        // Default 0.001 is tuned for neural networks
378        let mut optimizer = Ranger::<f64>::new(
379            0.1,   // learning_rate: higher for simple problem
380            0.9,   // beta1
381            0.999, // beta2
382            1e-8,  // epsilon
383            0.0,   // weight_decay
384            5,     // lookahead_k
385            0.5,   // lookahead_alpha
386        )
387        .expect("unwrap failed");
388        let mut params = array![5.0];
389
390        // Ranger combines RAdam (adaptive LR) with Lookahead (slow updates)
391        for _ in 0..500 {
392            let grads = params.mapv(|x| 2.0 * x);
393            params = optimizer
394                .step(params.view(), grads.view())
395                .expect("unwrap failed");
396        }
397
398        assert!(
399            params[0].abs() < 0.1,
400            "Failed to converge, got {}",
401            params[0]
402        );
403    }
404
405    #[test]
406    fn test_ranger_reset() {
407        let mut optimizer = Ranger::<f32>::default();
408        let params = array![1.0, 2.0, 3.0];
409        let grads = array![0.1, 0.2, 0.3];
410
411        for _ in 0..10 {
412            optimizer
413                .step(params.view(), grads.view())
414                .expect("unwrap failed");
415        }
416
417        optimizer.reset();
418        assert_eq!(optimizer.step_count(), 0);
419        assert_eq!(optimizer.slow_update_count(), 0);
420        assert!(optimizer.slow_weights().is_none());
421    }
422
423    #[test]
424    fn test_ranger_rectification() {
425        let mut optimizer = Ranger::<f32>::default();
426        let params = array![1.0];
427        let grads = array![0.1];
428
429        // Initially not rectified
430        assert!(!optimizer.is_rectified());
431
432        // After several steps, should be rectified
433        for _ in 0..10 {
434            optimizer
435                .step(params.view(), grads.view())
436                .expect("unwrap failed");
437        }
438        assert!(optimizer.is_rectified());
439    }
440}