Skip to main content

irithyll_core/ensemble/
adaptive.rs

1//! Adaptive learning rate wrapper for SGBT ensembles.
2//!
3//! [`AdaptiveSGBT`] pairs an [`SGBT`] model with an [`LRScheduler`], adjusting
4//! the learning rate before each training step based on the scheduler's policy.
5//! This enables time-varying learning rates -- decay, cosine annealing, plateau
6//! reduction -- without modifying the core ensemble code.
7//!
8//! # Example
9//!
10//! ```
11//! use irithyll::ensemble::adaptive::AdaptiveSGBT;
12//! use irithyll::ensemble::lr_schedule::ExponentialDecayLR;
13//! use irithyll::SGBTConfig;
14//! use irithyll::learner::StreamingLearner;
15//!
16//! let config = SGBTConfig::builder()
17//!     .n_steps(10)
18//!     .learning_rate(0.1)
19//!     .build()
20//!     .unwrap();
21//!
22//! let mut model = AdaptiveSGBT::new(config, ExponentialDecayLR::new(0.1, 0.999));
23//! model.train(&[1.0, 2.0], 3.0);
24//! model.train(&[4.0, 5.0], 6.0);
25//!
26//! // The learning rate adapts over time.
27//! let pred = model.predict(&[1.0, 2.0]);
28//! ```
29
30use alloc::boxed::Box;
31
32use core::fmt;
33
34use crate::ensemble::config::SGBTConfig;
35use crate::ensemble::lr_schedule::LRScheduler;
36use crate::ensemble::SGBT;
37use crate::learner::StreamingLearner;
38use crate::loss::squared::SquaredLoss;
39use crate::loss::Loss;
40use crate::sample::{Observation, SampleRef};
41
42/// SGBT ensemble with an attached learning rate scheduler.
43///
44/// Before each `train_one` call, AdaptiveSGBT:
45///
46/// 1. Computes the current prediction to estimate loss.
47/// 2. Queries the scheduler for the new learning rate.
48/// 3. Sets the learning rate on the inner SGBT.
49/// 4. Delegates the actual training step.
50///
51/// This allows any [`LRScheduler`] -- exponential decay, cosine annealing,
52/// plateau reduction -- to drive the ensemble's learning rate without touching
53/// the core boosting logic.
54///
55/// # Loss Estimation
56///
57/// The scheduler receives squared error `(target - prediction)²` as its loss
58/// signal. This is computed from the current ensemble prediction *before* the
59/// training step, making it a one-step-lagged estimate. This works well for
60/// schedulers like [`PlateauLR`](crate::ensemble::lr_schedule::PlateauLR) that
61/// smooth over many steps.
62pub struct AdaptiveSGBT<L: Loss = SquaredLoss> {
63    inner: SGBT<L>,
64    scheduler: Box<dyn LRScheduler>,
65    step_count: u64,
66    last_loss: f64,
67    base_lr: f64,
68}
69
70impl AdaptiveSGBT<SquaredLoss> {
71    /// Create an adaptive SGBT with squared loss (regression).
72    ///
73    /// The initial learning rate is taken from the config and also stored as
74    /// `base_lr` for reference.
75    ///
76    /// ```
77    /// use irithyll::ensemble::adaptive::AdaptiveSGBT;
78    /// use irithyll::ensemble::lr_schedule::ConstantLR;
79    /// use irithyll::SGBTConfig;
80    ///
81    /// let config = SGBTConfig::builder()
82    ///     .n_steps(10)
83    ///     .learning_rate(0.05)
84    ///     .build()
85    ///     .unwrap();
86    /// let model = AdaptiveSGBT::new(config, ConstantLR::new(0.05));
87    /// ```
88    pub fn new(config: SGBTConfig, scheduler: impl LRScheduler + 'static) -> Self {
89        let base_lr = config.learning_rate;
90        Self {
91            inner: SGBT::new(config),
92            scheduler: Box::new(scheduler),
93            step_count: 0,
94            last_loss: 0.0,
95            base_lr,
96        }
97    }
98}
99
100impl<L: Loss> AdaptiveSGBT<L> {
101    /// Create an adaptive SGBT with a specific loss function.
102    ///
103    /// ```
104    /// use irithyll::ensemble::adaptive::AdaptiveSGBT;
105    /// use irithyll::ensemble::lr_schedule::LinearDecayLR;
106    /// use irithyll::loss::logistic::LogisticLoss;
107    /// use irithyll::SGBTConfig;
108    ///
109    /// let config = SGBTConfig::builder()
110    ///     .n_steps(10)
111    ///     .learning_rate(0.1)
112    ///     .build()
113    ///     .unwrap();
114    /// let model = AdaptiveSGBT::with_loss(
115    ///     config, LogisticLoss, LinearDecayLR::new(0.1, 0.001, 10_000),
116    /// );
117    /// ```
118    pub fn with_loss(config: SGBTConfig, loss: L, scheduler: impl LRScheduler + 'static) -> Self {
119        let base_lr = config.learning_rate;
120        Self {
121            inner: SGBT::with_loss(config, loss),
122            scheduler: Box::new(scheduler),
123            step_count: 0,
124            last_loss: 0.0,
125            base_lr,
126        }
127    }
128
129    /// Train on a single observation, adapting the learning rate first.
130    ///
131    /// This is the generic version accepting any `Observation` implementor.
132    /// For the `StreamingLearner` trait interface, use `train_one(features, target, weight)`.
133    pub fn train_one_obs(&mut self, sample: &impl Observation) {
134        // Estimate loss from current prediction.
135        let pred = self.inner.predict(sample.features());
136        let err = sample.target() - pred;
137        self.last_loss = err * err;
138
139        // Query scheduler for new learning rate.
140        let lr = self
141            .scheduler
142            .learning_rate(self.step_count, self.last_loss);
143        self.inner.set_learning_rate(lr);
144        self.step_count += 1;
145
146        // Delegate training.
147        self.inner.train_one(sample);
148    }
149
150    /// Current learning rate (as last set by the scheduler).
151    pub fn current_lr(&self) -> f64 {
152        self.inner.config().learning_rate
153    }
154
155    /// The initial learning rate from the original config.
156    pub fn base_lr(&self) -> f64 {
157        self.base_lr
158    }
159
160    /// Total scheduler steps (equal to samples trained).
161    pub fn step_count(&self) -> u64 {
162        self.step_count
163    }
164
165    /// Most recent loss value passed to the scheduler.
166    pub fn last_loss(&self) -> f64 {
167        self.last_loss
168    }
169
170    /// Immutable access to the scheduler.
171    pub fn scheduler(&self) -> &dyn LRScheduler {
172        &*self.scheduler
173    }
174
175    /// Mutable access to the scheduler.
176    pub fn scheduler_mut(&mut self) -> &mut dyn LRScheduler {
177        &mut *self.scheduler
178    }
179
180    /// Immutable access to the inner SGBT model.
181    pub fn inner(&self) -> &SGBT<L> {
182        &self.inner
183    }
184
185    /// Mutable access to the inner SGBT model.
186    pub fn inner_mut(&mut self) -> &mut SGBT<L> {
187        &mut self.inner
188    }
189
190    /// Consume the wrapper and return the inner SGBT model.
191    pub fn into_inner(self) -> SGBT<L> {
192        self.inner
193    }
194
195    /// Predict using the inner SGBT model.
196    pub fn predict(&self, features: &[f64]) -> f64 {
197        self.inner.predict(features)
198    }
199}
200
201// ---------------------------------------------------------------------------
202// StreamingLearner
203// ---------------------------------------------------------------------------
204
205impl<L: Loss> StreamingLearner for AdaptiveSGBT<L> {
206    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
207        let sample = SampleRef::weighted(features, target, weight);
208        self.train_one_obs(&sample);
209    }
210
211    fn predict(&self, features: &[f64]) -> f64 {
212        self.inner.predict(features)
213    }
214
215    fn n_samples_seen(&self) -> u64 {
216        self.inner.n_samples_seen()
217    }
218
219    fn reset(&mut self) {
220        self.inner.reset();
221        self.scheduler.reset();
222        self.step_count = 0;
223        self.last_loss = 0.0;
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Trait impls
229// ---------------------------------------------------------------------------
230
231impl<L: Loss> fmt::Debug for AdaptiveSGBT<L> {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        f.debug_struct("AdaptiveSGBT")
234            .field("inner", &self.inner)
235            .field("step_count", &self.step_count)
236            .field("last_loss", &self.last_loss)
237            .field("base_lr", &self.base_lr)
238            .field("current_lr", &self.current_lr())
239            .finish()
240    }
241}
242
243// ---------------------------------------------------------------------------
244// Tests
245// ---------------------------------------------------------------------------
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::ensemble::lr_schedule::{ConstantLR, ExponentialDecayLR, PlateauLR};
251    use alloc::boxed::Box;
252
253    fn test_config() -> SGBTConfig {
254        SGBTConfig::builder()
255            .n_steps(5)
256            .learning_rate(0.1)
257            .build()
258            .unwrap()
259    }
260
261    #[test]
262    fn construction_and_initial_state() {
263        let model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
264        assert_eq!(model.step_count(), 0);
265        assert_eq!(model.n_samples_seen(), 0);
266        assert!((model.base_lr() - 0.1).abs() < 1e-12);
267        assert!((model.current_lr() - 0.1).abs() < 1e-12);
268    }
269
270    #[test]
271    fn train_increments_step_count() {
272        let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
273        model.train(&[1.0, 2.0], 3.0);
274        model.train(&[4.0, 5.0], 6.0);
275        assert_eq!(model.step_count(), 2);
276        assert_eq!(model.n_samples_seen(), 2);
277    }
278
279    #[test]
280    fn exponential_decay_reduces_lr() {
281        let mut model = AdaptiveSGBT::new(test_config(), ExponentialDecayLR::new(0.1, 0.9));
282
283        // Train a few steps -- LR should decrease.
284        for i in 0..10 {
285            model.train(&[i as f64, (i * 2) as f64], i as f64);
286        }
287
288        // After 10 steps with gamma=0.9, LR should be much less than initial.
289        let lr = model.current_lr();
290        assert!(lr < 0.1, "LR should have decayed from 0.1, got {}", lr);
291    }
292
293    #[test]
294    fn predict_is_finite() {
295        let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
296        model.train(&[1.0, 2.0], 3.0);
297        let pred = model.predict(&[1.0, 2.0]);
298        assert!(
299            pred.is_finite(),
300            "prediction should be finite, got {}",
301            pred
302        );
303    }
304
305    #[test]
306    fn reset_clears_all_state() {
307        let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
308        model.train(&[1.0], 2.0);
309        model.train(&[3.0], 4.0);
310        assert_eq!(model.step_count(), 2);
311
312        model.reset();
313        assert_eq!(model.step_count(), 0);
314        assert_eq!(model.n_samples_seen(), 0);
315        assert!((model.last_loss()).abs() < 1e-12);
316    }
317
318    #[test]
319    fn as_streaming_learner_trait_object() {
320        let model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
321        let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
322        boxed.train(&[1.0, 2.0], 3.0);
323        assert_eq!(boxed.n_samples_seen(), 1);
324        let pred = boxed.predict(&[1.0, 2.0]);
325        assert!(pred.is_finite());
326    }
327
328    #[test]
329    fn plateau_lr_reduces_on_stagnation() {
330        // PlateauLR with patience=5, factor=0.5, min_lr=1e-6
331        let scheduler = PlateauLR::new(0.1, 0.5, 5, 1e-6);
332        let mut model = AdaptiveSGBT::new(test_config(), scheduler);
333
334        // Feed constant data -- loss will stagnate.
335        for _ in 0..50 {
336            model.train(&[1.0, 1.0], 1.0);
337        }
338
339        // After stagnation, LR should have been reduced at least once.
340        let lr = model.current_lr();
341        assert!(lr <= 0.1, "PlateauLR should have reduced LR, got {}", lr);
342    }
343}