1use crate::error::{OptimError, Result};
15use scirs2_core::ndarray::ScalarOperand;
16use scirs2_core::ndarray_ext::{Array1, ArrayView1};
17use scirs2_core::numeric::{Float, Zero};
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Ranger<T: Float + ScalarOperand> {
35 learning_rate: T,
37 beta1: T,
38 beta2: T,
39 epsilon: T,
40 weight_decay: T,
41
42 lookahead_k: usize,
44 lookahead_alpha: T,
45
46 momentum: Option<Array1<T>>,
48 velocity: Option<Array1<T>>,
49
50 slow_weights: Option<Array1<T>>,
52
53 step_count: usize,
55 slow_update_count: usize,
56}
57
58impl<T: Float + ScalarOperand> Default for Ranger<T> {
59 fn default() -> Self {
60 Self::new(
61 T::from(0.001).expect("unwrap failed"), T::from(0.9).expect("unwrap failed"), T::from(0.999).expect("unwrap failed"), T::from(1e-8).expect("unwrap failed"), T::zero(), 5, T::from(0.5).expect("unwrap failed"), )
69 .expect("unwrap failed")
70 }
71}
72
73impl<T: Float + ScalarOperand> Ranger<T> {
74 #[allow(clippy::too_many_arguments)]
100 pub fn new(
101 learning_rate: T,
102 beta1: T,
103 beta2: T,
104 epsilon: T,
105 weight_decay: T,
106 lookahead_k: usize,
107 lookahead_alpha: T,
108 ) -> Result<Self> {
109 if learning_rate.to_f64().expect("unwrap failed") <= 0.0 {
111 return Err(OptimError::InvalidParameter(format!(
112 "learning_rate must be positive, got {}",
113 learning_rate.to_f64().expect("unwrap failed")
114 )));
115 }
116 if beta1.to_f64().expect("unwrap failed") <= 0.0
117 || beta1.to_f64().expect("unwrap failed") >= 1.0
118 {
119 return Err(OptimError::InvalidParameter(format!(
120 "beta1 must be in (0, 1), got {}",
121 beta1.to_f64().expect("unwrap failed")
122 )));
123 }
124 if beta2.to_f64().expect("unwrap failed") <= 0.0
125 || beta2.to_f64().expect("unwrap failed") >= 1.0
126 {
127 return Err(OptimError::InvalidParameter(format!(
128 "beta2 must be in (0, 1), got {}",
129 beta2.to_f64().expect("unwrap failed")
130 )));
131 }
132 if epsilon.to_f64().expect("unwrap failed") <= 0.0 {
133 return Err(OptimError::InvalidParameter(format!(
134 "epsilon must be positive, got {}",
135 epsilon.to_f64().expect("unwrap failed")
136 )));
137 }
138 if weight_decay.to_f64().expect("unwrap failed") < 0.0 {
139 return Err(OptimError::InvalidParameter(format!(
140 "weight_decay must be non-negative, got {}",
141 weight_decay.to_f64().expect("unwrap failed")
142 )));
143 }
144 if lookahead_k == 0 {
145 return Err(OptimError::InvalidParameter(
146 "lookahead_k must be positive".to_string(),
147 ));
148 }
149 if lookahead_alpha.to_f64().expect("unwrap failed") <= 0.0
150 || lookahead_alpha.to_f64().expect("unwrap failed") > 1.0
151 {
152 return Err(OptimError::InvalidParameter(format!(
153 "lookahead_alpha must be in (0, 1], got {}",
154 lookahead_alpha.to_f64().expect("unwrap failed")
155 )));
156 }
157
158 Ok(Self {
159 learning_rate,
160 beta1,
161 beta2,
162 epsilon,
163 weight_decay,
164 lookahead_k,
165 lookahead_alpha,
166 momentum: None,
167 velocity: None,
168 slow_weights: None,
169 step_count: 0,
170 slow_update_count: 0,
171 })
172 }
173
174 pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
190 let n = params.len();
191
192 if grads.len() != n {
193 return Err(OptimError::DimensionMismatch(format!(
194 "Expected gradient size {}, got {}",
195 n,
196 grads.len()
197 )));
198 }
199
200 if self.momentum.is_none() {
202 self.momentum = Some(Array1::zeros(n));
203 self.velocity = Some(Array1::zeros(n));
204 self.slow_weights = Some(params.to_owned());
205 }
206
207 self.step_count += 1;
208 let t = T::from(self.step_count).expect("unwrap failed");
209
210 let momentum = self.momentum.as_mut().expect("unwrap failed");
211 let velocity = self.velocity.as_mut().expect("unwrap failed");
212
213 let one = T::one();
214 let two = T::from(2).expect("unwrap failed");
215
216 let effective_grads = if self.weight_decay > T::zero() {
218 grads.to_owned() + &(params.to_owned() * self.weight_decay)
219 } else {
220 grads.to_owned()
221 };
222
223 for i in 0..n {
225 momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
226 }
227
228 for i in 0..n {
230 let grad_sq = effective_grads[i] * effective_grads[i];
231 velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
232 }
233
234 let bias_correction1 = one - self.beta1.powf(t);
236 let bias_correction2 = one - self.beta2.powf(t);
237
238 let rho_inf = two / (one - self.beta2) - one;
240 let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
241
242 let mut updated_params = params.to_owned();
244
245 if rho_t.to_f64().expect("unwrap failed") > 4.0 {
246 let rect_term =
248 ((rho_t - T::from(4).expect("unwrap failed")) * (rho_t - two) * rho_inf
249 / ((rho_inf - T::from(4).expect("unwrap failed")) * (rho_inf - two) * rho_t))
250 .sqrt();
251
252 for i in 0..n {
253 let m_hat = momentum[i] / bias_correction1;
254 let v_hat = velocity[i] / bias_correction2;
255 let step_size = self.learning_rate * rect_term / (v_hat.sqrt() + self.epsilon);
256 updated_params[i] = updated_params[i] - step_size * m_hat;
257 }
258 } else {
259 for i in 0..n {
261 let m_hat = momentum[i] / bias_correction1;
262 updated_params[i] = updated_params[i] - self.learning_rate * m_hat;
263 }
264 }
265
266 if self.step_count.is_multiple_of(self.lookahead_k) {
268 let slow = self.slow_weights.as_mut().expect("unwrap failed");
269 for i in 0..n {
270 slow[i] = slow[i] + self.lookahead_alpha * (updated_params[i] - slow[i]);
271 }
272 self.slow_update_count += 1;
273
274 Ok(slow.clone())
277 } else {
278 Ok(updated_params)
280 }
281 }
282
283 pub fn step_count(&self) -> usize {
285 self.step_count
286 }
287
288 pub fn slow_update_count(&self) -> usize {
290 self.slow_update_count
291 }
292
293 pub fn reset(&mut self) {
295 self.momentum = None;
296 self.velocity = None;
297 self.slow_weights = None;
298 self.step_count = 0;
299 self.slow_update_count = 0;
300 }
301
302 pub fn slow_weights(&self) -> Option<&Array1<T>> {
304 self.slow_weights.as_ref()
305 }
306
307 pub fn is_rectified(&self) -> bool {
309 if self.step_count == 0 {
310 return false;
311 }
312 let t = T::from(self.step_count).expect("unwrap failed");
313 let one = T::one();
314 let two = T::from(2).expect("unwrap failed");
315 let bias_correction2 = one - self.beta2.powf(t);
316 let rho_inf = two / (one - self.beta2) - one;
317 let rho_t = rho_inf - two * t * self.beta2.powf(t) / bias_correction2;
318 rho_t.to_f64().expect("unwrap failed") > 4.0
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use approx::assert_relative_eq;
326 use scirs2_core::ndarray_ext::array;
327
328 #[test]
329 fn test_ranger_creation() {
330 let optimizer = Ranger::<f32>::default();
331 assert_eq!(optimizer.step_count(), 0);
332 assert_eq!(optimizer.slow_update_count(), 0);
333 }
334
335 #[test]
336 fn test_ranger_custom_creation() {
337 let optimizer =
338 Ranger::<f32>::new(0.002, 0.95, 0.9999, 1e-7, 0.01, 6, 0.6).expect("unwrap failed");
339 assert_eq!(optimizer.step_count(), 0);
340 }
341
342 #[test]
343 fn test_ranger_single_step() {
344 let mut optimizer = Ranger::<f32>::default();
345 let params = array![1.0, 2.0, 3.0];
346 let grads = array![0.1, 0.2, 0.3];
347
348 let updated_params = optimizer
349 .step(params.view(), grads.view())
350 .expect("unwrap failed");
351 assert_eq!(updated_params.len(), 3);
352 assert_eq!(optimizer.step_count(), 1);
353
354 for i in 0..3 {
355 assert!(updated_params[i] < params[i]);
356 }
357 }
358
359 #[test]
360 fn test_ranger_slow_updates() {
361 let mut optimizer =
362 Ranger::<f32>::new(0.001, 0.9, 0.999, 1e-8, 0.0, 3, 0.5).expect("unwrap failed");
363 let mut params = array![1.0, 2.0, 3.0];
364
365 for _ in 0..3 {
366 let grads = array![0.1, 0.2, 0.3];
367 params = optimizer
368 .step(params.view(), grads.view())
369 .expect("unwrap failed");
370 }
371 assert_eq!(optimizer.slow_update_count(), 1);
372 }
373
374 #[test]
375 fn test_ranger_convergence() {
376 let mut optimizer = Ranger::<f64>::new(
379 0.1, 0.9, 0.999, 1e-8, 0.0, 5, 0.5, )
387 .expect("unwrap failed");
388 let mut params = array![5.0];
389
390 for _ in 0..500 {
392 let grads = params.mapv(|x| 2.0 * x);
393 params = optimizer
394 .step(params.view(), grads.view())
395 .expect("unwrap failed");
396 }
397
398 assert!(
399 params[0].abs() < 0.1,
400 "Failed to converge, got {}",
401 params[0]
402 );
403 }
404
405 #[test]
406 fn test_ranger_reset() {
407 let mut optimizer = Ranger::<f32>::default();
408 let params = array![1.0, 2.0, 3.0];
409 let grads = array![0.1, 0.2, 0.3];
410
411 for _ in 0..10 {
412 optimizer
413 .step(params.view(), grads.view())
414 .expect("unwrap failed");
415 }
416
417 optimizer.reset();
418 assert_eq!(optimizer.step_count(), 0);
419 assert_eq!(optimizer.slow_update_count(), 0);
420 assert!(optimizer.slow_weights().is_none());
421 }
422
423 #[test]
424 fn test_ranger_rectification() {
425 let mut optimizer = Ranger::<f32>::default();
426 let params = array![1.0];
427 let grads = array![0.1];
428
429 assert!(!optimizer.is_rectified());
431
432 for _ in 0..10 {
434 optimizer
435 .step(params.view(), grads.view())
436 .expect("unwrap failed");
437 }
438 assert!(optimizer.is_rectified());
439 }
440}