irithyll_core/ensemble/stacked.rs
1//! Polymorphic model stacking meta-learner for streaming ensembles.
2//!
3//! [`StackedEnsemble`] implements *stacked generalization* (Wolpert, 1992) in a
4//! streaming context. Multiple heterogeneous base learners -- any type implementing
5//! [`StreamingLearner`] -- produce predictions that are fed as features to a
6//! meta-learner which learns to optimally combine them.
7//!
8//! # Temporal Holdout Stacking
9//!
10//! In batch stacking, cross-validation prevents the meta-learner from seeing
11//! memorized training predictions. In a streaming setting we use **temporal
12//! holdout**: for each incoming sample `(x, y, w)`, base predictions are
13//! collected *before* the base learners are trained on that sample. This ensures
14//! the meta-learner always sees honest, out-of-sample-like predictions rather
15//! than memorized values -- the streaming analogue of leave-one-out stacking.
16//!
17//! # Recursive Stacking
18//!
19//! Because `StackedEnsemble` itself implements [`StreamingLearner`], it can be
20//! used as a base learner inside another `StackedEnsemble`, enabling arbitrarily
21//! deep stacking hierarchies.
22//!
23//! # Example
24//!
25//! ```text
26//! use irithyll::learner::{StreamingLearner, SGBTLearner};
27//! use irithyll::learners::linear::StreamingLinearModel;
28//! use irithyll::ensemble::stacked::StackedEnsemble;
29//! use irithyll::SGBTConfig;
30//!
31//! let config = SGBTConfig::builder()
32//! .n_steps(5)
33//! .learning_rate(0.1)
34//! .grace_period(10)
35//! .max_depth(3)
36//! .n_bins(8)
37//! .build()
38//! .unwrap();
39//!
40//! let bases: Vec<Box<dyn StreamingLearner>> = vec![
41//! Box::new(SGBTLearner::from_config(config)),
42//! Box::new(StreamingLinearModel::new(0.01)),
43//! ];
44//! let meta: Box<dyn StreamingLearner> = Box::new(StreamingLinearModel::new(0.01));
45//!
46//! let mut stack = StackedEnsemble::new(bases, meta);
47//! stack.train(&[1.0, 2.0], 3.0);
48//! let pred = stack.predict(&[1.0, 2.0]);
49//! assert!(pred.is_finite());
50//! ```
51
52use alloc::boxed::Box;
53use alloc::vec::Vec;
54
55use core::fmt;
56
57use crate::learner::StreamingLearner;
58
59// ---------------------------------------------------------------------------
60// StackedEnsemble
61// ---------------------------------------------------------------------------
62
63/// Polymorphic model stacking meta-learner using `Box<dyn StreamingLearner>`.
64///
65/// Combines predictions from heterogeneous base learners through a trainable
66/// meta-learner. Uses temporal holdout to prevent information leakage: base
67/// predictions are collected *before* training the bases on each sample.
68///
69/// # Note on `Clone`
70///
71/// `StackedEnsemble` cannot implement `Clone` because `Box<dyn StreamingLearner>`
72/// is not `Clone`. If you need to snapshot the ensemble, serialize it instead.
73pub struct StackedEnsemble {
74 /// Base learners -- heterogeneous models wrapped as trait objects.
75 base_learners: Vec<Box<dyn StreamingLearner>>,
76 /// Meta-learner that combines base predictions.
77 meta_learner: Box<dyn StreamingLearner>,
78 /// Whether to pass original features alongside base predictions to the meta-learner.
79 passthrough: bool,
80 /// Total samples trained on.
81 samples_seen: u64,
82}
83
84// ---------------------------------------------------------------------------
85// Constructors and accessors
86// ---------------------------------------------------------------------------
87
88impl StackedEnsemble {
89 /// Create a new stacked ensemble with passthrough disabled.
90 ///
91 /// The meta-learner receives only base learner predictions as features.
92 ///
93 /// # Arguments
94 ///
95 /// * `base_learners` -- heterogeneous base models (at least one recommended)
96 /// * `meta_learner` -- combiner model trained on base predictions
97 #[inline]
98 pub fn new(
99 base_learners: Vec<Box<dyn StreamingLearner>>,
100 meta_learner: Box<dyn StreamingLearner>,
101 ) -> Self {
102 Self {
103 base_learners,
104 meta_learner,
105 passthrough: false,
106 samples_seen: 0,
107 }
108 }
109
110 /// Create a new stacked ensemble with configurable feature passthrough.
111 ///
112 /// When `passthrough` is `true`, the meta-learner receives both base
113 /// predictions *and* the original feature vector, enabling it to learn
114 /// corrections that depend on raw inputs.
115 ///
116 /// # Arguments
117 ///
118 /// * `base_learners` -- heterogeneous base models
119 /// * `meta_learner` -- combiner model
120 /// * `passthrough` -- if `true`, original features are appended to meta-features
121 #[inline]
122 pub fn with_passthrough(
123 base_learners: Vec<Box<dyn StreamingLearner>>,
124 meta_learner: Box<dyn StreamingLearner>,
125 passthrough: bool,
126 ) -> Self {
127 Self {
128 base_learners,
129 meta_learner,
130 passthrough,
131 samples_seen: 0,
132 }
133 }
134
135 /// Number of base learners in the ensemble.
136 #[inline]
137 pub fn n_base_learners(&self) -> usize {
138 self.base_learners.len()
139 }
140
141 /// Whether original features are passed through to the meta-learner.
142 #[inline]
143 pub fn passthrough(&self) -> bool {
144 self.passthrough
145 }
146
147 /// Get predictions from each base learner for inspection.
148 ///
149 /// Returns a vector with one prediction per base learner, in the same
150 /// order they were provided at construction time.
151 #[inline]
152 pub fn base_predictions(&self, features: &[f64]) -> Vec<f64> {
153 self.base_learners
154 .iter()
155 .map(|learner| learner.predict(features))
156 .collect()
157 }
158
159 /// Build the meta-feature vector from base predictions and optional original features.
160 ///
161 /// Layout: `[base1_pred, base2_pred, ..., baseK_pred, (original features if passthrough)]`
162 fn build_meta_features(&self, features: &[f64], base_preds: &[f64]) -> Vec<f64> {
163 if self.passthrough {
164 let mut meta_features = Vec::with_capacity(base_preds.len() + features.len());
165 meta_features.extend_from_slice(base_preds);
166 meta_features.extend_from_slice(features);
167 meta_features
168 } else {
169 base_preds.to_vec()
170 }
171 }
172}
173
174// ---------------------------------------------------------------------------
175// StreamingLearner impl -- enables recursive stacking
176// ---------------------------------------------------------------------------
177
178impl StreamingLearner for StackedEnsemble {
179 /// Train on a single weighted observation using temporal holdout.
180 ///
181 /// 1. Collect base predictions **before** training (temporal holdout).
182 /// 2. Build meta-features and train the meta-learner on `(meta_features, target, weight)`.
183 /// 3. Train each base learner on `(features, target, weight)`.
184 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
185 // Step 1: Collect pre-training predictions from base learners.
186 let base_preds: Vec<f64> = self
187 .base_learners
188 .iter()
189 .map(|learner| learner.predict(features))
190 .collect();
191
192 // Step 2: Build meta-features and train the meta-learner.
193 let meta_features = self.build_meta_features(features, &base_preds);
194 self.meta_learner.train_one(&meta_features, target, weight);
195
196 // Step 3: Train base learners AFTER meta-learner has used their predictions.
197 for learner in &mut self.base_learners {
198 learner.train_one(features, target, weight);
199 }
200
201 self.samples_seen += 1;
202 }
203
204 /// Predict by collecting base predictions and passing them through the meta-learner.
205 #[inline]
206 fn predict(&self, features: &[f64]) -> f64 {
207 let base_preds = self.base_predictions(features);
208 let meta_features = self.build_meta_features(features, &base_preds);
209 self.meta_learner.predict(&meta_features)
210 }
211
212 /// Total number of samples trained on since creation or last reset.
213 #[inline]
214 fn n_samples_seen(&self) -> u64 {
215 self.samples_seen
216 }
217
218 /// Reset all base learners, the meta-learner, and the sample counter.
219 fn reset(&mut self) {
220 for learner in &mut self.base_learners {
221 learner.reset();
222 }
223 self.meta_learner.reset();
224 self.samples_seen = 0;
225 }
226}
227
228// ---------------------------------------------------------------------------
229// Debug impl -- manual since Box<dyn StreamingLearner> does not impl Debug
230// ---------------------------------------------------------------------------
231
232impl fmt::Debug for StackedEnsemble {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 f.debug_struct("StackedEnsemble")
235 .field("n_base_learners", &self.base_learners.len())
236 .field("passthrough", &self.passthrough)
237 .field("samples_seen", &self.samples_seen)
238 .finish()
239 }
240}
241
242// ---------------------------------------------------------------------------
243// Tests
244// ---------------------------------------------------------------------------
245
246// Tests require SGBTLearner and StreamingLinearModel which live in the full
247// `irithyll` crate, not in `irithyll-core`. These tests are exercised via the
248// re-export layer in `irithyll::ensemble::stacked`.
249#[cfg(all(test, feature = "_stacked_tests_disabled"))]
250mod tests {
251 use super::*;
252 use crate::learner::SGBTLearner;
253 use crate::learners::linear::StreamingLinearModel;
254 use crate::SGBTConfig;
255
256 /// Shared minimal SGBT config for tests.
257 fn test_config() -> SGBTConfig {
258 SGBTConfig::builder()
259 .n_steps(5)
260 .learning_rate(0.1)
261 .grace_period(10)
262 .max_depth(3)
263 .n_bins(8)
264 .build()
265 .unwrap()
266 }
267
268 /// Create a pair of SGBT base learners as trait objects.
269 fn sgbt_bases() -> Vec<Box<dyn StreamingLearner>> {
270 vec![
271 Box::new(SGBTLearner::from_config(test_config())),
272 Box::new(SGBTLearner::from_config(test_config())),
273 ]
274 }
275
276 /// Create a linear meta-learner as a trait object.
277 fn linear_meta() -> Box<dyn StreamingLearner> {
278 Box::new(StreamingLinearModel::new(0.01))
279 }
280
281 #[test]
282 fn test_creation() {
283 let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
284 assert_eq!(stack.n_base_learners(), 2);
285 assert!(!stack.passthrough());
286 assert_eq!(stack.n_samples_seen(), 0);
287 }
288
289 #[test]
290 fn test_train_and_predict() {
291 let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
292
293 // Train on a simple pattern.
294 for i in 0..50 {
295 let x = i as f64 * 0.1;
296 stack.train(&[x, x * 2.0], x * 3.0);
297 }
298
299 assert_eq!(stack.n_samples_seen(), 50);
300
301 // Prediction should be finite and non-trivial after training.
302 let pred = stack.predict(&[1.0, 2.0]);
303 assert!(
304 pred.is_finite(),
305 "prediction should be finite, got {}",
306 pred
307 );
308 }
309
310 #[test]
311 fn test_temporal_holdout() {
312 // Verify that the meta-learner sees pre-training predictions by
313 // checking that base learner sample counts advance correctly:
314 // after training the stack once, each base should have seen 1 sample.
315 let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
316
317 // Before training, base learners have seen 0 samples.
318 for bp in &stack.base_learners {
319 assert_eq!(bp.n_samples_seen(), 0);
320 }
321
322 // Train one sample through the stack.
323 stack.train(&[1.0, 2.0], 3.0);
324
325 // After training, each base learner has seen exactly 1 sample.
326 // The temporal holdout guarantee is that the meta-learner was trained
327 // on predictions made *before* this sample was ingested by the bases.
328 for bp in &stack.base_learners {
329 assert_eq!(bp.n_samples_seen(), 1);
330 }
331 assert_eq!(stack.meta_learner.n_samples_seen(), 1);
332 assert_eq!(stack.n_samples_seen(), 1);
333
334 // Train a second sample and verify counts advance together.
335 stack.train(&[3.0, 4.0], 5.0);
336 for bp in &stack.base_learners {
337 assert_eq!(bp.n_samples_seen(), 2);
338 }
339 assert_eq!(stack.meta_learner.n_samples_seen(), 2);
340 assert_eq!(stack.n_samples_seen(), 2);
341 }
342
343 #[test]
344 fn test_passthrough() {
345 // With passthrough=true, meta-features should include original features.
346 // We can verify this indirectly: a passthrough stack with a linear meta
347 // should produce a different prediction than a non-passthrough stack,
348 // because the meta-learner sees a wider feature vector.
349 let bases_a = sgbt_bases();
350 let bases_b = sgbt_bases();
351
352 let mut no_pass = StackedEnsemble::new(bases_a, linear_meta());
353 let mut with_pass = StackedEnsemble::with_passthrough(bases_b, linear_meta(), true);
354
355 assert!(!no_pass.passthrough());
356 assert!(with_pass.passthrough());
357
358 // Train both on the same data.
359 for i in 0..30 {
360 let x = i as f64 * 0.1;
361 let features = [x, x * 2.0];
362 let target = x * 3.0 + 1.0;
363 no_pass.train(&features, target);
364 with_pass.train(&features, target);
365 }
366
367 // Verify meta-feature dimensions differ by checking that build_meta_features
368 // produces different-length vectors.
369 let features = [1.0, 2.0];
370 let base_preds = [0.5, 0.7]; // mock base predictions
371 let meta_no = no_pass.build_meta_features(&features, &base_preds);
372 let meta_yes = with_pass.build_meta_features(&features, &base_preds);
373
374 assert_eq!(meta_no.len(), 2, "no passthrough: only base predictions");
375 assert_eq!(
376 meta_yes.len(),
377 4,
378 "passthrough: base predictions + original features"
379 );
380 assert!(
381 crate::math::abs((meta_yes[2] - 1.0)) < 1e-12,
382 "original features appended"
383 );
384 assert!(
385 crate::math::abs((meta_yes[3] - 2.0)) < 1e-12,
386 "original features appended"
387 );
388 }
389
390 #[test]
391 fn test_base_predictions() {
392 let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
393
394 // Before training, base predictions should all be zero (untrained models).
395 let preds = stack.base_predictions(&[1.0, 2.0]);
396 assert_eq!(preds.len(), 2);
397 for p in &preds {
398 assert!(
399 crate::math::abs(p) < 1e-12,
400 "untrained base should predict ~0, got {}",
401 p
402 );
403 }
404
405 // Train a few samples.
406 for i in 0..20 {
407 let x = i as f64;
408 stack.train(&[x, x * 0.5], x * 2.0);
409 }
410
411 // Base predictions should still return the correct count.
412 let preds_after = stack.base_predictions(&[5.0, 2.5]);
413 assert_eq!(preds_after.len(), 2);
414 for p in &preds_after {
415 assert!(p.is_finite());
416 }
417 }
418
419 #[test]
420 fn test_reset() {
421 let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
422
423 // Train some data.
424 for i in 0..30 {
425 let x = i as f64 * 0.1;
426 stack.train(&[x, x * 2.0], x * 3.0);
427 }
428 assert_eq!(stack.n_samples_seen(), 30);
429
430 // Reset everything.
431 stack.reset();
432 assert_eq!(stack.n_samples_seen(), 0);
433
434 // All base learners should be reset.
435 for bp in &stack.base_learners {
436 assert_eq!(bp.n_samples_seen(), 0);
437 }
438
439 // Meta-learner should be reset.
440 assert_eq!(stack.meta_learner.n_samples_seen(), 0);
441
442 // Predictions after reset should be near zero (untrained state).
443 let pred = stack.predict(&[1.0, 2.0]);
444 assert!(
445 crate::math::abs(pred) < 1e-12,
446 "prediction after reset should be ~0, got {}",
447 pred,
448 );
449 }
450
451 #[test]
452 fn test_n_samples_seen() {
453 let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
454
455 assert_eq!(stack.n_samples_seen(), 0);
456
457 for i in 1..=10 {
458 stack.train(&[i as f64], i as f64);
459 assert_eq!(stack.n_samples_seen(), i);
460 }
461
462 // Weighted training also increments by 1 (sample count, not weight sum).
463 stack.train_one(&[11.0], 11.0, 5.0);
464 assert_eq!(stack.n_samples_seen(), 11);
465 }
466
467 #[test]
468 fn test_trait_object() {
469 // StackedEnsemble itself should work as Box<dyn StreamingLearner>,
470 // enabling recursive stacking.
471 let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
472 let mut boxed: Box<dyn StreamingLearner> = Box::new(stack);
473
474 boxed.train(&[1.0, 2.0], 3.0);
475 assert_eq!(boxed.n_samples_seen(), 1);
476
477 let pred = boxed.predict(&[1.0, 2.0]);
478 assert!(pred.is_finite());
479
480 boxed.reset();
481 assert_eq!(boxed.n_samples_seen(), 0);
482 }
483
484 #[test]
485 fn test_heterogeneous_bases() {
486 // Mix SGBT and linear base learners -- the core polymorphism use case.
487 let bases: Vec<Box<dyn StreamingLearner>> = vec![
488 Box::new(SGBTLearner::from_config(test_config())),
489 Box::new(StreamingLinearModel::new(0.01)),
490 Box::new(StreamingLinearModel::ridge(0.01, 0.001)),
491 ];
492 let meta = linear_meta();
493
494 let mut stack = StackedEnsemble::new(bases, meta);
495 assert_eq!(stack.n_base_learners(), 3);
496
497 // Train on a linear-ish pattern. Both SGBT and linear models should
498 // contribute meaningful predictions.
499 for i in 0..40 {
500 let x = i as f64 * 0.1;
501 stack.train(&[x, x * 0.5], 2.0 * x + 1.0);
502 }
503
504 assert_eq!(stack.n_samples_seen(), 40);
505
506 let preds = stack.base_predictions(&[2.0, 1.0]);
507 assert_eq!(preds.len(), 3);
508 for p in &preds {
509 assert!(p.is_finite(), "base prediction should be finite, got {}", p);
510 }
511
512 let final_pred = stack.predict(&[2.0, 1.0]);
513 assert!(final_pred.is_finite());
514 }
515
516 #[test]
517 fn test_predict_batch() {
518 let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
519
520 // Train enough samples for non-trivial predictions.
521 for i in 0..30 {
522 let x = i as f64 * 0.1;
523 stack.train(&[x, x * 2.0], x * 3.0);
524 }
525
526 let rows: Vec<&[f64]> = vec![&[0.5, 1.0], &[1.5, 3.0], &[2.5, 5.0]];
527 let batch = stack.predict_batch(&rows);
528
529 // Batch results should exactly match individual predictions.
530 assert_eq!(batch.len(), rows.len());
531 for (i, row) in rows.iter().enumerate() {
532 let individual = stack.predict(row);
533 assert!(
534 crate::math::abs((batch[i] - individual)) < 1e-12,
535 "batch[{}]={} != individual={}",
536 i,
537 batch[i],
538 individual,
539 );
540 }
541 }
542
543 #[test]
544 fn test_debug_impl() {
545 let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
546 let debug_str = format!("{:?}", stack);
547 assert!(debug_str.contains("StackedEnsemble"));
548 assert!(debug_str.contains("n_base_learners: 2"));
549 assert!(debug_str.contains("passthrough: false"));
550 assert!(debug_str.contains("samples_seen: 0"));
551 }
552}