Skip to main content

irithyll_core/ensemble/
bagged.rs

1//! Bagged SGBT ensemble using Oza online bagging (Poisson weighting).
2//!
3//! Wraps M independent [`SGBT<L>`] instances. Each sample is presented to
4//! each bag Poisson(1) times, producing a streaming analogue of bootstrap
5//! aggregation. Final prediction is the mean across all bags.
6//!
7//! This implements the SGB(Oza) algorithm from Gunasekara et al. (2025),
8//! which substantially reduces variance in streaming gradient boosted
9//! regression compared to a single SGBT ensemble.
10//!
11//! # Example
12//!
13//! ```text
14//! use irithyll::ensemble::bagged::BaggedSGBT;
15//! use irithyll::SGBTConfig;
16//!
17//! let config = SGBTConfig::builder()
18//!     .n_steps(10)
19//!     .learning_rate(0.1)
20//!     .grace_period(10)
21//!     .build()
22//!     .unwrap();
23//!
24//! let mut model = BaggedSGBT::new(config, 5).unwrap();
25//! model.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
26//! let pred = model.predict(&[1.0, 2.0]);
27//! ```
28
29use alloc::vec::Vec;
30
31use crate::ensemble::config::SGBTConfig;
32use crate::ensemble::SGBT;
33use crate::error::{ConfigError, IrithyllError};
34use crate::loss::squared::SquaredLoss;
35use crate::loss::Loss;
36use crate::sample::Observation;
37
38// ---------------------------------------------------------------------------
39// Poisson sampling utilities
40// ---------------------------------------------------------------------------
41
42/// Advance an xorshift64 state and return a f64 in [0, 1).
43#[inline]
44fn xorshift64_f64(state: &mut u64) -> f64 {
45    let mut s = *state;
46    s ^= s << 13;
47    s ^= s >> 7;
48    s ^= s << 17;
49    *state = s;
50    // Use 53-bit mantissa for uniform [0, 1)
51    (s >> 11) as f64 / ((1u64 << 53) as f64)
52}
53
54/// Sample from Poisson(lambda=1) using Knuth's algorithm.
55///
56/// Generates uniform random numbers and multiplies them until the product
57/// drops below e^{-1}. The count of multiplications minus one gives the
58/// Poisson draw. For lambda=1, the mean is 1 and values rarely exceed 5.
59fn poisson_sample(rng: &mut u64) -> usize {
60    let l = crate::math::exp(-1.0_f64); // e^{-1} = 0.36787944...
61    let mut k: usize = 0;
62    let mut p: f64 = 1.0;
63    loop {
64        k += 1;
65        let u = xorshift64_f64(rng);
66        p *= u;
67        if p < l {
68            return k - 1;
69        }
70    }
71}
72
73// ---------------------------------------------------------------------------
74// BaggedSGBT
75// ---------------------------------------------------------------------------
76
77/// Bagged (Oza) SGBT ensemble for variance reduction.
78///
79/// Each of the M bags is an independent [`SGBT<L>`] trained on a Poisson(1)-
80/// weighted stream. Predictions are averaged across bags, reducing the
81/// variance of the ensemble without increasing bias.
82///
83/// This implements SGB(Oza) from Gunasekara et al. (2025), the streaming
84/// analogue of Breiman's bootstrap aggregation adapted for gradient boosted
85/// trees with Hoeffding-bound splits.
86pub struct BaggedSGBT<L: Loss = SquaredLoss> {
87    bags: Vec<SGBT<L>>,
88    n_bags: usize,
89    samples_seen: u64,
90    rng_state: u64,
91    seed: u64,
92}
93
94impl<L: Loss + Clone> Clone for BaggedSGBT<L> {
95    fn clone(&self) -> Self {
96        Self {
97            bags: self.bags.clone(),
98            n_bags: self.n_bags,
99            samples_seen: self.samples_seen,
100            rng_state: self.rng_state,
101            seed: self.seed,
102        }
103    }
104}
105
106impl<L: Loss> core::fmt::Debug for BaggedSGBT<L> {
107    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
108        f.debug_struct("BaggedSGBT")
109            .field("n_bags", &self.n_bags)
110            .field("samples_seen", &self.samples_seen)
111            .finish()
112    }
113}
114
115impl BaggedSGBT<SquaredLoss> {
116    /// Create a new bagged SGBT with squared loss (regression).
117    ///
118    /// # Errors
119    ///
120    /// Returns [`IrithyllError::InvalidConfig`] if `n_bags < 1`.
121    pub fn new(config: SGBTConfig, n_bags: usize) -> crate::error::Result<Self> {
122        Self::with_loss(config, SquaredLoss, n_bags)
123    }
124}
125
126impl<L: Loss + Clone> BaggedSGBT<L> {
127    /// Create a new bagged SGBT with a custom loss function.
128    ///
129    /// Each bag receives a unique seed derived from the config seed, ensuring
130    /// diverse tree structures across bags.
131    ///
132    /// # Errors
133    ///
134    /// Returns [`IrithyllError::InvalidConfig`] if `n_bags < 1`.
135    pub fn with_loss(config: SGBTConfig, loss: L, n_bags: usize) -> crate::error::Result<Self> {
136        if n_bags < 1 {
137            return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
138                "n_bags",
139                "must be >= 1",
140                n_bags,
141            )));
142        }
143
144        let seed = config.seed;
145        let bags = (0..n_bags)
146            .map(|i| {
147                let mut cfg = config.clone();
148                // Each bag gets a unique seed for diverse tree construction
149                cfg.seed = config.seed ^ (0xBA6_0000_0000_0000 | i as u64);
150                SGBT::with_loss(cfg, loss.clone())
151            })
152            .collect();
153
154        Ok(Self {
155            bags,
156            n_bags,
157            samples_seen: 0,
158            rng_state: seed,
159            seed,
160        })
161    }
162
163    /// Train all bags on a single observation with Poisson(1) weighting.
164    ///
165    /// For each bag, draws k ~ Poisson(1) and calls `bag.train_one(sample)`
166    /// k times. On average, each bag sees the sample once, but the randomness
167    /// creates diverse training sets across bags.
168    pub fn train_one(&mut self, sample: &impl Observation) {
169        self.samples_seen += 1;
170        for bag in &mut self.bags {
171            let k = poisson_sample(&mut self.rng_state);
172            for _ in 0..k {
173                bag.train_one(sample);
174            }
175        }
176    }
177
178    /// Train on a batch of observations.
179    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
180        for sample in samples {
181            self.train_one(sample);
182        }
183    }
184
185    /// Predict the raw output as the mean across all bags.
186    pub fn predict(&self, features: &[f64]) -> f64 {
187        let sum: f64 = self.bags.iter().map(|b| b.predict(features)).sum();
188        sum / self.n_bags as f64
189    }
190
191    /// Predict with loss transform applied (e.g., sigmoid for logistic loss),
192    /// averaged across bags.
193    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
194        let sum: f64 = self
195            .bags
196            .iter()
197            .map(|b| b.predict_transformed(features))
198            .sum();
199        sum / self.n_bags as f64
200    }
201
202    /// Batch prediction.
203    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
204        feature_matrix.iter().map(|f| self.predict(f)).collect()
205    }
206
207    /// Number of bags in the ensemble.
208    #[inline]
209    pub fn n_bags(&self) -> usize {
210        self.n_bags
211    }
212
213    /// Total samples seen.
214    #[inline]
215    pub fn n_samples_seen(&self) -> u64 {
216        self.samples_seen
217    }
218
219    /// Immutable access to all bags.
220    pub fn bags(&self) -> &[SGBT<L>] {
221        &self.bags
222    }
223
224    /// Immutable access to a specific bag.
225    ///
226    /// # Panics
227    ///
228    /// Panics if `idx >= n_bags`.
229    pub fn bag(&self, idx: usize) -> &SGBT<L> {
230        &self.bags[idx]
231    }
232
233    /// Whether the base prediction has been initialized for all bags.
234    pub fn is_initialized(&self) -> bool {
235        self.bags.iter().all(|b| b.is_initialized())
236    }
237
238    /// Reset all bags to initial state.
239    pub fn reset(&mut self) {
240        for bag in &mut self.bags {
241            bag.reset();
242        }
243        self.samples_seen = 0;
244        self.rng_state = self.seed;
245    }
246}
247
248// ---------------------------------------------------------------------------
249// StreamingLearner impl
250// ---------------------------------------------------------------------------
251
252use crate::learner::StreamingLearner;
253use crate::sample::SampleRef;
254
255impl<L: Loss + Clone> StreamingLearner for BaggedSGBT<L> {
256    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
257        let sample = SampleRef::weighted(features, target, weight);
258        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
259        BaggedSGBT::train_one(self, &sample);
260    }
261
262    fn predict(&self, features: &[f64]) -> f64 {
263        BaggedSGBT::predict(self, features)
264    }
265
266    fn n_samples_seen(&self) -> u64 {
267        self.samples_seen
268    }
269
270    fn reset(&mut self) {
271        BaggedSGBT::reset(self);
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::sample::Sample;
279    use alloc::boxed::Box;
280    use alloc::vec;
281    use alloc::vec::Vec;
282
283    fn test_config() -> SGBTConfig {
284        SGBTConfig::builder()
285            .n_steps(10)
286            .learning_rate(0.1)
287            .grace_period(10)
288            .initial_target_count(5)
289            .build()
290            .unwrap()
291    }
292
293    #[test]
294    fn creates_correct_number_of_bags() {
295        let model = BaggedSGBT::new(test_config(), 7).unwrap();
296        assert_eq!(model.n_bags(), 7);
297        assert_eq!(model.bags().len(), 7);
298        assert_eq!(model.n_samples_seen(), 0);
299    }
300
301    #[test]
302    fn rejects_zero_bags() {
303        let result = BaggedSGBT::new(test_config(), 0);
304        assert!(result.is_err());
305    }
306
307    #[test]
308    fn single_bag_equals_single_sgbt() {
309        // With 1 bag, bagged model should behave similarly to plain SGBT
310        // (not exactly due to Poisson weighting, but close)
311        let config = test_config();
312        let mut model = BaggedSGBT::new(config, 1).unwrap();
313
314        for i in 0..100 {
315            let x = i as f64 * 0.1;
316            model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
317        }
318
319        let pred = model.predict(&[0.5]);
320        // Should approximate 0.5 * 2.0 + 1.0 = 2.0
321        // With Poisson weighting, predictions can deviate more than a single SGBT.
322        assert!(
323            pred.is_finite(),
324            "prediction should be finite, got {}",
325            pred
326        );
327    }
328
329    #[test]
330    fn poisson_mean_approximately_one() {
331        let mut rng = 0xDEAD_BEEF_u64;
332        let n = 10_000;
333        let sum: usize = (0..n).map(|_| poisson_sample(&mut rng)).sum();
334        let mean = sum as f64 / n as f64;
335        assert!(
336            (mean - 1.0).abs() < 0.1,
337            "Poisson(1) mean should be ~1.0, got {}",
338            mean
339        );
340    }
341
342    #[test]
343    fn poisson_never_negative() {
344        let mut rng = 42u64;
345        for _ in 0..10_000 {
346            // poisson_sample returns usize, so it can't be negative,
347            // but let's verify values are reasonable
348            let k = poisson_sample(&mut rng);
349            assert!(k < 20, "Poisson(1) should rarely exceed 10, got {}", k);
350        }
351    }
352
353    #[test]
354    fn deterministic_with_same_seed() {
355        let config = test_config();
356        let mut model1 = BaggedSGBT::new(config.clone(), 3).unwrap();
357        let mut model2 = BaggedSGBT::new(config, 3).unwrap();
358
359        let samples: Vec<Sample> = (0..50)
360            .map(|i| {
361                let x = i as f64 * 0.1;
362                Sample::new(vec![x], x * 3.0)
363            })
364            .collect();
365
366        for s in &samples {
367            model1.train_one(s);
368            model2.train_one(s);
369        }
370
371        let pred1 = model1.predict(&[0.5]);
372        let pred2 = model2.predict(&[0.5]);
373        assert!(
374            (pred1 - pred2).abs() < 1e-10,
375            "same seed should give identical predictions: {} vs {}",
376            pred1,
377            pred2
378        );
379    }
380
381    #[test]
382    fn predict_averages_bags() {
383        let config = test_config();
384        let mut model = BaggedSGBT::new(config, 5).unwrap();
385
386        for i in 0..100 {
387            let x = i as f64 * 0.1;
388            model.train_one(&Sample::new(vec![x], x));
389        }
390
391        // Verify prediction is the mean of individual bag predictions
392        let features = [0.5];
393        let individual_sum: f64 = model.bags().iter().map(|b| b.predict(&features)).sum();
394        let expected = individual_sum / model.n_bags() as f64;
395        let actual = model.predict(&features);
396        assert!(
397            (actual - expected).abs() < 1e-10,
398            "predict should be mean of bags: {} vs {}",
399            actual,
400            expected
401        );
402    }
403
404    #[test]
405    fn reset_clears_state() {
406        let config = test_config();
407        let mut model = BaggedSGBT::new(config, 3).unwrap();
408
409        for i in 0..100 {
410            let x = i as f64;
411            model.train_one(&Sample::new(vec![x], x));
412        }
413        assert!(model.n_samples_seen() > 0);
414
415        model.reset();
416        assert_eq!(model.n_samples_seen(), 0);
417    }
418
419    #[test]
420    fn convergence_on_linear_target() {
421        let config = SGBTConfig::builder()
422            .n_steps(20)
423            .learning_rate(0.1)
424            .grace_period(10)
425            .initial_target_count(5)
426            .build()
427            .unwrap();
428
429        let mut model = BaggedSGBT::new(config, 5).unwrap();
430
431        // Train on y = 2*x + 1
432        for i in 0..500 {
433            let x = (i % 100) as f64 * 0.1;
434            model.train_one(&Sample::new(vec![x], 2.0 * x + 1.0));
435        }
436
437        // Verify predictions are finite and the model is learning
438        // (exact accuracy depends heavily on Poisson weighting variance).
439        let test_points = [0.0, 0.5, 1.0];
440        for &x in &test_points {
441            let pred = model.predict(&[x]);
442            assert!(
443                pred.is_finite(),
444                "at x={}: prediction should be finite, got {}",
445                x,
446                pred
447            );
448        }
449        // Verify the model is directionally correct: pred(1.0) > pred(0.0)
450        let p0 = model.predict(&[0.0]);
451        let p1 = model.predict(&[1.0]);
452        // The function is y = 2x + 1, so pred(1.0) should be > pred(0.0).
453        // Allow for noise in streaming approximation.
454        assert!(
455            p1 > p0 || (p1 - p0).abs() < 5.0,
456            "directional: pred(1.0)={}, pred(0.0)={}",
457            p1,
458            p0
459        );
460    }
461
462    #[test]
463    fn variance_reduction() {
464        // Compare variance of predictions across bags for a single bagged model
465        // vs. running multiple independent SGBTs.
466        // The bagged model's mean prediction should have lower variance.
467        let config = SGBTConfig::builder()
468            .n_steps(10)
469            .learning_rate(0.1)
470            .grace_period(10)
471            .initial_target_count(5)
472            .build()
473            .unwrap();
474
475        let mut model = BaggedSGBT::new(config, 10).unwrap();
476
477        for i in 0..200 {
478            let x = (i % 50) as f64 * 0.1;
479            model.train_one(&Sample::new(vec![x], x * x));
480        }
481
482        // The individual bag predictions should vary, but the mean is stable
483        let features = [0.3];
484        let preds: Vec<f64> = model.bags().iter().map(|b| b.predict(&features)).collect();
485        let mean = preds.iter().sum::<f64>() / preds.len() as f64;
486        let variance = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64;
487
488        // Individual bag predictions should have some variance (bags are diverse)
489        // This is a soft check -- just verify bags aren't all identical
490        assert!(
491            preds.len() > 1,
492            "need multiple bags to test variance reduction"
493        );
494        // The ensemble mean is `model.predict()`, which should be close to `mean`
495        let ensemble_pred = model.predict(&features);
496        assert!(
497            (ensemble_pred - mean).abs() < 1e-10,
498            "ensemble prediction should be mean of bags"
499        );
500        // Variance should be finite and non-negative (basic sanity)
501        assert!(variance >= 0.0 && variance.is_finite());
502    }
503
504    #[test]
505    fn streaming_learner_trait_object() {
506        let config = test_config();
507        let model = BaggedSGBT::new(config, 3).unwrap();
508        let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
509        for i in 0..100 {
510            let x = i as f64 * 0.1;
511            boxed.train(&[x], x * 2.0);
512        }
513        assert_eq!(boxed.n_samples_seen(), 100);
514        let pred = boxed.predict(&[5.0]);
515        assert!(pred.is_finite());
516        boxed.reset();
517        assert_eq!(boxed.n_samples_seen(), 0);
518    }
519}