Skip to main content

irithyll_core/ensemble/
bagged.rs

1//! Bagged SGBT ensemble using Oza online bagging (Poisson weighting).
2//!
3//! Wraps M independent [`SGBT<L>`] instances. Each sample is presented to
4//! each bag Poisson(1) times, producing a streaming analogue of bootstrap
5//! aggregation. Final prediction is the mean across all bags.
6//!
7//! This implements the SGB(Oza) algorithm from Gunasekara et al. (2025),
8//! which substantially reduces variance in streaming gradient boosted
9//! regression compared to a single SGBT ensemble.
10//!
11//! # Example
12//!
13//! ```text
14//! use irithyll::ensemble::bagged::BaggedSGBT;
15//! use irithyll::SGBTConfig;
16//!
17//! let config = SGBTConfig::builder()
18//!     .n_steps(10)
19//!     .learning_rate(0.1)
20//!     .grace_period(10)
21//!     .build()
22//!     .unwrap();
23//!
24//! let mut model = BaggedSGBT::new(config, 5).unwrap();
25//! model.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
26//! let pred = model.predict(&[1.0, 2.0]);
27//! ```
28
29use alloc::vec::Vec;
30
31use crate::ensemble::config::SGBTConfig;
32use crate::ensemble::SGBT;
33use crate::error::{ConfigError, IrithyllError};
34use crate::loss::squared::SquaredLoss;
35use crate::loss::Loss;
36use crate::sample::Observation;
37
38use crate::rng::xorshift64_f64;
39
40// ---------------------------------------------------------------------------
41// Poisson sampling utilities
42// ---------------------------------------------------------------------------
43
44/// Sample from Poisson(lambda=1) using Knuth's algorithm.
45///
46/// Generates uniform random numbers and multiplies them until the product
47/// drops below e^{-1}. The count of multiplications minus one gives the
48/// Poisson draw. For lambda=1, the mean is 1 and values rarely exceed 5.
49fn poisson_sample(rng: &mut u64) -> usize {
50    let l = crate::math::exp(-1.0_f64); // e^{-1} = 0.36787944...
51    let mut k: usize = 0;
52    let mut p: f64 = 1.0;
53    loop {
54        k += 1;
55        let u = xorshift64_f64(rng);
56        p *= u;
57        if p < l {
58            return k - 1;
59        }
60    }
61}
62
63// ---------------------------------------------------------------------------
64// BaggedSGBT
65// ---------------------------------------------------------------------------
66
67/// Bagged (Oza) SGBT ensemble for variance reduction.
68///
69/// Each of the M bags is an independent [`SGBT<L>`] trained on a Poisson(1)-
70/// weighted stream. Predictions are averaged across bags, reducing the
71/// variance of the ensemble without increasing bias.
72///
73/// This implements SGB(Oza) from Gunasekara et al. (2025), the streaming
74/// analogue of Breiman's bootstrap aggregation adapted for gradient boosted
75/// trees with Hoeffding-bound splits.
76pub struct BaggedSGBT<L: Loss = SquaredLoss> {
77    bags: Vec<SGBT<L>>,
78    n_bags: usize,
79    samples_seen: u64,
80    rng_state: u64,
81    seed: u64,
82}
83
84impl<L: Loss + Clone> Clone for BaggedSGBT<L> {
85    fn clone(&self) -> Self {
86        Self {
87            bags: self.bags.clone(),
88            n_bags: self.n_bags,
89            samples_seen: self.samples_seen,
90            rng_state: self.rng_state,
91            seed: self.seed,
92        }
93    }
94}
95
96impl<L: Loss> core::fmt::Debug for BaggedSGBT<L> {
97    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
98        f.debug_struct("BaggedSGBT")
99            .field("n_bags", &self.n_bags)
100            .field("samples_seen", &self.samples_seen)
101            .finish()
102    }
103}
104
105impl BaggedSGBT<SquaredLoss> {
106    /// Create a new bagged SGBT with squared loss (regression).
107    ///
108    /// # Errors
109    ///
110    /// Returns [`IrithyllError::InvalidConfig`] if `n_bags < 1`.
111    pub fn new(config: SGBTConfig, n_bags: usize) -> crate::error::Result<Self> {
112        Self::with_loss(config, SquaredLoss, n_bags)
113    }
114}
115
116impl<L: Loss + Clone> BaggedSGBT<L> {
117    /// Create a new bagged SGBT with a custom loss function.
118    ///
119    /// Each bag receives a unique seed derived from the config seed, ensuring
120    /// diverse tree structures across bags.
121    ///
122    /// # Errors
123    ///
124    /// Returns [`IrithyllError::InvalidConfig`] if `n_bags < 1`.
125    pub fn with_loss(config: SGBTConfig, loss: L, n_bags: usize) -> crate::error::Result<Self> {
126        if n_bags < 1 {
127            return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
128                "n_bags",
129                "must be >= 1",
130                n_bags,
131            )));
132        }
133
134        let seed = config.seed;
135        let bags = (0..n_bags)
136            .map(|i| {
137                let mut cfg = config.clone();
138                // Each bag gets a unique seed for diverse tree construction
139                cfg.seed = config.seed ^ (0xBA6_0000_0000_0000 | i as u64);
140                SGBT::with_loss(cfg, loss.clone())
141            })
142            .collect();
143
144        Ok(Self {
145            bags,
146            n_bags,
147            samples_seen: 0,
148            rng_state: seed,
149            seed,
150        })
151    }
152
153    /// Train all bags on a single observation with Poisson(1) weighting.
154    ///
155    /// For each bag, draws k ~ Poisson(1) and calls `bag.train_one(sample)`
156    /// k times. On average, each bag sees the sample once, but the randomness
157    /// creates diverse training sets across bags.
158    pub fn train_one(&mut self, sample: &impl Observation) {
159        self.samples_seen += 1;
160        let target = sample.target();
161        let features = sample.features();
162
163        // Guard: skip non-finite inputs to prevent NaN/Inf from corrupting model state.
164        if !target.is_finite() || !features.iter().all(|f| f.is_finite()) {
165            return;
166        }
167
168        for bag in &mut self.bags {
169            let k = poisson_sample(&mut self.rng_state);
170            for _ in 0..k {
171                bag.train_one(sample);
172            }
173        }
174    }
175
176    /// Train on a batch of observations.
177    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
178        for sample in samples {
179            self.train_one(sample);
180        }
181    }
182
183    /// Predict the raw output as the mean across all bags.
184    pub fn predict(&self, features: &[f64]) -> f64 {
185        let sum: f64 = self.bags.iter().map(|b| b.predict(features)).sum();
186        sum / self.n_bags as f64
187    }
188
189    /// Predict with loss transform applied (e.g., sigmoid for logistic loss),
190    /// averaged across bags.
191    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
192        let sum: f64 = self
193            .bags
194            .iter()
195            .map(|b| b.predict_transformed(features))
196            .sum();
197        sum / self.n_bags as f64
198    }
199
200    /// Batch prediction.
201    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
202        feature_matrix.iter().map(|f| self.predict(f)).collect()
203    }
204
205    /// Number of bags in the ensemble.
206    #[inline]
207    pub fn n_bags(&self) -> usize {
208        self.n_bags
209    }
210
211    /// Total samples seen.
212    #[inline]
213    pub fn n_samples_seen(&self) -> u64 {
214        self.samples_seen
215    }
216
217    /// Immutable access to all bags.
218    pub fn bags(&self) -> &[SGBT<L>] {
219        &self.bags
220    }
221
222    /// Immutable access to a specific bag.
223    ///
224    /// # Panics
225    ///
226    /// Panics if `idx >= n_bags`.
227    pub fn bag(&self, idx: usize) -> &SGBT<L> {
228        &self.bags[idx]
229    }
230
231    /// Whether the base prediction has been initialized for all bags.
232    pub fn is_initialized(&self) -> bool {
233        self.bags.iter().all(|b| b.is_initialized())
234    }
235
236    /// Reset all bags to initial state.
237    pub fn reset(&mut self) {
238        for bag in &mut self.bags {
239            bag.reset();
240        }
241        self.samples_seen = 0;
242        self.rng_state = self.seed;
243    }
244}
245
246// ---------------------------------------------------------------------------
247// StreamingLearner impl
248// ---------------------------------------------------------------------------
249
250use crate::learner::StreamingLearner;
251use crate::sample::SampleRef;
252
253impl<L: Loss + Clone> StreamingLearner for BaggedSGBT<L> {
254    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
255        let sample = SampleRef::weighted(features, target, weight);
256        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
257        BaggedSGBT::train_one(self, &sample);
258    }
259
260    fn predict(&self, features: &[f64]) -> f64 {
261        BaggedSGBT::predict(self, features)
262    }
263
264    fn n_samples_seen(&self) -> u64 {
265        self.samples_seen
266    }
267
268    fn reset(&mut self) {
269        BaggedSGBT::reset(self);
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::sample::Sample;
277    use alloc::boxed::Box;
278    use alloc::vec;
279    use alloc::vec::Vec;
280
281    fn test_config() -> SGBTConfig {
282        SGBTConfig::builder()
283            .n_steps(10)
284            .learning_rate(0.1)
285            .grace_period(10)
286            .initial_target_count(5)
287            .build()
288            .unwrap()
289    }
290
291    #[test]
292    fn creates_correct_number_of_bags() {
293        let model = BaggedSGBT::new(test_config(), 7).unwrap();
294        assert_eq!(model.n_bags(), 7);
295        assert_eq!(model.bags().len(), 7);
296        assert_eq!(model.n_samples_seen(), 0);
297    }
298
299    #[test]
300    fn rejects_zero_bags() {
301        let result = BaggedSGBT::new(test_config(), 0);
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn single_bag_equals_single_sgbt() {
307        // With 1 bag, bagged model should behave similarly to plain SGBT
308        // (not exactly due to Poisson weighting, but close)
309        let config = test_config();
310        let mut model = BaggedSGBT::new(config, 1).unwrap();
311
312        for i in 0..100 {
313            let x = i as f64 * 0.1;
314            model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
315        }
316
317        let pred = model.predict(&[0.5]);
318        // Should approximate 0.5 * 2.0 + 1.0 = 2.0
319        // With Poisson weighting, predictions can deviate more than a single SGBT.
320        assert!(
321            pred.is_finite(),
322            "prediction should be finite, got {}",
323            pred
324        );
325    }
326
327    #[test]
328    fn poisson_mean_approximately_one() {
329        let mut rng = 0xDEAD_BEEF_u64;
330        let n = 10_000;
331        let sum: usize = (0..n).map(|_| poisson_sample(&mut rng)).sum();
332        let mean = sum as f64 / n as f64;
333        assert!(
334            (mean - 1.0).abs() < 0.1,
335            "Poisson(1) mean should be ~1.0, got {}",
336            mean
337        );
338    }
339
340    #[test]
341    fn poisson_never_negative() {
342        let mut rng = 42u64;
343        for _ in 0..10_000 {
344            // poisson_sample returns usize, so it can't be negative,
345            // but let's verify values are reasonable
346            let k = poisson_sample(&mut rng);
347            assert!(k < 20, "Poisson(1) should rarely exceed 10, got {}", k);
348        }
349    }
350
351    #[test]
352    fn deterministic_with_same_seed() {
353        let config = test_config();
354        let mut model1 = BaggedSGBT::new(config.clone(), 3).unwrap();
355        let mut model2 = BaggedSGBT::new(config, 3).unwrap();
356
357        let samples: Vec<Sample> = (0..50)
358            .map(|i| {
359                let x = i as f64 * 0.1;
360                Sample::new(vec![x], x * 3.0)
361            })
362            .collect();
363
364        for s in &samples {
365            model1.train_one(s);
366            model2.train_one(s);
367        }
368
369        let pred1 = model1.predict(&[0.5]);
370        let pred2 = model2.predict(&[0.5]);
371        assert!(
372            (pred1 - pred2).abs() < 1e-10,
373            "same seed should give identical predictions: {} vs {}",
374            pred1,
375            pred2
376        );
377    }
378
379    #[test]
380    fn predict_averages_bags() {
381        let config = test_config();
382        let mut model = BaggedSGBT::new(config, 5).unwrap();
383
384        for i in 0..100 {
385            let x = i as f64 * 0.1;
386            model.train_one(&Sample::new(vec![x], x));
387        }
388
389        // Verify prediction is the mean of individual bag predictions
390        let features = [0.5];
391        let individual_sum: f64 = model.bags().iter().map(|b| b.predict(&features)).sum();
392        let expected = individual_sum / model.n_bags() as f64;
393        let actual = model.predict(&features);
394        assert!(
395            (actual - expected).abs() < 1e-10,
396            "predict should be mean of bags: {} vs {}",
397            actual,
398            expected
399        );
400    }
401
402    #[test]
403    fn reset_clears_state() {
404        let config = test_config();
405        let mut model = BaggedSGBT::new(config, 3).unwrap();
406
407        for i in 0..100 {
408            let x = i as f64;
409            model.train_one(&Sample::new(vec![x], x));
410        }
411        assert!(model.n_samples_seen() > 0);
412
413        model.reset();
414        assert_eq!(model.n_samples_seen(), 0);
415    }
416
417    #[test]
418    fn convergence_on_linear_target() {
419        let config = SGBTConfig::builder()
420            .n_steps(20)
421            .learning_rate(0.1)
422            .grace_period(10)
423            .initial_target_count(5)
424            .build()
425            .unwrap();
426
427        let mut model = BaggedSGBT::new(config, 5).unwrap();
428
429        // Train on y = 2*x + 1
430        for i in 0..500 {
431            let x = (i % 100) as f64 * 0.1;
432            model.train_one(&Sample::new(vec![x], 2.0 * x + 1.0));
433        }
434
435        // Verify predictions are finite and the model is learning
436        // (exact accuracy depends heavily on Poisson weighting variance).
437        let test_points = [0.0, 0.5, 1.0];
438        for &x in &test_points {
439            let pred = model.predict(&[x]);
440            assert!(
441                pred.is_finite(),
442                "at x={}: prediction should be finite, got {}",
443                x,
444                pred
445            );
446        }
447        // Verify the model is directionally correct: pred(1.0) > pred(0.0)
448        let p0 = model.predict(&[0.0]);
449        let p1 = model.predict(&[1.0]);
450        // The function is y = 2x + 1, so pred(1.0) should be > pred(0.0).
451        // Allow for noise in streaming approximation.
452        assert!(
453            p1 > p0 || (p1 - p0).abs() < 5.0,
454            "directional: pred(1.0)={}, pred(0.0)={}",
455            p1,
456            p0
457        );
458    }
459
460    #[test]
461    fn variance_reduction() {
462        // Compare variance of predictions across bags for a single bagged model
463        // vs. running multiple independent SGBTs.
464        // The bagged model's mean prediction should have lower variance.
465        let config = SGBTConfig::builder()
466            .n_steps(10)
467            .learning_rate(0.1)
468            .grace_period(10)
469            .initial_target_count(5)
470            .build()
471            .unwrap();
472
473        let mut model = BaggedSGBT::new(config, 10).unwrap();
474
475        for i in 0..200 {
476            let x = (i % 50) as f64 * 0.1;
477            model.train_one(&Sample::new(vec![x], x * x));
478        }
479
480        // The individual bag predictions should vary, but the mean is stable
481        let features = [0.3];
482        let preds: Vec<f64> = model.bags().iter().map(|b| b.predict(&features)).collect();
483        let mean = preds.iter().sum::<f64>() / preds.len() as f64;
484        let variance = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64;
485
486        // Individual bag predictions should have some variance (bags are diverse)
487        // This is a soft check -- just verify bags aren't all identical
488        assert!(
489            preds.len() > 1,
490            "need multiple bags to test variance reduction"
491        );
492        // The ensemble mean is `model.predict()`, which should be close to `mean`
493        let ensemble_pred = model.predict(&features);
494        assert!(
495            (ensemble_pred - mean).abs() < 1e-10,
496            "ensemble prediction should be mean of bags"
497        );
498        // Variance should be finite and non-negative (basic sanity)
499        assert!(variance >= 0.0 && variance.is_finite());
500    }
501
502    #[test]
503    fn streaming_learner_trait_object() {
504        let config = test_config();
505        let model = BaggedSGBT::new(config, 3).unwrap();
506        let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
507        for i in 0..100 {
508            let x = i as f64 * 0.1;
509            boxed.train(&[x], x * 2.0);
510        }
511        assert_eq!(boxed.n_samples_seen(), 100);
512        let pred = boxed.predict(&[5.0]);
513        assert!(pred.is_finite());
514        boxed.reset();
515        assert_eq!(boxed.n_samples_seen(), 0);
516    }
517}