Skip to main content

irithyll_core/ensemble/
lr_schedule.rs

1//! Learning rate scheduling for streaming gradient boosted trees.
2//!
3//! In standard (batch) gradient boosting the learning rate is fixed for the
4//! entire training run. Streaming ensembles see data indefinitely, so a
5//! fixed rate must balance early convergence against long-term stability.
6//! Learning rate schedulers resolve this tension by adapting the rate over
7//! the lifetime of the model.
8//!
9//! # Provided schedulers
10//!
11//! | Scheduler | Strategy |
12//! |-----------|----------|
13//! | [`ConstantLR`] | Fixed rate -- baseline behaviour, equivalent to no scheduling. |
14//! | [`LinearDecayLR`] | Linearly interpolates from `initial_lr` to `final_lr` over a fixed number of steps, then holds `final_lr`. |
15//! | [`ExponentialDecayLR`] | Multiplicative decay by `gamma` each step, floored at `1e-8` to avoid numerical zero. |
16//! | [`CosineAnnealingLR`] | Periodic cosine wave between `max_lr` and `min_lr`, useful for warm-restart style exploration. |
17//! | [`PlateauLR`] | Monitors the loss and reduces the rate by `factor` when improvement stalls for `patience` steps. |
18//!
19//! # Custom schedulers
20//!
21//! Implement the [`LRScheduler`] trait to build your own schedule:
22//!
23//! ```ignore
24//! use irithyll::ensemble::lr_schedule::LRScheduler;
25//!
26//! #[derive(Clone, Debug)]
27//! struct HalvingLR { lr: f64 }
28//!
29//! impl LRScheduler for HalvingLR {
30//!     fn learning_rate(&mut self, step: u64, _loss: f64) -> f64 {
31//!         let lr = self.lr;
32//!         self.lr *= 0.5_f64.max(1e-8);
33//!         lr
34//!     }
35//!     fn reset(&mut self) { self.lr = 1.0; }
36//! }
37//! ```
38
39use core::f64::consts::PI;
40
41// ---------------------------------------------------------------------------
42// Trait
43// ---------------------------------------------------------------------------
44
45/// A learning rate scheduler for streaming gradient boosted trees.
46///
47/// The ensemble calls [`learning_rate`](LRScheduler::learning_rate) once per
48/// boosting round, passing the monotonically increasing `step` counter and the
49/// most recent loss value. Implementations may use either or both to decide the
50/// rate.
51///
52/// All schedulers must be `Send + Sync` so they can live inside async ensemble
53/// wrappers and be shared across threads.
54pub trait LRScheduler: Send + Sync {
55    /// Return the learning rate for the given `step` and `current_loss`.
56    ///
57    /// # Arguments
58    ///
59    /// * `step` -- Zero-based step counter. Incremented by the caller before
60    ///   each invocation (0 on the first call, 1 on the second, ...).
61    /// * `current_loss` -- The most recent loss value observed by the ensemble.
62    ///   Schedulers that do not use loss feedback (everything except
63    ///   [`PlateauLR`]) may ignore this argument.
64    fn learning_rate(&mut self, step: u64, current_loss: f64) -> f64;
65
66    /// Reset the scheduler to its initial state.
67    ///
68    /// Called when the ensemble is reset (e.g., after a concept-drift event
69    /// triggers a full model rebuild).
70    fn reset(&mut self);
71}
72
73// ---------------------------------------------------------------------------
74// 1. ConstantLR
75// ---------------------------------------------------------------------------
76
77/// Always returns the same learning rate.
78///
79/// This is the simplest scheduler and reproduces the behaviour of a plain
80/// fixed-rate ensemble. It exists so that code paths expecting a `dyn
81/// LRScheduler` can use a constant rate without special-casing.
82///
83/// # Example
84///
85/// ```ignore
86/// use irithyll::ensemble::lr_schedule::{LRScheduler, ConstantLR};
87///
88/// let mut sched = ConstantLR::new(0.05);
89/// assert!(crate::math::abs((sched.learning_rate(0, 1.0) - 0.05)) < f64::EPSILON);
90/// assert!(crate::math::abs((sched.learning_rate(1000, 0.1) - 0.05)) < f64::EPSILON);
91/// ```
92#[derive(Clone, Debug)]
93pub struct ConstantLR {
94    /// The fixed learning rate.
95    lr: f64,
96}
97
98impl ConstantLR {
99    /// Create a constant-rate scheduler.
100    ///
101    /// # Arguments
102    ///
103    /// * `lr` -- The learning rate returned on every call.
104    pub fn new(lr: f64) -> Self {
105        Self { lr }
106    }
107}
108
109impl LRScheduler for ConstantLR {
110    #[inline]
111    fn learning_rate(&mut self, _step: u64, _current_loss: f64) -> f64 {
112        self.lr
113    }
114
115    fn reset(&mut self) {
116        // Nothing to reset -- the rate is stateless.
117    }
118}
119
120// ---------------------------------------------------------------------------
121// 2. LinearDecayLR
122// ---------------------------------------------------------------------------
123
124/// Linearly interpolates the learning rate from `initial_lr` to `final_lr`
125/// over `decay_steps`, then holds `final_lr` forever.
126///
127/// The formula is:
128///
129/// ```text
130/// lr = initial_lr - (initial_lr - final_lr) * min(step / decay_steps, 1.0)
131/// ```
132///
133/// This gives a smooth ramp-down that reaches `final_lr` at exactly
134/// `step == decay_steps` and clamps there for all subsequent steps.
135///
136/// # Example
137///
138/// ```ignore
139/// use irithyll::ensemble::lr_schedule::{LRScheduler, LinearDecayLR};
140///
141/// let mut sched = LinearDecayLR::new(0.1, 0.01, 100);
142/// // At step 0 we get the initial rate.
143/// assert!(crate::math::abs((sched.learning_rate(0, 0.0) - 0.1)) < 1e-12);
144/// // At step 50 we're halfway.
145/// assert!(crate::math::abs((sched.learning_rate(50, 0.0) - 0.055)) < 1e-12);
146/// // At step 100 we've reached the final rate.
147/// assert!(crate::math::abs((sched.learning_rate(100, 0.0) - 0.01)) < 1e-12);
148/// // Beyond decay_steps the rate stays clamped.
149/// assert!(crate::math::abs((sched.learning_rate(200, 0.0) - 0.01)) < 1e-12);
150/// ```
151#[derive(Clone, Debug)]
152pub struct LinearDecayLR {
153    /// Starting learning rate.
154    initial_lr: f64,
155    /// Terminal learning rate (held after `decay_steps`).
156    final_lr: f64,
157    /// Number of steps over which the linear ramp is applied.
158    decay_steps: u64,
159}
160
161impl LinearDecayLR {
162    /// Create a linear-decay scheduler.
163    ///
164    /// # Arguments
165    ///
166    /// * `initial_lr` -- Rate at step 0.
167    /// * `final_lr` -- Rate from step `decay_steps` onward.
168    /// * `decay_steps` -- Length of the linear ramp in steps.
169    pub fn new(initial_lr: f64, final_lr: f64, decay_steps: u64) -> Self {
170        Self {
171            initial_lr,
172            final_lr,
173            decay_steps,
174        }
175    }
176}
177
178impl LRScheduler for LinearDecayLR {
179    #[inline]
180    fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
181        let t = if self.decay_steps == 0 {
182            1.0
183        } else {
184            (step as f64 / self.decay_steps as f64).min(1.0)
185        };
186        self.initial_lr - (self.initial_lr - self.final_lr) * t
187    }
188
189    fn reset(&mut self) {
190        // Stateless -- nothing to reset.
191    }
192}
193
194// ---------------------------------------------------------------------------
195// 3. ExponentialDecayLR
196// ---------------------------------------------------------------------------
197
198/// Multiplicative exponential decay: `lr = initial_lr * gamma^step`.
199///
200/// The rate is floored at `1e-8` to prevent numerical underflow that would
201/// effectively freeze learning. For a half-life of *h* steps, set
202/// `gamma = 0.5^(1/h)`.
203///
204/// # Example
205///
206/// ```ignore
207/// use irithyll::ensemble::lr_schedule::{LRScheduler, ExponentialDecayLR};
208///
209/// let mut sched = ExponentialDecayLR::new(1.0, 0.9);
210/// assert!(crate::math::abs((sched.learning_rate(0, 0.0) - 1.0)) < 1e-12);
211/// assert!(crate::math::abs((sched.learning_rate(1, 0.0) - 0.9)) < 1e-12);
212/// assert!(crate::math::abs((sched.learning_rate(2, 0.0) - 0.81)) < 1e-12);
213/// ```
214#[derive(Clone, Debug)]
215pub struct ExponentialDecayLR {
216    /// Learning rate at step 0.
217    initial_lr: f64,
218    /// Per-step multiplicative factor (typically in (0, 1)).
219    gamma: f64,
220}
221
222impl ExponentialDecayLR {
223    /// Create an exponential-decay scheduler.
224    ///
225    /// # Arguments
226    ///
227    /// * `initial_lr` -- Rate at step 0.
228    /// * `gamma` -- Multiplicative decay factor applied each step.
229    pub fn new(initial_lr: f64, gamma: f64) -> Self {
230        Self { initial_lr, gamma }
231    }
232}
233
234impl LRScheduler for ExponentialDecayLR {
235    #[inline]
236    fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
237        crate::math::fmax(
238            self.initial_lr * crate::math::powi(self.gamma, step as i32),
239            1e-8,
240        )
241    }
242
243    fn reset(&mut self) {
244        // Stateless -- nothing to reset.
245    }
246}
247
248// ---------------------------------------------------------------------------
249// 4. CosineAnnealingLR
250// ---------------------------------------------------------------------------
251
252/// Cosine annealing with periodic warm restarts.
253///
254/// The learning rate follows a cosine curve between `max_lr` and `min_lr`,
255/// repeating every `period` steps:
256///
257/// ```text
258/// lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * (step % period) / period))
259/// ```
260///
261/// At the start of each period (`step % period == 0`) the rate jumps back to
262/// `max_lr`, providing a "warm restart" that can help the ensemble escape
263/// local plateaus in a streaming setting.
264///
265/// # Example
266///
267/// ```ignore
268/// use irithyll::ensemble::lr_schedule::{LRScheduler, CosineAnnealingLR};
269///
270/// let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
271/// // Period start → max_lr.
272/// assert!(crate::math::abs((sched.learning_rate(0, 0.0) - 0.1)) < 1e-12);
273/// // Midpoint → min_lr.
274/// assert!(crate::math::abs((sched.learning_rate(50, 0.0) - 0.01)) < 1e-12);
275/// // Full period → back to max_lr.
276/// assert!(crate::math::abs((sched.learning_rate(100, 0.0) - 0.1)) < 1e-12);
277/// ```
278#[derive(Clone, Debug)]
279pub struct CosineAnnealingLR {
280    /// Peak learning rate (at the start of each period).
281    max_lr: f64,
282    /// Trough learning rate (at the midpoint of each period).
283    min_lr: f64,
284    /// Number of steps per cosine cycle.
285    period: u64,
286}
287
288impl CosineAnnealingLR {
289    /// Create a cosine-annealing scheduler.
290    ///
291    /// # Arguments
292    ///
293    /// * `max_lr` -- Rate at the start (and end) of each cosine period.
294    /// * `min_lr` -- Rate at the midpoint of each period.
295    /// * `period` -- Length of one cosine cycle in steps.
296    pub fn new(max_lr: f64, min_lr: f64, period: u64) -> Self {
297        Self {
298            max_lr,
299            min_lr,
300            period,
301        }
302    }
303}
304
305impl LRScheduler for CosineAnnealingLR {
306    #[inline]
307    fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
308        let phase = if self.period == 0 {
309            0.0
310        } else {
311            (step % self.period) as f64 / self.period as f64
312        };
313        self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + crate::math::cos(2.0 * PI * phase))
314    }
315
316    fn reset(&mut self) {
317        // Stateless -- nothing to reset.
318    }
319}
320
321// ---------------------------------------------------------------------------
322// 5. PlateauLR
323// ---------------------------------------------------------------------------
324
325/// Reduce learning rate when the loss plateaus.
326///
327/// Monitors `current_loss` and reduces the rate by `factor` whenever the loss
328/// has not improved for `patience` consecutive steps. "Improved" means the
329/// new loss is strictly less than the best-seen loss.
330///
331/// The rate is floored at `min_lr` to prevent it from vanishing entirely.
332///
333/// # Example
334///
335/// ```ignore
336/// use irithyll::ensemble::lr_schedule::{LRScheduler, PlateauLR};
337///
338/// let mut sched = PlateauLR::new(0.1, 0.5, 3, 0.001);
339///
340/// // Improving loss -- rate stays at 0.1.
341/// assert!(crate::math::abs((sched.learning_rate(0, 1.0) - 0.1)) < 1e-12);
342/// assert!(crate::math::abs((sched.learning_rate(1, 0.9) - 0.1)) < 1e-12);
343///
344/// // Stagnating loss for patience=3 steps.
345/// assert!(crate::math::abs((sched.learning_rate(2, 0.95) - 0.1)) < 1e-12);
346/// assert!(crate::math::abs((sched.learning_rate(3, 0.95) - 0.1)) < 1e-12);
347/// assert!(crate::math::abs((sched.learning_rate(4, 0.95) - 0.1)) < 1e-12);
348///
349/// // Patience exhausted -- rate drops to 0.1 * 0.5 = 0.05.
350/// assert!(crate::math::abs((sched.learning_rate(5, 0.95) - 0.05)) < 1e-12);
351/// ```
352#[derive(Clone, Debug)]
353pub struct PlateauLR {
354    /// Starting learning rate.
355    initial_lr: f64,
356    /// Multiplicative factor applied when patience is exhausted (0 < factor < 1).
357    factor: f64,
358    /// Number of non-improving steps before a reduction.
359    patience: u64,
360    /// Minimum learning rate floor.
361    min_lr: f64,
362
363    // -- internal state --
364    /// Best loss observed so far.
365    best_loss: f64,
366    /// Number of consecutive steps without improvement.
367    steps_without_improvement: u64,
368    /// The current (possibly reduced) learning rate.
369    current_lr: f64,
370}
371
372impl PlateauLR {
373    /// Create a plateau-aware scheduler.
374    ///
375    /// # Arguments
376    ///
377    /// * `initial_lr` -- Starting learning rate.
378    /// * `factor` -- Multiplicative reduction factor (e.g., 0.5 halves the rate).
379    /// * `patience` -- Number of non-improving steps to tolerate before reducing.
380    /// * `min_lr` -- Floor below which the rate will not be reduced.
381    pub fn new(initial_lr: f64, factor: f64, patience: u64, min_lr: f64) -> Self {
382        Self {
383            initial_lr,
384            factor,
385            patience,
386            min_lr,
387            best_loss: f64::INFINITY,
388            steps_without_improvement: 0,
389            current_lr: initial_lr,
390        }
391    }
392}
393
394impl LRScheduler for PlateauLR {
395    fn learning_rate(&mut self, _step: u64, current_loss: f64) -> f64 {
396        if current_loss < self.best_loss {
397            // Improvement -- record and reset counter.
398            self.best_loss = current_loss;
399            self.steps_without_improvement = 0;
400        } else {
401            self.steps_without_improvement += 1;
402
403            if self.steps_without_improvement > self.patience {
404                self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
405                self.steps_without_improvement = 0;
406            }
407        }
408
409        self.current_lr
410    }
411
412    fn reset(&mut self) {
413        self.best_loss = f64::INFINITY;
414        self.steps_without_improvement = 0;
415        self.current_lr = self.initial_lr;
416    }
417}
418
419// ---------------------------------------------------------------------------
420// Tests
421// ---------------------------------------------------------------------------
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use alloc::boxed::Box;
427    use alloc::vec;
428    use alloc::vec::Vec;
429
430    // -- ConstantLR --------------------------------------------------------
431
432    /// Constant scheduler always returns the configured rate regardless of
433    /// step or loss.
434    #[test]
435    fn test_constant_lr() {
436        let mut sched = ConstantLR::new(0.05);
437
438        for step in 0..100 {
439            let lr = sched.learning_rate(step, 999.0);
440            assert!(
441                (lr - 0.05).abs() < f64::EPSILON,
442                "ConstantLR should always return 0.05, got {} at step {}",
443                lr,
444                step,
445            );
446        }
447    }
448
449    // -- LinearDecayLR -----------------------------------------------------
450
451    /// Linear decay interpolates correctly between initial and final.
452    #[test]
453    fn test_linear_decay() {
454        let mut sched = LinearDecayLR::new(0.1, 0.01, 100);
455
456        let lr0 = sched.learning_rate(0, 0.0);
457        assert!(
458            (lr0 - 0.1).abs() < 1e-12,
459            "step 0 should be initial_lr (0.1), got {}",
460            lr0,
461        );
462
463        let lr50 = sched.learning_rate(50, 0.0);
464        let expected_50 = 0.1 - (0.1 - 0.01) * 0.5;
465        assert!(
466            (lr50 - expected_50).abs() < 1e-12,
467            "step 50 should be {}, got {}",
468            expected_50,
469            lr50,
470        );
471
472        let lr100 = sched.learning_rate(100, 0.0);
473        assert!(
474            (lr100 - 0.01).abs() < 1e-12,
475            "step 100 should be final_lr (0.01), got {}",
476            lr100,
477        );
478    }
479
480    /// Linear decay clamps at final_lr for steps beyond decay_steps.
481    #[test]
482    fn test_linear_decay_clamps() {
483        let mut sched = LinearDecayLR::new(0.1, 0.01, 50);
484
485        let lr_before = sched.learning_rate(50, 0.0);
486        let lr_after = sched.learning_rate(200, 0.0);
487        assert!(
488            (lr_before - 0.01).abs() < 1e-12,
489            "at decay_steps should be final_lr, got {}",
490            lr_before,
491        );
492        assert!(
493            (lr_after - 0.01).abs() < 1e-12,
494            "beyond decay_steps should still be final_lr, got {}",
495            lr_after,
496        );
497    }
498
499    // -- ExponentialDecayLR ------------------------------------------------
500
501    /// Exponential decay follows gamma^step correctly.
502    #[test]
503    fn test_exponential_decay() {
504        let mut sched = ExponentialDecayLR::new(1.0, 0.9);
505
506        let lr0 = sched.learning_rate(0, 0.0);
507        assert!(
508            (lr0 - 1.0).abs() < 1e-12,
509            "step 0 should be initial_lr (1.0), got {}",
510            lr0,
511        );
512
513        let lr1 = sched.learning_rate(1, 0.0);
514        assert!(
515            (lr1 - 0.9).abs() < 1e-12,
516            "step 1 should be 0.9, got {}",
517            lr1,
518        );
519
520        let lr2 = sched.learning_rate(2, 0.0);
521        assert!(
522            (lr2 - 0.81).abs() < 1e-12,
523            "step 2 should be 0.81, got {}",
524            lr2,
525        );
526
527        let lr10 = sched.learning_rate(10, 0.0);
528        let expected_10 = 0.9_f64.powi(10);
529        assert!(
530            (lr10 - expected_10).abs() < 1e-10,
531            "step 10 should be {}, got {}",
532            expected_10,
533            lr10,
534        );
535    }
536
537    /// Exponential decay floors at 1e-8, never reaching zero.
538    #[test]
539    fn test_exponential_floor() {
540        let mut sched = ExponentialDecayLR::new(1.0, 0.01);
541
542        // After enough steps, gamma^step would be astronomically small.
543        let lr = sched.learning_rate(10_000, 0.0);
544        assert!(
545            lr >= 1e-8,
546            "exponential decay should floor at 1e-8, got {}",
547            lr,
548        );
549        assert!(
550            (lr - 1e-8).abs() < 1e-15,
551            "at extreme steps the rate should equal the floor, got {}",
552            lr,
553        );
554    }
555
556    // -- CosineAnnealingLR -------------------------------------------------
557
558    /// Cosine annealing hits max at period boundaries and min at midpoints.
559    #[test]
560    fn test_cosine_annealing() {
561        let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
562
563        // At step 0 → max_lr.
564        let lr0 = sched.learning_rate(0, 0.0);
565        assert!(
566            (lr0 - 0.1).abs() < 1e-12,
567            "period start should be max_lr (0.1), got {}",
568            lr0,
569        );
570
571        // At step 50 → min_lr (cos(pi) = -1).
572        let lr50 = sched.learning_rate(50, 0.0);
573        assert!(
574            (lr50 - 0.01).abs() < 1e-12,
575            "period midpoint should be min_lr (0.01), got {}",
576            lr50,
577        );
578
579        // At step 25 → halfway between max and min.
580        let lr25 = sched.learning_rate(25, 0.0);
581        let expected_25 = 0.01 + 0.5 * (0.1 - 0.01) * (1.0 + (2.0 * PI * 0.25).cos());
582        assert!(
583            (lr25 - expected_25).abs() < 1e-12,
584            "quarter-period should be {}, got {}",
585            expected_25,
586            lr25,
587        );
588    }
589
590    /// Cosine annealing wraps correctly at period boundaries.
591    #[test]
592    fn test_cosine_boundaries() {
593        let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
594
595        let at_boundary = sched.learning_rate(100, 0.0);
596        assert!(
597            (at_boundary - 0.1).abs() < 1e-12,
598            "step==period should wrap to max_lr, got {}",
599            at_boundary,
600        );
601
602        let second_mid = sched.learning_rate(150, 0.0);
603        assert!(
604            (second_mid - 0.01).abs() < 1e-12,
605            "second period midpoint should be min_lr, got {}",
606            second_mid,
607        );
608    }
609
610    // -- PlateauLR ---------------------------------------------------------
611
612    /// Plateau scheduler reduces the rate after patience non-improving steps.
613    #[test]
614    fn test_plateau_reduces() {
615        let mut sched = PlateauLR::new(0.1, 0.5, 3, 0.001);
616
617        // First call sets best_loss = 1.0.
618        let lr = sched.learning_rate(0, 1.0);
619        assert!(
620            (lr - 0.1).abs() < 1e-12,
621            "initial rate should be 0.1, got {}",
622            lr
623        );
624
625        // Three non-improving steps (patience = 3).
626        sched.learning_rate(1, 1.0); // counter: 1
627        sched.learning_rate(2, 1.0); // counter: 2
628        sched.learning_rate(3, 1.0); // counter: 3
629
630        // Fourth non-improving step exceeds patience → reduce.
631        let lr_reduced = sched.learning_rate(4, 1.0);
632        assert!(
633            (lr_reduced - 0.05).abs() < 1e-12,
634            "after patience exceeded, rate should be 0.1*0.5 = 0.05, got {}",
635            lr_reduced,
636        );
637    }
638
639    /// Plateau scheduler resets its counter when loss improves.
640    #[test]
641    fn test_plateau_improvement_resets() {
642        let mut sched = PlateauLR::new(0.1, 0.5, 2, 0.001);
643
644        // Establish baseline.
645        sched.learning_rate(0, 1.0);
646
647        // Two non-improving steps (counter: 1, 2).
648        sched.learning_rate(1, 1.5);
649        sched.learning_rate(2, 1.5);
650
651        // Now improve -- counter resets to 0.
652        sched.learning_rate(3, 0.5);
653
654        // One non-improving step after improvement (counter: 1).
655        sched.learning_rate(4, 0.6);
656
657        // Second non-improving step (counter: 2, still <= patience=2).
658        let lr = sched.learning_rate(5, 0.6);
659        assert!(
660            (lr - 0.1).abs() < 1e-12,
661            "improvement should have reset counter; rate should be 0.1, got {}",
662            lr,
663        );
664    }
665
666    /// Plateau scheduler never drops below min_lr.
667    #[test]
668    fn test_plateau_min_lr() {
669        let mut sched = PlateauLR::new(0.1, 0.1, 0, 0.05);
670
671        // With patience=0, every non-improving step triggers a reduction.
672        sched.learning_rate(0, 1.0); // sets best_loss = 1.0
673
674        // Non-improving: counter goes to 1 which exceeds patience (0).
675        sched.learning_rate(1, 1.0); // reduce: 0.1 * 0.1 = 0.01 → clamped to 0.05
676
677        let lr = sched.learning_rate(2, 1.0);
678        assert!(
679            lr >= 0.05 - 1e-12,
680            "rate should never drop below min_lr (0.05), got {}",
681            lr,
682        );
683    }
684
685    /// Plateau reset restores the scheduler to its initial state.
686    #[test]
687    fn test_plateau_reset() {
688        let mut sched = PlateauLR::new(0.1, 0.5, 1, 0.001);
689
690        // Drive the rate down.
691        sched.learning_rate(0, 1.0);
692        sched.learning_rate(1, 1.0);
693        sched.learning_rate(2, 1.0);
694
695        // The rate should have been reduced at least once.
696        let lr_before_reset = sched.current_lr;
697        assert!(
698            lr_before_reset < 0.1,
699            "rate should have decreased before reset, got {}",
700            lr_before_reset,
701        );
702
703        sched.reset();
704
705        let lr_after = sched.learning_rate(0, 10.0);
706        assert!(
707            (lr_after - 0.1).abs() < 1e-12,
708            "after reset, rate should be back to initial_lr (0.1), got {}",
709            lr_after,
710        );
711    }
712
713    // -- Cross-cutting property tests --------------------------------------
714
715    /// Every scheduler must return a strictly positive learning rate for any
716    /// non-negative step and finite loss value.
717    #[test]
718    fn test_all_positive() {
719        let mut schedulers: Vec<Box<dyn LRScheduler>> = vec![
720            Box::new(ConstantLR::new(0.05)),
721            Box::new(LinearDecayLR::new(0.1, 0.001, 100)),
722            Box::new(ExponentialDecayLR::new(1.0, 0.99)),
723            Box::new(CosineAnnealingLR::new(0.1, 0.001, 50)),
724            Box::new(PlateauLR::new(0.1, 0.5, 5, 0.001)),
725        ];
726
727        for (i, sched) in schedulers.iter_mut().enumerate() {
728            for step in 0..500 {
729                let lr = sched.learning_rate(step, 1.0);
730                assert!(
731                    lr > 0.0,
732                    "scheduler {} returned non-positive lr {} at step {}",
733                    i,
734                    lr,
735                    step,
736                );
737                assert!(
738                    lr.is_finite(),
739                    "scheduler {} returned non-finite lr at step {}",
740                    i,
741                    step,
742                );
743            }
744        }
745    }
746}