irithyll_core/ensemble/adaptive.rs
1//! Adaptive learning rate wrapper for SGBT ensembles.
2//!
3//! [`AdaptiveSGBT`] pairs an [`SGBT`] model with an [`LRScheduler`], adjusting
4//! the learning rate before each training step based on the scheduler's policy.
5//! This enables time-varying learning rates -- decay, cosine annealing, plateau
6//! reduction -- without modifying the core ensemble code.
7//!
8//! # Example
9//!
10//! ```text
11//! use irithyll::ensemble::adaptive::AdaptiveSGBT;
12//! use irithyll::ensemble::lr_schedule::ExponentialDecayLR;
13//! use irithyll::SGBTConfig;
14//! use irithyll::learner::StreamingLearner;
15//!
16//! let config = SGBTConfig::builder()
17//! .n_steps(10)
18//! .learning_rate(0.1)
19//! .build()
20//! .unwrap();
21//!
22//! let mut model = AdaptiveSGBT::new(config, ExponentialDecayLR::new(0.1, 0.999));
23//! model.train(&[1.0, 2.0], 3.0);
24//! model.train(&[4.0, 5.0], 6.0);
25//!
26//! // The learning rate adapts over time.
27//! let pred = model.predict(&[1.0, 2.0]);
28//! ```
29
30use alloc::boxed::Box;
31
32use core::fmt;
33
34use crate::ensemble::config::SGBTConfig;
35use crate::ensemble::lr_schedule::LRScheduler;
36use crate::ensemble::SGBT;
37use crate::learner::StreamingLearner;
38use crate::loss::squared::SquaredLoss;
39use crate::loss::Loss;
40use crate::sample::{Observation, SampleRef};
41
42/// SGBT ensemble with an attached learning rate scheduler.
43///
44/// Before each `train_one` call, AdaptiveSGBT:
45///
46/// 1. Computes the current prediction to estimate loss.
47/// 2. Queries the scheduler for the new learning rate.
48/// 3. Sets the learning rate on the inner SGBT.
49/// 4. Delegates the actual training step.
50///
51/// This allows any [`LRScheduler`] -- exponential decay, cosine annealing,
52/// plateau reduction -- to drive the ensemble's learning rate without touching
53/// the core boosting logic.
54///
55/// # Loss Estimation
56///
57/// The scheduler receives squared error `(target - prediction)²` as its loss
58/// signal. This is computed from the current ensemble prediction *before* the
59/// training step, making it a one-step-lagged estimate. This works well for
60/// schedulers like [`PlateauLR`](crate::ensemble::lr_schedule::PlateauLR) that
61/// smooth over many steps.
62pub struct AdaptiveSGBT<L: Loss = SquaredLoss> {
63 inner: SGBT<L>,
64 scheduler: Box<dyn LRScheduler>,
65 step_count: u64,
66 last_loss: f64,
67 base_lr: f64,
68}
69
70impl AdaptiveSGBT<SquaredLoss> {
71 /// Create an adaptive SGBT with squared loss (regression).
72 ///
73 /// The initial learning rate is taken from the config and also stored as
74 /// `base_lr` for reference.
75 ///
76 /// ```text
77 /// use irithyll::ensemble::adaptive::AdaptiveSGBT;
78 /// use irithyll::ensemble::lr_schedule::ConstantLR;
79 /// use irithyll::SGBTConfig;
80 ///
81 /// let config = SGBTConfig::builder()
82 /// .n_steps(10)
83 /// .learning_rate(0.05)
84 /// .build()
85 /// .unwrap();
86 /// let model = AdaptiveSGBT::new(config, ConstantLR::new(0.05));
87 /// ```
88 pub fn new(config: SGBTConfig, scheduler: impl LRScheduler + 'static) -> Self {
89 let base_lr = config.learning_rate;
90 Self {
91 inner: SGBT::new(config),
92 scheduler: Box::new(scheduler),
93 step_count: 0,
94 last_loss: 0.0,
95 base_lr,
96 }
97 }
98}
99
100impl<L: Loss> AdaptiveSGBT<L> {
101 /// Create an adaptive SGBT with a specific loss function.
102 ///
103 /// ```text
104 /// use irithyll::ensemble::adaptive::AdaptiveSGBT;
105 /// use irithyll::ensemble::lr_schedule::LinearDecayLR;
106 /// use irithyll::loss::logistic::LogisticLoss;
107 /// use irithyll::SGBTConfig;
108 ///
109 /// let config = SGBTConfig::builder()
110 /// .n_steps(10)
111 /// .learning_rate(0.1)
112 /// .build()
113 /// .unwrap();
114 /// let model = AdaptiveSGBT::with_loss(
115 /// config, LogisticLoss, LinearDecayLR::new(0.1, 0.001, 10_000),
116 /// );
117 /// ```
118 pub fn with_loss(config: SGBTConfig, loss: L, scheduler: impl LRScheduler + 'static) -> Self {
119 let base_lr = config.learning_rate;
120 Self {
121 inner: SGBT::with_loss(config, loss),
122 scheduler: Box::new(scheduler),
123 step_count: 0,
124 last_loss: 0.0,
125 base_lr,
126 }
127 }
128
129 /// Train on a single observation, adapting the learning rate first.
130 ///
131 /// This is the generic version accepting any `Observation` implementor.
132 /// For the `StreamingLearner` trait interface, use `train_one(features, target, weight)`.
133 pub fn train_one_obs(&mut self, sample: &impl Observation) {
134 // Estimate loss from current prediction.
135 let pred = self.inner.predict(sample.features());
136 let err = sample.target() - pred;
137 self.last_loss = err * err;
138
139 // Query scheduler for new learning rate.
140 let lr = self
141 .scheduler
142 .learning_rate(self.step_count, self.last_loss);
143 self.inner.set_learning_rate(lr);
144 self.step_count += 1;
145
146 // Delegate training.
147 self.inner.train_one(sample);
148 }
149
150 /// Current learning rate (as last set by the scheduler).
151 pub fn current_lr(&self) -> f64 {
152 self.inner.config().learning_rate
153 }
154
155 /// The initial learning rate from the original config.
156 pub fn base_lr(&self) -> f64 {
157 self.base_lr
158 }
159
160 /// Total scheduler steps (equal to samples trained).
161 pub fn step_count(&self) -> u64 {
162 self.step_count
163 }
164
165 /// Most recent loss value passed to the scheduler.
166 pub fn last_loss(&self) -> f64 {
167 self.last_loss
168 }
169
170 /// Immutable access to the scheduler.
171 pub fn scheduler(&self) -> &dyn LRScheduler {
172 &*self.scheduler
173 }
174
175 /// Mutable access to the scheduler.
176 pub fn scheduler_mut(&mut self) -> &mut dyn LRScheduler {
177 &mut *self.scheduler
178 }
179
180 /// Immutable access to the inner SGBT model.
181 pub fn inner(&self) -> &SGBT<L> {
182 &self.inner
183 }
184
185 /// Mutable access to the inner SGBT model.
186 pub fn inner_mut(&mut self) -> &mut SGBT<L> {
187 &mut self.inner
188 }
189
190 /// Consume the wrapper and return the inner SGBT model.
191 pub fn into_inner(self) -> SGBT<L> {
192 self.inner
193 }
194
195 /// Predict using the inner SGBT model.
196 pub fn predict(&self, features: &[f64]) -> f64 {
197 self.inner.predict(features)
198 }
199}
200
201// ---------------------------------------------------------------------------
202// StreamingLearner
203// ---------------------------------------------------------------------------
204
205impl<L: Loss> StreamingLearner for AdaptiveSGBT<L> {
206 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
207 let sample = SampleRef::weighted(features, target, weight);
208 self.train_one_obs(&sample);
209 }
210
211 fn predict(&self, features: &[f64]) -> f64 {
212 self.inner.predict(features)
213 }
214
215 fn n_samples_seen(&self) -> u64 {
216 self.inner.n_samples_seen()
217 }
218
219 fn reset(&mut self) {
220 self.inner.reset();
221 self.scheduler.reset();
222 self.step_count = 0;
223 self.last_loss = 0.0;
224 }
225}
226
227// ---------------------------------------------------------------------------
228// Trait impls
229// ---------------------------------------------------------------------------
230
231impl<L: Loss> fmt::Debug for AdaptiveSGBT<L> {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 f.debug_struct("AdaptiveSGBT")
234 .field("inner", &self.inner)
235 .field("step_count", &self.step_count)
236 .field("last_loss", &self.last_loss)
237 .field("base_lr", &self.base_lr)
238 .field("current_lr", &self.current_lr())
239 .finish()
240 }
241}
242
243// ---------------------------------------------------------------------------
244// Tests
245// ---------------------------------------------------------------------------
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::ensemble::lr_schedule::{ConstantLR, ExponentialDecayLR, PlateauLR};
251 use alloc::boxed::Box;
252
253 fn test_config() -> SGBTConfig {
254 SGBTConfig::builder()
255 .n_steps(5)
256 .learning_rate(0.1)
257 .build()
258 .unwrap()
259 }
260
261 #[test]
262 fn construction_and_initial_state() {
263 let model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
264 assert_eq!(model.step_count(), 0);
265 assert_eq!(model.n_samples_seen(), 0);
266 assert!((model.base_lr() - 0.1).abs() < 1e-12);
267 assert!((model.current_lr() - 0.1).abs() < 1e-12);
268 }
269
270 #[test]
271 fn train_increments_step_count() {
272 let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
273 model.train(&[1.0, 2.0], 3.0);
274 model.train(&[4.0, 5.0], 6.0);
275 assert_eq!(model.step_count(), 2);
276 assert_eq!(model.n_samples_seen(), 2);
277 }
278
279 #[test]
280 fn exponential_decay_reduces_lr() {
281 let mut model = AdaptiveSGBT::new(test_config(), ExponentialDecayLR::new(0.1, 0.9));
282
283 // Train a few steps -- LR should decrease.
284 for i in 0..10 {
285 model.train(&[i as f64, (i * 2) as f64], i as f64);
286 }
287
288 // After 10 steps with gamma=0.9, LR should be much less than initial.
289 let lr = model.current_lr();
290 assert!(lr < 0.1, "LR should have decayed from 0.1, got {}", lr);
291 }
292
293 #[test]
294 fn predict_is_finite() {
295 let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
296 model.train(&[1.0, 2.0], 3.0);
297 let pred = model.predict(&[1.0, 2.0]);
298 assert!(
299 pred.is_finite(),
300 "prediction should be finite, got {}",
301 pred
302 );
303 }
304
305 #[test]
306 fn reset_clears_all_state() {
307 let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
308 model.train(&[1.0], 2.0);
309 model.train(&[3.0], 4.0);
310 assert_eq!(model.step_count(), 2);
311
312 model.reset();
313 assert_eq!(model.step_count(), 0);
314 assert_eq!(model.n_samples_seen(), 0);
315 assert!((model.last_loss()).abs() < 1e-12);
316 }
317
318 #[test]
319 fn as_streaming_learner_trait_object() {
320 let model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
321 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
322 boxed.train(&[1.0, 2.0], 3.0);
323 assert_eq!(boxed.n_samples_seen(), 1);
324 let pred = boxed.predict(&[1.0, 2.0]);
325 assert!(pred.is_finite());
326 }
327
328 #[test]
329 fn plateau_lr_reduces_on_stagnation() {
330 // PlateauLR with patience=5, factor=0.5, min_lr=1e-6
331 let scheduler = PlateauLR::new(0.1, 0.5, 5, 1e-6);
332 let mut model = AdaptiveSGBT::new(test_config(), scheduler);
333
334 // Feed constant data -- loss will stagnate.
335 for _ in 0..50 {
336 model.train(&[1.0, 1.0], 1.0);
337 }
338
339 // After stagnation, LR should have been reduced at least once.
340 let lr = model.current_lr();
341 assert!(lr <= 0.1, "PlateauLR should have reduced LR, got {}", lr);
342 }
343}