1use crate::error::{OptimError, Result};
15use scirs2_core::ndarray::ScalarOperand;
16use scirs2_core::ndarray_ext::{Array1, ArrayView1};
17use scirs2_core::numeric::{Float, Zero};
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Ranger<T: Float + ScalarOperand> {
35 learning_rate: T,
37 beta1: T,
38 beta2: T,
39 epsilon: T,
40 weight_decay: T,
41
42 lookahead_k: usize,
44 lookahead_alpha: T,
45
46 momentum: Option<Array1<T>>,
48 velocity: Option<Array1<T>>,
49
50 slow_weights: Option<Array1<T>>,
52
53 step_count: usize,
55 slow_update_count: usize,
56}
57
58impl<T: Float + ScalarOperand> Default for Ranger<T> {
59 fn default() -> Self {
60 Self::new(
61 T::from(0.001).unwrap(), T::from(0.9).unwrap(), T::from(0.999).unwrap(), T::from(1e-8).unwrap(), T::zero(), 5, T::from(0.5).unwrap(), )
69 .unwrap()
70 }
71}
72
73impl<T: Float + ScalarOperand> Ranger<T> {
74 #[allow(clippy::too_many_arguments)]
100 pub fn new(
101 learning_rate: T,
102 beta1: T,
103 beta2: T,
104 epsilon: T,
105 weight_decay: T,
106 lookahead_k: usize,
107 lookahead_alpha: T,
108 ) -> Result<Self> {
109 if learning_rate.to_f64().unwrap() <= 0.0 {
111 return Err(OptimError::InvalidParameter(format!(
112 "learning_rate must be positive, got {}",
113 learning_rate.to_f64().unwrap()
114 )));
115 }
116 if beta1.to_f64().unwrap() <= 0.0 || beta1.to_f64().unwrap() >= 1.0 {
117 return Err(OptimError::InvalidParameter(format!(
118 "beta1 must be in (0, 1), got {}",
119 beta1.to_f64().unwrap()
120 )));
121 }
122 if beta2.to_f64().unwrap() <= 0.0 || beta2.to_f64().unwrap() >= 1.0 {
123 return Err(OptimError::InvalidParameter(format!(
124 "beta2 must be in (0, 1), got {}",
125 beta2.to_f64().unwrap()
126 )));
127 }
128 if epsilon.to_f64().unwrap() <= 0.0 {
129 return Err(OptimError::InvalidParameter(format!(
130 "epsilon must be positive, got {}",
131 epsilon.to_f64().unwrap()
132 )));
133 }
134 if weight_decay.to_f64().unwrap() < 0.0 {
135 return Err(OptimError::InvalidParameter(format!(
136 "weight_decay must be non-negative, got {}",
137 weight_decay.to_f64().unwrap()
138 )));
139 }
140 if lookahead_k == 0 {
141 return Err(OptimError::InvalidParameter(
142 "lookahead_k must be positive".to_string(),
143 ));
144 }
145 if lookahead_alpha.to_f64().unwrap() <= 0.0 || lookahead_alpha.to_f64().unwrap() > 1.0 {
146 return Err(OptimError::InvalidParameter(format!(
147 "lookahead_alpha must be in (0, 1], got {}",
148 lookahead_alpha.to_f64().unwrap()
149 )));
150 }
151
152 Ok(Self {
153 learning_rate,
154 beta1,
155 beta2,
156 epsilon,
157 weight_decay,
158 lookahead_k,
159 lookahead_alpha,
160 momentum: None,
161 velocity: None,
162 slow_weights: None,
163 step_count: 0,
164 slow_update_count: 0,
165 })
166 }
167
168 pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
184 let n = params.len();
185
186 if grads.len() != n {
187 return Err(OptimError::DimensionMismatch(format!(
188 "Expected gradient size {}, got {}",
189 n,
190 grads.len()
191 )));
192 }
193
194 if self.momentum.is_none() {
196 self.momentum = Some(Array1::zeros(n));
197 self.velocity = Some(Array1::zeros(n));
198 self.slow_weights = Some(params.to_owned());
199 }
200
201 self.step_count += 1;
202 let t = T::from(self.step_count).unwrap();
203
204 let momentum = self.momentum.as_mut().unwrap();
205 let velocity = self.velocity.as_mut().unwrap();
206
207 let one = T::one();
208 let two = T::from(2).unwrap();
209
210 let effective_grads = if self.weight_decay > T::zero() {
212 grads.to_owned() + &(params.to_owned() * self.weight_decay)
213 } else {
214 grads.to_owned()
215 };
216
217 for i in 0..n {
219 momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
220 }
221
222 for i in 0..n {
224 let grad_sq = effective_grads[i] * effective_grads[i];
225 velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
226 }
227
228 let bias_correction1 = one - self.beta1.powf(t);
230 let bias_correction2 = one - self.beta2.powf(t);
231
232 let rho_inf = two / (one - self.beta2) - one;
234 let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
235
236 let mut updated_params = params.to_owned();
238
239 if rho_t.to_f64().unwrap() > 4.0 {
240 let rect_term = ((rho_t - T::from(4).unwrap()) * (rho_t - two) * rho_inf
242 / ((rho_inf - T::from(4).unwrap()) * (rho_inf - two) * rho_t))
243 .sqrt();
244
245 for i in 0..n {
246 let m_hat = momentum[i] / bias_correction1;
247 let v_hat = velocity[i] / bias_correction2;
248 let step_size = self.learning_rate * rect_term / (v_hat.sqrt() + self.epsilon);
249 updated_params[i] = updated_params[i] - step_size * m_hat;
250 }
251 } else {
252 for i in 0..n {
254 let m_hat = momentum[i] / bias_correction1;
255 updated_params[i] = updated_params[i] - self.learning_rate * m_hat;
256 }
257 }
258
259 if self.step_count.is_multiple_of(self.lookahead_k) {
261 let slow = self.slow_weights.as_mut().unwrap();
262 for i in 0..n {
263 slow[i] = slow[i] + self.lookahead_alpha * (updated_params[i] - slow[i]);
264 }
265 self.slow_update_count += 1;
266
267 Ok(slow.clone())
270 } else {
271 Ok(updated_params)
273 }
274 }
275
276 pub fn step_count(&self) -> usize {
278 self.step_count
279 }
280
281 pub fn slow_update_count(&self) -> usize {
283 self.slow_update_count
284 }
285
286 pub fn reset(&mut self) {
288 self.momentum = None;
289 self.velocity = None;
290 self.slow_weights = None;
291 self.step_count = 0;
292 self.slow_update_count = 0;
293 }
294
295 pub fn slow_weights(&self) -> Option<&Array1<T>> {
297 self.slow_weights.as_ref()
298 }
299
300 pub fn is_rectified(&self) -> bool {
302 if self.step_count == 0 {
303 return false;
304 }
305 let t = T::from(self.step_count).unwrap();
306 let one = T::one();
307 let two = T::from(2).unwrap();
308 let bias_correction2 = one - self.beta2.powf(t);
309 let rho_inf = two / (one - self.beta2) - one;
310 let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
311 rho_t.to_f64().unwrap() > 4.0
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use approx::assert_relative_eq;
319 use scirs2_core::ndarray_ext::array;
320
321 #[test]
322 fn test_ranger_creation() {
323 let optimizer = Ranger::<f32>::default();
324 assert_eq!(optimizer.step_count(), 0);
325 assert_eq!(optimizer.slow_update_count(), 0);
326 }
327
328 #[test]
329 fn test_ranger_custom_creation() {
330 let optimizer = Ranger::<f32>::new(0.002, 0.95, 0.9999, 1e-7, 0.01, 6, 0.6).unwrap();
331 assert_eq!(optimizer.step_count(), 0);
332 }
333
334 #[test]
335 fn test_ranger_single_step() {
336 let mut optimizer = Ranger::<f32>::default();
337 let params = array![1.0, 2.0, 3.0];
338 let grads = array![0.1, 0.2, 0.3];
339
340 let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
341 assert_eq!(updated_params.len(), 3);
342 assert_eq!(optimizer.step_count(), 1);
343
344 for i in 0..3 {
345 assert!(updated_params[i] < params[i]);
346 }
347 }
348
349 #[test]
350 fn test_ranger_slow_updates() {
351 let mut optimizer = Ranger::<f32>::new(0.001, 0.9, 0.999, 1e-8, 0.0, 3, 0.5).unwrap();
352 let mut params = array![1.0, 2.0, 3.0];
353
354 for _ in 0..3 {
355 let grads = array![0.1, 0.2, 0.3];
356 params = optimizer.step(params.view(), grads.view()).unwrap();
357 }
358 assert_eq!(optimizer.slow_update_count(), 1);
359 }
360
361 #[test]
362 fn test_ranger_convergence() {
363 let mut optimizer = Ranger::<f64>::new(
366 0.1, 0.9, 0.999, 1e-8, 0.0, 5, 0.5, )
374 .unwrap();
375 let mut params = array![5.0];
376
377 for _ in 0..500 {
379 let grads = params.mapv(|x| 2.0 * x);
380 params = optimizer.step(params.view(), grads.view()).unwrap();
381 }
382
383 assert!(
384 params[0].abs() < 0.1,
385 "Failed to converge, got {}",
386 params[0]
387 );
388 }
389
390 #[test]
391 fn test_ranger_reset() {
392 let mut optimizer = Ranger::<f32>::default();
393 let params = array![1.0, 2.0, 3.0];
394 let grads = array![0.1, 0.2, 0.3];
395
396 for _ in 0..10 {
397 optimizer.step(params.view(), grads.view()).unwrap();
398 }
399
400 optimizer.reset();
401 assert_eq!(optimizer.step_count(), 0);
402 assert_eq!(optimizer.slow_update_count(), 0);
403 assert!(optimizer.slow_weights().is_none());
404 }
405
406 #[test]
407 fn test_ranger_rectification() {
408 let mut optimizer = Ranger::<f32>::default();
409 let params = array![1.0];
410 let grads = array![0.1];
411
412 assert!(!optimizer.is_rectified());
414
415 for _ in 0..10 {
417 optimizer.step(params.view(), grads.view()).unwrap();
418 }
419 assert!(optimizer.is_rectified());
420 }
421}