Skip to main content

irithyll_core/ensemble/
stacked.rs

1//! Polymorphic model stacking meta-learner for streaming ensembles.
2//!
3//! [`StackedEnsemble`] implements *stacked generalization* (Wolpert, 1992) in a
4//! streaming context. Multiple heterogeneous base learners -- any type implementing
5//! [`StreamingLearner`] -- produce predictions that are fed as features to a
6//! meta-learner which learns to optimally combine them.
7//!
8//! # Temporal Holdout Stacking
9//!
10//! In batch stacking, cross-validation prevents the meta-learner from seeing
11//! memorized training predictions. In a streaming setting we use **temporal
12//! holdout**: for each incoming sample `(x, y, w)`, base predictions are
13//! collected *before* the base learners are trained on that sample. This ensures
14//! the meta-learner always sees honest, out-of-sample-like predictions rather
15//! than memorized values -- the streaming analogue of leave-one-out stacking.
16//!
17//! # Recursive Stacking
18//!
19//! Because `StackedEnsemble` itself implements [`StreamingLearner`], it can be
20//! used as a base learner inside another `StackedEnsemble`, enabling arbitrarily
21//! deep stacking hierarchies.
22//!
23//! # Example
24//!
25//! ```text
26//! use irithyll::learner::{StreamingLearner, SGBTLearner};
27//! use irithyll::learners::linear::StreamingLinearModel;
28//! use irithyll::ensemble::stacked::StackedEnsemble;
29//! use irithyll::SGBTConfig;
30//!
31//! let config = SGBTConfig::builder()
32//!     .n_steps(5)
33//!     .learning_rate(0.1)
34//!     .grace_period(10)
35//!     .max_depth(3)
36//!     .n_bins(8)
37//!     .build()
38//!     .unwrap();
39//!
40//! let bases: Vec<Box<dyn StreamingLearner>> = vec![
41//!     Box::new(SGBTLearner::from_config(config)),
42//!     Box::new(StreamingLinearModel::new(0.01)),
43//! ];
44//! let meta: Box<dyn StreamingLearner> = Box::new(StreamingLinearModel::new(0.01));
45//!
46//! let mut stack = StackedEnsemble::new(bases, meta);
47//! stack.train(&[1.0, 2.0], 3.0);
48//! let pred = stack.predict(&[1.0, 2.0]);
49//! assert!(pred.is_finite());
50//! ```
51
52use alloc::boxed::Box;
53use alloc::vec::Vec;
54
55use core::fmt;
56
57use crate::learner::StreamingLearner;
58
59// ---------------------------------------------------------------------------
60// StackedEnsemble
61// ---------------------------------------------------------------------------
62
63/// Polymorphic model stacking meta-learner using `Box<dyn StreamingLearner>`.
64///
65/// Combines predictions from heterogeneous base learners through a trainable
66/// meta-learner. Uses temporal holdout to prevent information leakage: base
67/// predictions are collected *before* training the bases on each sample.
68///
69/// # Note on `Clone`
70///
71/// `StackedEnsemble` cannot implement `Clone` because `Box<dyn StreamingLearner>`
72/// is not `Clone`. If you need to snapshot the ensemble, serialize it instead.
73pub struct StackedEnsemble {
74    /// Base learners -- heterogeneous models wrapped as trait objects.
75    base_learners: Vec<Box<dyn StreamingLearner>>,
76    /// Meta-learner that combines base predictions.
77    meta_learner: Box<dyn StreamingLearner>,
78    /// Whether to pass original features alongside base predictions to the meta-learner.
79    passthrough: bool,
80    /// Total samples trained on.
81    samples_seen: u64,
82}
83
84// ---------------------------------------------------------------------------
85// Constructors and accessors
86// ---------------------------------------------------------------------------
87
88impl StackedEnsemble {
89    /// Create a new stacked ensemble with passthrough disabled.
90    ///
91    /// The meta-learner receives only base learner predictions as features.
92    ///
93    /// # Arguments
94    ///
95    /// * `base_learners` -- heterogeneous base models (at least one recommended)
96    /// * `meta_learner` -- combiner model trained on base predictions
97    #[inline]
98    pub fn new(
99        base_learners: Vec<Box<dyn StreamingLearner>>,
100        meta_learner: Box<dyn StreamingLearner>,
101    ) -> Self {
102        Self {
103            base_learners,
104            meta_learner,
105            passthrough: false,
106            samples_seen: 0,
107        }
108    }
109
110    /// Create a new stacked ensemble with configurable feature passthrough.
111    ///
112    /// When `passthrough` is `true`, the meta-learner receives both base
113    /// predictions *and* the original feature vector, enabling it to learn
114    /// corrections that depend on raw inputs.
115    ///
116    /// # Arguments
117    ///
118    /// * `base_learners` -- heterogeneous base models
119    /// * `meta_learner` -- combiner model
120    /// * `passthrough` -- if `true`, original features are appended to meta-features
121    #[inline]
122    pub fn with_passthrough(
123        base_learners: Vec<Box<dyn StreamingLearner>>,
124        meta_learner: Box<dyn StreamingLearner>,
125        passthrough: bool,
126    ) -> Self {
127        Self {
128            base_learners,
129            meta_learner,
130            passthrough,
131            samples_seen: 0,
132        }
133    }
134
135    /// Number of base learners in the ensemble.
136    #[inline]
137    pub fn n_base_learners(&self) -> usize {
138        self.base_learners.len()
139    }
140
141    /// Whether original features are passed through to the meta-learner.
142    #[inline]
143    pub fn passthrough(&self) -> bool {
144        self.passthrough
145    }
146
147    /// Get predictions from each base learner for inspection.
148    ///
149    /// Returns a vector with one prediction per base learner, in the same
150    /// order they were provided at construction time.
151    #[inline]
152    pub fn base_predictions(&self, features: &[f64]) -> Vec<f64> {
153        self.base_learners
154            .iter()
155            .map(|learner| learner.predict(features))
156            .collect()
157    }
158
159    /// Build the meta-feature vector from base predictions and optional original features.
160    ///
161    /// Layout: `[base1_pred, base2_pred, ..., baseK_pred, (original features if passthrough)]`
162    fn build_meta_features(&self, features: &[f64], base_preds: &[f64]) -> Vec<f64> {
163        if self.passthrough {
164            let mut meta_features = Vec::with_capacity(base_preds.len() + features.len());
165            meta_features.extend_from_slice(base_preds);
166            meta_features.extend_from_slice(features);
167            meta_features
168        } else {
169            base_preds.to_vec()
170        }
171    }
172}
173
174// ---------------------------------------------------------------------------
175// StreamingLearner impl -- enables recursive stacking
176// ---------------------------------------------------------------------------
177
178impl StreamingLearner for StackedEnsemble {
179    /// Train on a single weighted observation using temporal holdout.
180    ///
181    /// 1. Collect base predictions **before** training (temporal holdout).
182    /// 2. Build meta-features and train the meta-learner on `(meta_features, target, weight)`.
183    /// 3. Train each base learner on `(features, target, weight)`.
184    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
185        // Step 1: Collect pre-training predictions from base learners.
186        let base_preds: Vec<f64> = self
187            .base_learners
188            .iter()
189            .map(|learner| learner.predict(features))
190            .collect();
191
192        // Step 2: Build meta-features and train the meta-learner.
193        let meta_features = self.build_meta_features(features, &base_preds);
194        self.meta_learner.train_one(&meta_features, target, weight);
195
196        // Step 3: Train base learners AFTER meta-learner has used their predictions.
197        for learner in &mut self.base_learners {
198            learner.train_one(features, target, weight);
199        }
200
201        self.samples_seen += 1;
202    }
203
204    /// Predict by collecting base predictions and passing them through the meta-learner.
205    #[inline]
206    fn predict(&self, features: &[f64]) -> f64 {
207        let base_preds = self.base_predictions(features);
208        let meta_features = self.build_meta_features(features, &base_preds);
209        self.meta_learner.predict(&meta_features)
210    }
211
212    /// Total number of samples trained on since creation or last reset.
213    #[inline]
214    fn n_samples_seen(&self) -> u64 {
215        self.samples_seen
216    }
217
218    /// Reset all base learners, the meta-learner, and the sample counter.
219    fn reset(&mut self) {
220        for learner in &mut self.base_learners {
221            learner.reset();
222        }
223        self.meta_learner.reset();
224        self.samples_seen = 0;
225    }
226}
227
228// ---------------------------------------------------------------------------
229// Debug impl -- manual since Box<dyn StreamingLearner> does not impl Debug
230// ---------------------------------------------------------------------------
231
232impl fmt::Debug for StackedEnsemble {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        f.debug_struct("StackedEnsemble")
235            .field("n_base_learners", &self.base_learners.len())
236            .field("passthrough", &self.passthrough)
237            .field("samples_seen", &self.samples_seen)
238            .finish()
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Tests
244// ---------------------------------------------------------------------------
245
246// Tests require SGBTLearner and StreamingLinearModel which live in the full
247// `irithyll` crate, not in `irithyll-core`. These tests are exercised via the
248// re-export layer in `irithyll::ensemble::stacked`.
249#[cfg(all(test, feature = "_stacked_tests_disabled"))]
250mod tests {
251    use super::*;
252    use crate::learner::SGBTLearner;
253    use crate::learners::linear::StreamingLinearModel;
254    use crate::SGBTConfig;
255
256    /// Shared minimal SGBT config for tests.
257    fn test_config() -> SGBTConfig {
258        SGBTConfig::builder()
259            .n_steps(5)
260            .learning_rate(0.1)
261            .grace_period(10)
262            .max_depth(3)
263            .n_bins(8)
264            .build()
265            .unwrap()
266    }
267
268    /// Create a pair of SGBT base learners as trait objects.
269    fn sgbt_bases() -> Vec<Box<dyn StreamingLearner>> {
270        vec![
271            Box::new(SGBTLearner::from_config(test_config())),
272            Box::new(SGBTLearner::from_config(test_config())),
273        ]
274    }
275
276    /// Create a linear meta-learner as a trait object.
277    fn linear_meta() -> Box<dyn StreamingLearner> {
278        Box::new(StreamingLinearModel::new(0.01))
279    }
280
281    #[test]
282    fn test_creation() {
283        let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
284        assert_eq!(stack.n_base_learners(), 2);
285        assert!(!stack.passthrough());
286        assert_eq!(stack.n_samples_seen(), 0);
287    }
288
289    #[test]
290    fn test_train_and_predict() {
291        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
292
293        // Train on a simple pattern.
294        for i in 0..50 {
295            let x = i as f64 * 0.1;
296            stack.train(&[x, x * 2.0], x * 3.0);
297        }
298
299        assert_eq!(stack.n_samples_seen(), 50);
300
301        // Prediction should be finite and non-trivial after training.
302        let pred = stack.predict(&[1.0, 2.0]);
303        assert!(
304            pred.is_finite(),
305            "prediction should be finite, got {}",
306            pred
307        );
308    }
309
310    #[test]
311    fn test_temporal_holdout() {
312        // Verify that the meta-learner sees pre-training predictions by
313        // checking that base learner sample counts advance correctly:
314        // after training the stack once, each base should have seen 1 sample.
315        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
316
317        // Before training, base learners have seen 0 samples.
318        for bp in &stack.base_learners {
319            assert_eq!(bp.n_samples_seen(), 0);
320        }
321
322        // Train one sample through the stack.
323        stack.train(&[1.0, 2.0], 3.0);
324
325        // After training, each base learner has seen exactly 1 sample.
326        // The temporal holdout guarantee is that the meta-learner was trained
327        // on predictions made *before* this sample was ingested by the bases.
328        for bp in &stack.base_learners {
329            assert_eq!(bp.n_samples_seen(), 1);
330        }
331        assert_eq!(stack.meta_learner.n_samples_seen(), 1);
332        assert_eq!(stack.n_samples_seen(), 1);
333
334        // Train a second sample and verify counts advance together.
335        stack.train(&[3.0, 4.0], 5.0);
336        for bp in &stack.base_learners {
337            assert_eq!(bp.n_samples_seen(), 2);
338        }
339        assert_eq!(stack.meta_learner.n_samples_seen(), 2);
340        assert_eq!(stack.n_samples_seen(), 2);
341    }
342
343    #[test]
344    fn test_passthrough() {
345        // With passthrough=true, meta-features should include original features.
346        // We can verify this indirectly: a passthrough stack with a linear meta
347        // should produce a different prediction than a non-passthrough stack,
348        // because the meta-learner sees a wider feature vector.
349        let bases_a = sgbt_bases();
350        let bases_b = sgbt_bases();
351
352        let mut no_pass = StackedEnsemble::new(bases_a, linear_meta());
353        let mut with_pass = StackedEnsemble::with_passthrough(bases_b, linear_meta(), true);
354
355        assert!(!no_pass.passthrough());
356        assert!(with_pass.passthrough());
357
358        // Train both on the same data.
359        for i in 0..30 {
360            let x = i as f64 * 0.1;
361            let features = [x, x * 2.0];
362            let target = x * 3.0 + 1.0;
363            no_pass.train(&features, target);
364            with_pass.train(&features, target);
365        }
366
367        // Verify meta-feature dimensions differ by checking that build_meta_features
368        // produces different-length vectors.
369        let features = [1.0, 2.0];
370        let base_preds = [0.5, 0.7]; // mock base predictions
371        let meta_no = no_pass.build_meta_features(&features, &base_preds);
372        let meta_yes = with_pass.build_meta_features(&features, &base_preds);
373
374        assert_eq!(meta_no.len(), 2, "no passthrough: only base predictions");
375        assert_eq!(
376            meta_yes.len(),
377            4,
378            "passthrough: base predictions + original features"
379        );
380        assert!(
381            crate::math::abs((meta_yes[2] - 1.0)) < 1e-12,
382            "original features appended"
383        );
384        assert!(
385            crate::math::abs((meta_yes[3] - 2.0)) < 1e-12,
386            "original features appended"
387        );
388    }
389
390    #[test]
391    fn test_base_predictions() {
392        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
393
394        // Before training, base predictions should all be zero (untrained models).
395        let preds = stack.base_predictions(&[1.0, 2.0]);
396        assert_eq!(preds.len(), 2);
397        for p in &preds {
398            assert!(
399                crate::math::abs(p) < 1e-12,
400                "untrained base should predict ~0, got {}",
401                p
402            );
403        }
404
405        // Train a few samples.
406        for i in 0..20 {
407            let x = i as f64;
408            stack.train(&[x, x * 0.5], x * 2.0);
409        }
410
411        // Base predictions should still return the correct count.
412        let preds_after = stack.base_predictions(&[5.0, 2.5]);
413        assert_eq!(preds_after.len(), 2);
414        for p in &preds_after {
415            assert!(p.is_finite());
416        }
417    }
418
419    #[test]
420    fn test_reset() {
421        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
422
423        // Train some data.
424        for i in 0..30 {
425            let x = i as f64 * 0.1;
426            stack.train(&[x, x * 2.0], x * 3.0);
427        }
428        assert_eq!(stack.n_samples_seen(), 30);
429
430        // Reset everything.
431        stack.reset();
432        assert_eq!(stack.n_samples_seen(), 0);
433
434        // All base learners should be reset.
435        for bp in &stack.base_learners {
436            assert_eq!(bp.n_samples_seen(), 0);
437        }
438
439        // Meta-learner should be reset.
440        assert_eq!(stack.meta_learner.n_samples_seen(), 0);
441
442        // Predictions after reset should be near zero (untrained state).
443        let pred = stack.predict(&[1.0, 2.0]);
444        assert!(
445            crate::math::abs(pred) < 1e-12,
446            "prediction after reset should be ~0, got {}",
447            pred,
448        );
449    }
450
451    #[test]
452    fn test_n_samples_seen() {
453        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
454
455        assert_eq!(stack.n_samples_seen(), 0);
456
457        for i in 1..=10 {
458            stack.train(&[i as f64], i as f64);
459            assert_eq!(stack.n_samples_seen(), i);
460        }
461
462        // Weighted training also increments by 1 (sample count, not weight sum).
463        stack.train_one(&[11.0], 11.0, 5.0);
464        assert_eq!(stack.n_samples_seen(), 11);
465    }
466
467    #[test]
468    fn test_trait_object() {
469        // StackedEnsemble itself should work as Box<dyn StreamingLearner>,
470        // enabling recursive stacking.
471        let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
472        let mut boxed: Box<dyn StreamingLearner> = Box::new(stack);
473
474        boxed.train(&[1.0, 2.0], 3.0);
475        assert_eq!(boxed.n_samples_seen(), 1);
476
477        let pred = boxed.predict(&[1.0, 2.0]);
478        assert!(pred.is_finite());
479
480        boxed.reset();
481        assert_eq!(boxed.n_samples_seen(), 0);
482    }
483
484    #[test]
485    fn test_heterogeneous_bases() {
486        // Mix SGBT and linear base learners -- the core polymorphism use case.
487        let bases: Vec<Box<dyn StreamingLearner>> = vec![
488            Box::new(SGBTLearner::from_config(test_config())),
489            Box::new(StreamingLinearModel::new(0.01)),
490            Box::new(StreamingLinearModel::ridge(0.01, 0.001)),
491        ];
492        let meta = linear_meta();
493
494        let mut stack = StackedEnsemble::new(bases, meta);
495        assert_eq!(stack.n_base_learners(), 3);
496
497        // Train on a linear-ish pattern. Both SGBT and linear models should
498        // contribute meaningful predictions.
499        for i in 0..40 {
500            let x = i as f64 * 0.1;
501            stack.train(&[x, x * 0.5], 2.0 * x + 1.0);
502        }
503
504        assert_eq!(stack.n_samples_seen(), 40);
505
506        let preds = stack.base_predictions(&[2.0, 1.0]);
507        assert_eq!(preds.len(), 3);
508        for p in &preds {
509            assert!(p.is_finite(), "base prediction should be finite, got {}", p);
510        }
511
512        let final_pred = stack.predict(&[2.0, 1.0]);
513        assert!(final_pred.is_finite());
514    }
515
516    #[test]
517    fn test_predict_batch() {
518        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
519
520        // Train enough samples for non-trivial predictions.
521        for i in 0..30 {
522            let x = i as f64 * 0.1;
523            stack.train(&[x, x * 2.0], x * 3.0);
524        }
525
526        let rows: Vec<&[f64]> = vec![&[0.5, 1.0], &[1.5, 3.0], &[2.5, 5.0]];
527        let batch = stack.predict_batch(&rows);
528
529        // Batch results should exactly match individual predictions.
530        assert_eq!(batch.len(), rows.len());
531        for (i, row) in rows.iter().enumerate() {
532            let individual = stack.predict(row);
533            assert!(
534                crate::math::abs((batch[i] - individual)) < 1e-12,
535                "batch[{}]={} != individual={}",
536                i,
537                batch[i],
538                individual,
539            );
540        }
541    }
542
543    #[test]
544    fn test_debug_impl() {
545        let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
546        let debug_str = format!("{:?}", stack);
547        assert!(debug_str.contains("StackedEnsemble"));
548        assert!(debug_str.contains("n_base_learners: 2"));
549        assert!(debug_str.contains("passthrough: false"));
550        assert!(debug_str.contains("samples_seen: 0"));
551    }
552}