irithyll_core/ensemble/lr_schedule.rs
1//! Learning rate scheduling for streaming gradient boosted trees.
2//!
3//! In standard (batch) gradient boosting the learning rate is fixed for the
4//! entire training run. Streaming ensembles see data indefinitely, so a
5//! fixed rate must balance early convergence against long-term stability.
6//! Learning rate schedulers resolve this tension by adapting the rate over
7//! the lifetime of the model.
8//!
9//! # Provided schedulers
10//!
11//! | Scheduler | Strategy |
12//! |-----------|----------|
13//! | [`ConstantLR`] | Fixed rate -- baseline behaviour, equivalent to no scheduling. |
14//! | [`LinearDecayLR`] | Linearly interpolates from `initial_lr` to `final_lr` over a fixed number of steps, then holds `final_lr`. |
15//! | [`ExponentialDecayLR`] | Multiplicative decay by `gamma` each step, floored at `1e-8` to avoid numerical zero. |
16//! | [`CosineAnnealingLR`] | Periodic cosine wave between `max_lr` and `min_lr`, useful for warm-restart style exploration. |
17//! | [`PlateauLR`] | Monitors the loss and reduces the rate by `factor` when improvement stalls for `patience` steps. |
18//!
19//! # Custom schedulers
20//!
21//! Implement the [`LRScheduler`] trait to build your own schedule:
22//!
23//! ```ignore
24//! use irithyll::ensemble::lr_schedule::LRScheduler;
25//!
26//! #[derive(Clone, Debug)]
27//! struct HalvingLR { lr: f64 }
28//!
29//! impl LRScheduler for HalvingLR {
30//! fn learning_rate(&mut self, step: u64, _loss: f64) -> f64 {
31//! let lr = self.lr;
32//! self.lr *= 0.5_f64.max(1e-8);
33//! lr
34//! }
35//! fn reset(&mut self) { self.lr = 1.0; }
36//! }
37//! ```
38
39use core::f64::consts::PI;
40
41// ---------------------------------------------------------------------------
42// Trait
43// ---------------------------------------------------------------------------
44
45/// A learning rate scheduler for streaming gradient boosted trees.
46///
47/// The ensemble calls [`learning_rate`](LRScheduler::learning_rate) once per
48/// boosting round, passing the monotonically increasing `step` counter and the
49/// most recent loss value. Implementations may use either or both to decide the
50/// rate.
51///
52/// All schedulers must be `Send + Sync` so they can live inside async ensemble
53/// wrappers and be shared across threads.
54pub trait LRScheduler: Send + Sync {
55 /// Return the learning rate for the given `step` and `current_loss`.
56 ///
57 /// # Arguments
58 ///
59 /// * `step` -- Zero-based step counter. Incremented by the caller before
60 /// each invocation (0 on the first call, 1 on the second, ...).
61 /// * `current_loss` -- The most recent loss value observed by the ensemble.
62 /// Schedulers that do not use loss feedback (everything except
63 /// [`PlateauLR`]) may ignore this argument.
64 fn learning_rate(&mut self, step: u64, current_loss: f64) -> f64;
65
66 /// Reset the scheduler to its initial state.
67 ///
68 /// Called when the ensemble is reset (e.g., after a concept-drift event
69 /// triggers a full model rebuild).
70 fn reset(&mut self);
71}
72
73// ---------------------------------------------------------------------------
74// 1. ConstantLR
75// ---------------------------------------------------------------------------
76
77/// Always returns the same learning rate.
78///
79/// This is the simplest scheduler and reproduces the behaviour of a plain
80/// fixed-rate ensemble. It exists so that code paths expecting a `dyn
81/// LRScheduler` can use a constant rate without special-casing.
82///
83/// # Example
84///
85/// ```ignore
86/// use irithyll::ensemble::lr_schedule::{LRScheduler, ConstantLR};
87///
88/// let mut sched = ConstantLR::new(0.05);
89/// assert!(crate::math::abs((sched.learning_rate(0, 1.0) - 0.05)) < f64::EPSILON);
90/// assert!(crate::math::abs((sched.learning_rate(1000, 0.1) - 0.05)) < f64::EPSILON);
91/// ```
92#[derive(Clone, Debug)]
93pub struct ConstantLR {
94 /// The fixed learning rate.
95 lr: f64,
96}
97
98impl ConstantLR {
99 /// Create a constant-rate scheduler.
100 ///
101 /// # Arguments
102 ///
103 /// * `lr` -- The learning rate returned on every call.
104 pub fn new(lr: f64) -> Self {
105 Self { lr }
106 }
107}
108
109impl LRScheduler for ConstantLR {
110 #[inline]
111 fn learning_rate(&mut self, _step: u64, _current_loss: f64) -> f64 {
112 self.lr
113 }
114
115 fn reset(&mut self) {
116 // Nothing to reset -- the rate is stateless.
117 }
118}
119
120// ---------------------------------------------------------------------------
121// 2. LinearDecayLR
122// ---------------------------------------------------------------------------
123
124/// Linearly interpolates the learning rate from `initial_lr` to `final_lr`
125/// over `decay_steps`, then holds `final_lr` forever.
126///
127/// The formula is:
128///
129/// ```text
130/// lr = initial_lr - (initial_lr - final_lr) * min(step / decay_steps, 1.0)
131/// ```
132///
133/// This gives a smooth ramp-down that reaches `final_lr` at exactly
134/// `step == decay_steps` and clamps there for all subsequent steps.
135///
136/// # Example
137///
138/// ```ignore
139/// use irithyll::ensemble::lr_schedule::{LRScheduler, LinearDecayLR};
140///
141/// let mut sched = LinearDecayLR::new(0.1, 0.01, 100);
142/// // At step 0 we get the initial rate.
143/// assert!(crate::math::abs((sched.learning_rate(0, 0.0) - 0.1)) < 1e-12);
144/// // At step 50 we're halfway.
145/// assert!(crate::math::abs((sched.learning_rate(50, 0.0) - 0.055)) < 1e-12);
146/// // At step 100 we've reached the final rate.
147/// assert!(crate::math::abs((sched.learning_rate(100, 0.0) - 0.01)) < 1e-12);
148/// // Beyond decay_steps the rate stays clamped.
149/// assert!(crate::math::abs((sched.learning_rate(200, 0.0) - 0.01)) < 1e-12);
150/// ```
151#[derive(Clone, Debug)]
152pub struct LinearDecayLR {
153 /// Starting learning rate.
154 initial_lr: f64,
155 /// Terminal learning rate (held after `decay_steps`).
156 final_lr: f64,
157 /// Number of steps over which the linear ramp is applied.
158 decay_steps: u64,
159}
160
161impl LinearDecayLR {
162 /// Create a linear-decay scheduler.
163 ///
164 /// # Arguments
165 ///
166 /// * `initial_lr` -- Rate at step 0.
167 /// * `final_lr` -- Rate from step `decay_steps` onward.
168 /// * `decay_steps` -- Length of the linear ramp in steps.
169 pub fn new(initial_lr: f64, final_lr: f64, decay_steps: u64) -> Self {
170 Self {
171 initial_lr,
172 final_lr,
173 decay_steps,
174 }
175 }
176}
177
178impl LRScheduler for LinearDecayLR {
179 #[inline]
180 fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
181 let t = if self.decay_steps == 0 {
182 1.0
183 } else {
184 (step as f64 / self.decay_steps as f64).min(1.0)
185 };
186 self.initial_lr - (self.initial_lr - self.final_lr) * t
187 }
188
189 fn reset(&mut self) {
190 // Stateless -- nothing to reset.
191 }
192}
193
194// ---------------------------------------------------------------------------
195// 3. ExponentialDecayLR
196// ---------------------------------------------------------------------------
197
198/// Multiplicative exponential decay: `lr = initial_lr * gamma^step`.
199///
200/// The rate is floored at `1e-8` to prevent numerical underflow that would
201/// effectively freeze learning. For a half-life of *h* steps, set
202/// `gamma = 0.5^(1/h)`.
203///
204/// # Example
205///
206/// ```ignore
207/// use irithyll::ensemble::lr_schedule::{LRScheduler, ExponentialDecayLR};
208///
209/// let mut sched = ExponentialDecayLR::new(1.0, 0.9);
210/// assert!(crate::math::abs((sched.learning_rate(0, 0.0) - 1.0)) < 1e-12);
211/// assert!(crate::math::abs((sched.learning_rate(1, 0.0) - 0.9)) < 1e-12);
212/// assert!(crate::math::abs((sched.learning_rate(2, 0.0) - 0.81)) < 1e-12);
213/// ```
214#[derive(Clone, Debug)]
215pub struct ExponentialDecayLR {
216 /// Learning rate at step 0.
217 initial_lr: f64,
218 /// Per-step multiplicative factor (typically in (0, 1)).
219 gamma: f64,
220}
221
222impl ExponentialDecayLR {
223 /// Create an exponential-decay scheduler.
224 ///
225 /// # Arguments
226 ///
227 /// * `initial_lr` -- Rate at step 0.
228 /// * `gamma` -- Multiplicative decay factor applied each step.
229 pub fn new(initial_lr: f64, gamma: f64) -> Self {
230 Self { initial_lr, gamma }
231 }
232}
233
234impl LRScheduler for ExponentialDecayLR {
235 #[inline]
236 fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
237 crate::math::fmax(
238 self.initial_lr * crate::math::powi(self.gamma, step as i32),
239 1e-8,
240 )
241 }
242
243 fn reset(&mut self) {
244 // Stateless -- nothing to reset.
245 }
246}
247
248// ---------------------------------------------------------------------------
249// 4. CosineAnnealingLR
250// ---------------------------------------------------------------------------
251
252/// Cosine annealing with periodic warm restarts.
253///
254/// The learning rate follows a cosine curve between `max_lr` and `min_lr`,
255/// repeating every `period` steps:
256///
257/// ```text
258/// lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * (step % period) / period))
259/// ```
260///
261/// At the start of each period (`step % period == 0`) the rate jumps back to
262/// `max_lr`, providing a "warm restart" that can help the ensemble escape
263/// local plateaus in a streaming setting.
264///
265/// # Example
266///
267/// ```ignore
268/// use irithyll::ensemble::lr_schedule::{LRScheduler, CosineAnnealingLR};
269///
270/// let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
271/// // Period start → max_lr.
272/// assert!(crate::math::abs((sched.learning_rate(0, 0.0) - 0.1)) < 1e-12);
273/// // Midpoint → min_lr.
274/// assert!(crate::math::abs((sched.learning_rate(50, 0.0) - 0.01)) < 1e-12);
275/// // Full period → back to max_lr.
276/// assert!(crate::math::abs((sched.learning_rate(100, 0.0) - 0.1)) < 1e-12);
277/// ```
278#[derive(Clone, Debug)]
279pub struct CosineAnnealingLR {
280 /// Peak learning rate (at the start of each period).
281 max_lr: f64,
282 /// Trough learning rate (at the midpoint of each period).
283 min_lr: f64,
284 /// Number of steps per cosine cycle.
285 period: u64,
286}
287
288impl CosineAnnealingLR {
289 /// Create a cosine-annealing scheduler.
290 ///
291 /// # Arguments
292 ///
293 /// * `max_lr` -- Rate at the start (and end) of each cosine period.
294 /// * `min_lr` -- Rate at the midpoint of each period.
295 /// * `period` -- Length of one cosine cycle in steps.
296 pub fn new(max_lr: f64, min_lr: f64, period: u64) -> Self {
297 Self {
298 max_lr,
299 min_lr,
300 period,
301 }
302 }
303}
304
305impl LRScheduler for CosineAnnealingLR {
306 #[inline]
307 fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
308 let phase = if self.period == 0 {
309 0.0
310 } else {
311 (step % self.period) as f64 / self.period as f64
312 };
313 self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + crate::math::cos(2.0 * PI * phase))
314 }
315
316 fn reset(&mut self) {
317 // Stateless -- nothing to reset.
318 }
319}
320
321// ---------------------------------------------------------------------------
322// 5. PlateauLR
323// ---------------------------------------------------------------------------
324
325/// Reduce learning rate when the loss plateaus.
326///
327/// Monitors `current_loss` and reduces the rate by `factor` whenever the loss
328/// has not improved for `patience` consecutive steps. "Improved" means the
329/// new loss is strictly less than the best-seen loss.
330///
331/// The rate is floored at `min_lr` to prevent it from vanishing entirely.
332///
333/// # Example
334///
335/// ```ignore
336/// use irithyll::ensemble::lr_schedule::{LRScheduler, PlateauLR};
337///
338/// let mut sched = PlateauLR::new(0.1, 0.5, 3, 0.001);
339///
340/// // Improving loss -- rate stays at 0.1.
341/// assert!(crate::math::abs((sched.learning_rate(0, 1.0) - 0.1)) < 1e-12);
342/// assert!(crate::math::abs((sched.learning_rate(1, 0.9) - 0.1)) < 1e-12);
343///
344/// // Stagnating loss for patience=3 steps.
345/// assert!(crate::math::abs((sched.learning_rate(2, 0.95) - 0.1)) < 1e-12);
346/// assert!(crate::math::abs((sched.learning_rate(3, 0.95) - 0.1)) < 1e-12);
347/// assert!(crate::math::abs((sched.learning_rate(4, 0.95) - 0.1)) < 1e-12);
348///
349/// // Patience exhausted -- rate drops to 0.1 * 0.5 = 0.05.
350/// assert!(crate::math::abs((sched.learning_rate(5, 0.95) - 0.05)) < 1e-12);
351/// ```
352#[derive(Clone, Debug)]
353pub struct PlateauLR {
354 /// Starting learning rate.
355 initial_lr: f64,
356 /// Multiplicative factor applied when patience is exhausted (0 < factor < 1).
357 factor: f64,
358 /// Number of non-improving steps before a reduction.
359 patience: u64,
360 /// Minimum learning rate floor.
361 min_lr: f64,
362
363 // -- internal state --
364 /// Best loss observed so far.
365 best_loss: f64,
366 /// Number of consecutive steps without improvement.
367 steps_without_improvement: u64,
368 /// The current (possibly reduced) learning rate.
369 current_lr: f64,
370}
371
372impl PlateauLR {
373 /// Create a plateau-aware scheduler.
374 ///
375 /// # Arguments
376 ///
377 /// * `initial_lr` -- Starting learning rate.
378 /// * `factor` -- Multiplicative reduction factor (e.g., 0.5 halves the rate).
379 /// * `patience` -- Number of non-improving steps to tolerate before reducing.
380 /// * `min_lr` -- Floor below which the rate will not be reduced.
381 pub fn new(initial_lr: f64, factor: f64, patience: u64, min_lr: f64) -> Self {
382 Self {
383 initial_lr,
384 factor,
385 patience,
386 min_lr,
387 best_loss: f64::INFINITY,
388 steps_without_improvement: 0,
389 current_lr: initial_lr,
390 }
391 }
392}
393
394impl LRScheduler for PlateauLR {
395 fn learning_rate(&mut self, _step: u64, current_loss: f64) -> f64 {
396 if current_loss < self.best_loss {
397 // Improvement -- record and reset counter.
398 self.best_loss = current_loss;
399 self.steps_without_improvement = 0;
400 } else {
401 self.steps_without_improvement += 1;
402
403 if self.steps_without_improvement > self.patience {
404 self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
405 self.steps_without_improvement = 0;
406 }
407 }
408
409 self.current_lr
410 }
411
412 fn reset(&mut self) {
413 self.best_loss = f64::INFINITY;
414 self.steps_without_improvement = 0;
415 self.current_lr = self.initial_lr;
416 }
417}
418
419// ---------------------------------------------------------------------------
420// Tests
421// ---------------------------------------------------------------------------
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use alloc::boxed::Box;
427 use alloc::vec;
428 use alloc::vec::Vec;
429
430 // -- ConstantLR --------------------------------------------------------
431
432 /// Constant scheduler always returns the configured rate regardless of
433 /// step or loss.
434 #[test]
435 fn test_constant_lr() {
436 let mut sched = ConstantLR::new(0.05);
437
438 for step in 0..100 {
439 let lr = sched.learning_rate(step, 999.0);
440 assert!(
441 (lr - 0.05).abs() < f64::EPSILON,
442 "ConstantLR should always return 0.05, got {} at step {}",
443 lr,
444 step,
445 );
446 }
447 }
448
449 // -- LinearDecayLR -----------------------------------------------------
450
451 /// Linear decay interpolates correctly between initial and final.
452 #[test]
453 fn test_linear_decay() {
454 let mut sched = LinearDecayLR::new(0.1, 0.01, 100);
455
456 let lr0 = sched.learning_rate(0, 0.0);
457 assert!(
458 (lr0 - 0.1).abs() < 1e-12,
459 "step 0 should be initial_lr (0.1), got {}",
460 lr0,
461 );
462
463 let lr50 = sched.learning_rate(50, 0.0);
464 let expected_50 = 0.1 - (0.1 - 0.01) * 0.5;
465 assert!(
466 (lr50 - expected_50).abs() < 1e-12,
467 "step 50 should be {}, got {}",
468 expected_50,
469 lr50,
470 );
471
472 let lr100 = sched.learning_rate(100, 0.0);
473 assert!(
474 (lr100 - 0.01).abs() < 1e-12,
475 "step 100 should be final_lr (0.01), got {}",
476 lr100,
477 );
478 }
479
480 /// Linear decay clamps at final_lr for steps beyond decay_steps.
481 #[test]
482 fn test_linear_decay_clamps() {
483 let mut sched = LinearDecayLR::new(0.1, 0.01, 50);
484
485 let lr_before = sched.learning_rate(50, 0.0);
486 let lr_after = sched.learning_rate(200, 0.0);
487 assert!(
488 (lr_before - 0.01).abs() < 1e-12,
489 "at decay_steps should be final_lr, got {}",
490 lr_before,
491 );
492 assert!(
493 (lr_after - 0.01).abs() < 1e-12,
494 "beyond decay_steps should still be final_lr, got {}",
495 lr_after,
496 );
497 }
498
499 // -- ExponentialDecayLR ------------------------------------------------
500
501 /// Exponential decay follows gamma^step correctly.
502 #[test]
503 fn test_exponential_decay() {
504 let mut sched = ExponentialDecayLR::new(1.0, 0.9);
505
506 let lr0 = sched.learning_rate(0, 0.0);
507 assert!(
508 (lr0 - 1.0).abs() < 1e-12,
509 "step 0 should be initial_lr (1.0), got {}",
510 lr0,
511 );
512
513 let lr1 = sched.learning_rate(1, 0.0);
514 assert!(
515 (lr1 - 0.9).abs() < 1e-12,
516 "step 1 should be 0.9, got {}",
517 lr1,
518 );
519
520 let lr2 = sched.learning_rate(2, 0.0);
521 assert!(
522 (lr2 - 0.81).abs() < 1e-12,
523 "step 2 should be 0.81, got {}",
524 lr2,
525 );
526
527 let lr10 = sched.learning_rate(10, 0.0);
528 let expected_10 = 0.9_f64.powi(10);
529 assert!(
530 (lr10 - expected_10).abs() < 1e-10,
531 "step 10 should be {}, got {}",
532 expected_10,
533 lr10,
534 );
535 }
536
537 /// Exponential decay floors at 1e-8, never reaching zero.
538 #[test]
539 fn test_exponential_floor() {
540 let mut sched = ExponentialDecayLR::new(1.0, 0.01);
541
542 // After enough steps, gamma^step would be astronomically small.
543 let lr = sched.learning_rate(10_000, 0.0);
544 assert!(
545 lr >= 1e-8,
546 "exponential decay should floor at 1e-8, got {}",
547 lr,
548 );
549 assert!(
550 (lr - 1e-8).abs() < 1e-15,
551 "at extreme steps the rate should equal the floor, got {}",
552 lr,
553 );
554 }
555
556 // -- CosineAnnealingLR -------------------------------------------------
557
558 /// Cosine annealing hits max at period boundaries and min at midpoints.
559 #[test]
560 fn test_cosine_annealing() {
561 let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
562
563 // At step 0 → max_lr.
564 let lr0 = sched.learning_rate(0, 0.0);
565 assert!(
566 (lr0 - 0.1).abs() < 1e-12,
567 "period start should be max_lr (0.1), got {}",
568 lr0,
569 );
570
571 // At step 50 → min_lr (cos(pi) = -1).
572 let lr50 = sched.learning_rate(50, 0.0);
573 assert!(
574 (lr50 - 0.01).abs() < 1e-12,
575 "period midpoint should be min_lr (0.01), got {}",
576 lr50,
577 );
578
579 // At step 25 → halfway between max and min.
580 let lr25 = sched.learning_rate(25, 0.0);
581 let expected_25 = 0.01 + 0.5 * (0.1 - 0.01) * (1.0 + (2.0 * PI * 0.25).cos());
582 assert!(
583 (lr25 - expected_25).abs() < 1e-12,
584 "quarter-period should be {}, got {}",
585 expected_25,
586 lr25,
587 );
588 }
589
590 /// Cosine annealing wraps correctly at period boundaries.
591 #[test]
592 fn test_cosine_boundaries() {
593 let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
594
595 let at_boundary = sched.learning_rate(100, 0.0);
596 assert!(
597 (at_boundary - 0.1).abs() < 1e-12,
598 "step==period should wrap to max_lr, got {}",
599 at_boundary,
600 );
601
602 let second_mid = sched.learning_rate(150, 0.0);
603 assert!(
604 (second_mid - 0.01).abs() < 1e-12,
605 "second period midpoint should be min_lr, got {}",
606 second_mid,
607 );
608 }
609
610 // -- PlateauLR ---------------------------------------------------------
611
612 /// Plateau scheduler reduces the rate after patience non-improving steps.
613 #[test]
614 fn test_plateau_reduces() {
615 let mut sched = PlateauLR::new(0.1, 0.5, 3, 0.001);
616
617 // First call sets best_loss = 1.0.
618 let lr = sched.learning_rate(0, 1.0);
619 assert!(
620 (lr - 0.1).abs() < 1e-12,
621 "initial rate should be 0.1, got {}",
622 lr
623 );
624
625 // Three non-improving steps (patience = 3).
626 sched.learning_rate(1, 1.0); // counter: 1
627 sched.learning_rate(2, 1.0); // counter: 2
628 sched.learning_rate(3, 1.0); // counter: 3
629
630 // Fourth non-improving step exceeds patience → reduce.
631 let lr_reduced = sched.learning_rate(4, 1.0);
632 assert!(
633 (lr_reduced - 0.05).abs() < 1e-12,
634 "after patience exceeded, rate should be 0.1*0.5 = 0.05, got {}",
635 lr_reduced,
636 );
637 }
638
639 /// Plateau scheduler resets its counter when loss improves.
640 #[test]
641 fn test_plateau_improvement_resets() {
642 let mut sched = PlateauLR::new(0.1, 0.5, 2, 0.001);
643
644 // Establish baseline.
645 sched.learning_rate(0, 1.0);
646
647 // Two non-improving steps (counter: 1, 2).
648 sched.learning_rate(1, 1.5);
649 sched.learning_rate(2, 1.5);
650
651 // Now improve -- counter resets to 0.
652 sched.learning_rate(3, 0.5);
653
654 // One non-improving step after improvement (counter: 1).
655 sched.learning_rate(4, 0.6);
656
657 // Second non-improving step (counter: 2, still <= patience=2).
658 let lr = sched.learning_rate(5, 0.6);
659 assert!(
660 (lr - 0.1).abs() < 1e-12,
661 "improvement should have reset counter; rate should be 0.1, got {}",
662 lr,
663 );
664 }
665
666 /// Plateau scheduler never drops below min_lr.
667 #[test]
668 fn test_plateau_min_lr() {
669 let mut sched = PlateauLR::new(0.1, 0.1, 0, 0.05);
670
671 // With patience=0, every non-improving step triggers a reduction.
672 sched.learning_rate(0, 1.0); // sets best_loss = 1.0
673
674 // Non-improving: counter goes to 1 which exceeds patience (0).
675 sched.learning_rate(1, 1.0); // reduce: 0.1 * 0.1 = 0.01 → clamped to 0.05
676
677 let lr = sched.learning_rate(2, 1.0);
678 assert!(
679 lr >= 0.05 - 1e-12,
680 "rate should never drop below min_lr (0.05), got {}",
681 lr,
682 );
683 }
684
685 /// Plateau reset restores the scheduler to its initial state.
686 #[test]
687 fn test_plateau_reset() {
688 let mut sched = PlateauLR::new(0.1, 0.5, 1, 0.001);
689
690 // Drive the rate down.
691 sched.learning_rate(0, 1.0);
692 sched.learning_rate(1, 1.0);
693 sched.learning_rate(2, 1.0);
694
695 // The rate should have been reduced at least once.
696 let lr_before_reset = sched.current_lr;
697 assert!(
698 lr_before_reset < 0.1,
699 "rate should have decreased before reset, got {}",
700 lr_before_reset,
701 );
702
703 sched.reset();
704
705 let lr_after = sched.learning_rate(0, 10.0);
706 assert!(
707 (lr_after - 0.1).abs() < 1e-12,
708 "after reset, rate should be back to initial_lr (0.1), got {}",
709 lr_after,
710 );
711 }
712
713 // -- Cross-cutting property tests --------------------------------------
714
715 /// Every scheduler must return a strictly positive learning rate for any
716 /// non-negative step and finite loss value.
717 #[test]
718 fn test_all_positive() {
719 let mut schedulers: Vec<Box<dyn LRScheduler>> = vec![
720 Box::new(ConstantLR::new(0.05)),
721 Box::new(LinearDecayLR::new(0.1, 0.001, 100)),
722 Box::new(ExponentialDecayLR::new(1.0, 0.99)),
723 Box::new(CosineAnnealingLR::new(0.1, 0.001, 50)),
724 Box::new(PlateauLR::new(0.1, 0.5, 5, 0.001)),
725 ];
726
727 for (i, sched) in schedulers.iter_mut().enumerate() {
728 for step in 0..500 {
729 let lr = sched.learning_rate(step, 1.0);
730 assert!(
731 lr > 0.0,
732 "scheduler {} returned non-positive lr {} at step {}",
733 i,
734 lr,
735 step,
736 );
737 assert!(
738 lr.is_finite(),
739 "scheduler {} returned non-finite lr at step {}",
740 i,
741 step,
742 );
743 }
744 }
745 }
746}