irithyll 10.0.0

Streaming ML in Rust -- gradient boosted trees, neural architectures (TTT/KAN/MoE/Mamba/SNN), AutoML, kernel methods, and composable pipelines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
//! Polymorphic model stacking meta-learner for streaming ensembles.
//!
//! [`StackedEnsemble`] implements *stacked generalization* (Wolpert, 1992) in a
//! streaming context. Multiple heterogeneous base learners -- any type implementing
//! [`StreamingLearner`] -- produce predictions that are fed as features to a
//! meta-learner which learns to optimally combine them.
//!
//! # Temporal Holdout Stacking
//!
//! In batch stacking, cross-validation prevents the meta-learner from seeing
//! memorized training predictions. In a streaming setting we use **temporal
//! holdout**: for each incoming sample `(x, y, w)`, base predictions are
//! collected *before* the base learners are trained on that sample. This ensures
//! the meta-learner always sees honest, out-of-sample-like predictions rather
//! than memorized values -- the streaming analogue of leave-one-out stacking.
//!
//! # Recursive Stacking
//!
//! Because `StackedEnsemble` itself implements [`StreamingLearner`], it can be
//! used as a base learner inside another `StackedEnsemble`, enabling arbitrarily
//! deep stacking hierarchies.
//!
//! # Example
//!
//! ```
//! use irithyll::learner::{StreamingLearner, SGBTLearner};
//! use irithyll::learners::linear::StreamingLinearModel;
//! use irithyll::ensemble::stacked::StackedEnsemble;
//! use irithyll::SGBTConfig;
//!
//! let config = SGBTConfig::builder()
//!     .n_steps(5)
//!     .learning_rate(0.1)
//!     .grace_period(10)
//!     .max_depth(3)
//!     .n_bins(8)
//!     .build()
//!     .unwrap();
//!
//! let bases: Vec<Box<dyn StreamingLearner>> = vec![
//!     Box::new(SGBTLearner::from_config(config)),
//!     Box::new(StreamingLinearModel::new(0.01)),
//! ];
//! let meta: Box<dyn StreamingLearner> = Box::new(StreamingLinearModel::new(0.01));
//!
//! let mut stack = StackedEnsemble::new(bases, meta);
//! stack.train(&[1.0, 2.0], 3.0);
//! let pred = stack.predict(&[1.0, 2.0]);
//! assert!(pred.is_finite());
//! ```

use std::fmt;

use crate::learner::StreamingLearner;

// ---------------------------------------------------------------------------
// StackedEnsemble
// ---------------------------------------------------------------------------

/// Polymorphic model stacking meta-learner using `Box<dyn StreamingLearner>`.
///
/// Combines predictions from heterogeneous base learners through a trainable
/// meta-learner. Uses temporal holdout to prevent information leakage: base
/// predictions are collected *before* training the bases on each sample.
///
/// # Note on `Clone`
///
/// `StackedEnsemble` cannot implement `Clone` because `Box<dyn StreamingLearner>`
/// is not `Clone`. If you need to snapshot the ensemble, serialize it instead.
pub struct StackedEnsemble {
    /// Base learners -- heterogeneous models wrapped as trait objects.
    base_learners: Vec<Box<dyn StreamingLearner>>,
    /// Meta-learner that combines base predictions.
    meta_learner: Box<dyn StreamingLearner>,
    /// Whether to pass original features alongside base predictions to the meta-learner.
    passthrough: bool,
    /// Total samples trained on.
    samples_seen: u64,
}

// ---------------------------------------------------------------------------
// Constructors and accessors
// ---------------------------------------------------------------------------

impl StackedEnsemble {
    /// Create a new stacked ensemble with passthrough disabled.
    ///
    /// The meta-learner receives only base learner predictions as features.
    ///
    /// # Arguments
    ///
    /// * `base_learners` -- heterogeneous base models (at least one recommended)
    /// * `meta_learner` -- combiner model trained on base predictions
    #[inline]
    pub fn new(
        base_learners: Vec<Box<dyn StreamingLearner>>,
        meta_learner: Box<dyn StreamingLearner>,
    ) -> Self {
        Self {
            base_learners,
            meta_learner,
            passthrough: false,
            samples_seen: 0,
        }
    }

    /// Create a new stacked ensemble with configurable feature passthrough.
    ///
    /// When `passthrough` is `true`, the meta-learner receives both base
    /// predictions *and* the original feature vector, enabling it to learn
    /// corrections that depend on raw inputs.
    ///
    /// # Arguments
    ///
    /// * `base_learners` -- heterogeneous base models
    /// * `meta_learner` -- combiner model
    /// * `passthrough` -- if `true`, original features are appended to meta-features
    #[inline]
    pub fn with_passthrough(
        base_learners: Vec<Box<dyn StreamingLearner>>,
        meta_learner: Box<dyn StreamingLearner>,
        passthrough: bool,
    ) -> Self {
        Self {
            base_learners,
            meta_learner,
            passthrough,
            samples_seen: 0,
        }
    }

    /// Number of base learners in the ensemble.
    #[inline]
    pub fn n_base_learners(&self) -> usize {
        self.base_learners.len()
    }

    /// Whether original features are passed through to the meta-learner.
    #[inline]
    pub fn passthrough(&self) -> bool {
        self.passthrough
    }

    /// Get predictions from each base learner for inspection.
    ///
    /// Returns a vector with one prediction per base learner, in the same
    /// order they were provided at construction time.
    #[inline]
    pub fn base_predictions(&self, features: &[f64]) -> Vec<f64> {
        self.base_learners
            .iter()
            .map(|learner| learner.predict(features))
            .collect()
    }

    /// Build the meta-feature vector from base predictions and optional original features.
    ///
    /// Layout: `[base1_pred, base2_pred, ..., baseK_pred, (original features if passthrough)]`
    fn build_meta_features(&self, features: &[f64], base_preds: &[f64]) -> Vec<f64> {
        if self.passthrough {
            let mut meta_features = Vec::with_capacity(base_preds.len() + features.len());
            meta_features.extend_from_slice(base_preds);
            meta_features.extend_from_slice(features);
            meta_features
        } else {
            base_preds.to_vec()
        }
    }
}

// ---------------------------------------------------------------------------
// StreamingLearner impl -- enables recursive stacking
// ---------------------------------------------------------------------------

impl StreamingLearner for StackedEnsemble {
    /// Train on a single weighted observation using temporal holdout.
    ///
    /// 1. Collect base predictions **before** training (temporal holdout).
    /// 2. Build meta-features and train the meta-learner on `(meta_features, target, weight)`.
    /// 3. Train each base learner on `(features, target, weight)`.
    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
        // Step 1: Collect pre-training predictions from base learners.
        let base_preds: Vec<f64> = self
            .base_learners
            .iter()
            .map(|learner| learner.predict(features))
            .collect();

        // Step 2: Build meta-features and train the meta-learner.
        let meta_features = self.build_meta_features(features, &base_preds);
        self.meta_learner.train_one(&meta_features, target, weight);

        // Step 3: Train base learners AFTER meta-learner has used their predictions.
        for learner in &mut self.base_learners {
            learner.train_one(features, target, weight);
        }

        self.samples_seen += 1;
    }

    /// Predict by collecting base predictions and passing them through the meta-learner.
    #[inline]
    fn predict(&self, features: &[f64]) -> f64 {
        let base_preds = self.base_predictions(features);
        let meta_features = self.build_meta_features(features, &base_preds);
        self.meta_learner.predict(&meta_features)
    }

    /// Total number of samples trained on since creation or last reset.
    #[inline]
    fn n_samples_seen(&self) -> u64 {
        self.samples_seen
    }

    /// Reset all base learners, the meta-learner, and the sample counter.
    fn reset(&mut self) {
        for learner in &mut self.base_learners {
            learner.reset();
        }
        self.meta_learner.reset();
        self.samples_seen = 0;
    }

    /// Aggregate diagnostics from the first base learner as representative signal.
    #[allow(deprecated)]
    fn diagnostics_array(&self) -> [f64; 5] {
        if let Some(first) = self.base_learners.first() {
            first.diagnostics_array()
        } else {
            [0.0; 5]
        }
    }

    /// Sum of replacement counts across all base learners and the meta-learner.
    #[allow(deprecated)]
    fn replacement_count(&self) -> u64 {
        self.base_learners
            .iter()
            .map(|l| l.replacement_count())
            .sum::<u64>()
            + self.meta_learner.replacement_count()
    }

    /// Forward config adjustments to all base learners and the meta-learner.
    #[allow(deprecated)]
    fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
        for learner in &mut self.base_learners {
            learner.adjust_config(lr_multiplier, lambda_delta);
        }
        self.meta_learner.adjust_config(lr_multiplier, lambda_delta);
    }
}

// ---------------------------------------------------------------------------
// Debug impl -- manual since Box<dyn StreamingLearner> does not impl Debug
// ---------------------------------------------------------------------------

impl fmt::Debug for StackedEnsemble {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("StackedEnsemble")
            .field("n_base_learners", &self.base_learners.len())
            .field("passthrough", &self.passthrough)
            .field("samples_seen", &self.samples_seen)
            .finish()
    }
}

// ---------------------------------------------------------------------------
// DiagnosticSource impl
// ---------------------------------------------------------------------------

impl crate::automl::DiagnosticSource for StackedEnsemble {
    fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
        None
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use crate::learner::SGBTLearner;
    use crate::learners::linear::StreamingLinearModel;
    use crate::SGBTConfig;

    /// Shared minimal SGBT config for tests.
    fn test_config() -> SGBTConfig {
        SGBTConfig::builder()
            .n_steps(5)
            .learning_rate(0.1)
            .grace_period(10)
            .max_depth(3)
            .n_bins(8)
            .build()
            .unwrap()
    }

    /// Create a pair of SGBT base learners as trait objects.
    fn sgbt_bases() -> Vec<Box<dyn StreamingLearner>> {
        vec![
            Box::new(SGBTLearner::from_config(test_config())),
            Box::new(SGBTLearner::from_config(test_config())),
        ]
    }

    /// Create a linear meta-learner as a trait object.
    fn linear_meta() -> Box<dyn StreamingLearner> {
        Box::new(StreamingLinearModel::new(0.01))
    }

    #[test]
    fn test_creation() {
        let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
        assert_eq!(stack.n_base_learners(), 2);
        assert!(!stack.passthrough());
        assert_eq!(stack.n_samples_seen(), 0);
    }

    #[test]
    fn test_train_and_predict() {
        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());

        // Train on a simple pattern.
        for i in 0..50 {
            let x = i as f64 * 0.1;
            stack.train(&[x, x * 2.0], x * 3.0);
        }

        assert_eq!(stack.n_samples_seen(), 50);

        // Prediction should be finite and non-trivial after training.
        let pred = stack.predict(&[1.0, 2.0]);
        assert!(
            pred.is_finite(),
            "prediction should be finite, got {}",
            pred
        );
    }

    #[test]
    fn test_temporal_holdout() {
        // Verify that the meta-learner sees pre-training predictions by
        // checking that base learner sample counts advance correctly:
        // after training the stack once, each base should have seen 1 sample.
        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());

        // Before training, base learners have seen 0 samples.
        for bp in &stack.base_learners {
            assert_eq!(bp.n_samples_seen(), 0);
        }

        // Train one sample through the stack.
        stack.train(&[1.0, 2.0], 3.0);

        // After training, each base learner has seen exactly 1 sample.
        // The temporal holdout guarantee is that the meta-learner was trained
        // on predictions made *before* this sample was ingested by the bases.
        for bp in &stack.base_learners {
            assert_eq!(bp.n_samples_seen(), 1);
        }
        assert_eq!(stack.meta_learner.n_samples_seen(), 1);
        assert_eq!(stack.n_samples_seen(), 1);

        // Train a second sample and verify counts advance together.
        stack.train(&[3.0, 4.0], 5.0);
        for bp in &stack.base_learners {
            assert_eq!(bp.n_samples_seen(), 2);
        }
        assert_eq!(stack.meta_learner.n_samples_seen(), 2);
        assert_eq!(stack.n_samples_seen(), 2);
    }

    #[test]
    fn test_passthrough() {
        // With passthrough=true, meta-features should include original features.
        // We can verify this indirectly: a passthrough stack with a linear meta
        // should produce a different prediction than a non-passthrough stack,
        // because the meta-learner sees a wider feature vector.
        let bases_a = sgbt_bases();
        let bases_b = sgbt_bases();

        let mut no_pass = StackedEnsemble::new(bases_a, linear_meta());
        let mut with_pass = StackedEnsemble::with_passthrough(bases_b, linear_meta(), true);

        assert!(!no_pass.passthrough());
        assert!(with_pass.passthrough());

        // Train both on the same data.
        for i in 0..30 {
            let x = i as f64 * 0.1;
            let features = [x, x * 2.0];
            let target = x * 3.0 + 1.0;
            no_pass.train(&features, target);
            with_pass.train(&features, target);
        }

        // Verify meta-feature dimensions differ by checking that build_meta_features
        // produces different-length vectors.
        let features = [1.0, 2.0];
        let base_preds = [0.5, 0.7]; // mock base predictions
        let meta_no = no_pass.build_meta_features(&features, &base_preds);
        let meta_yes = with_pass.build_meta_features(&features, &base_preds);

        assert_eq!(meta_no.len(), 2, "no passthrough: only base predictions");
        assert_eq!(
            meta_yes.len(),
            4,
            "passthrough: base predictions + original features"
        );
        assert!(
            (meta_yes[2] - 1.0).abs() < 1e-12,
            "original features appended"
        );
        assert!(
            (meta_yes[3] - 2.0).abs() < 1e-12,
            "original features appended"
        );
    }

    #[test]
    fn test_base_predictions() {
        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());

        // Before training, base predictions should all be zero (untrained models).
        let preds = stack.base_predictions(&[1.0, 2.0]);
        assert_eq!(preds.len(), 2);
        for p in &preds {
            assert!(
                p.abs() < 1e-12,
                "untrained base should predict ~0, got {}",
                p
            );
        }

        // Train a few samples.
        for i in 0..20 {
            let x = i as f64;
            stack.train(&[x, x * 0.5], x * 2.0);
        }

        // Base predictions should still return the correct count.
        let preds_after = stack.base_predictions(&[5.0, 2.5]);
        assert_eq!(preds_after.len(), 2);
        for p in &preds_after {
            assert!(p.is_finite());
        }
    }

    #[test]
    fn test_reset() {
        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());

        // Train some data.
        for i in 0..30 {
            let x = i as f64 * 0.1;
            stack.train(&[x, x * 2.0], x * 3.0);
        }
        assert_eq!(stack.n_samples_seen(), 30);

        // Reset everything.
        stack.reset();
        assert_eq!(stack.n_samples_seen(), 0);

        // All base learners should be reset.
        for bp in &stack.base_learners {
            assert_eq!(bp.n_samples_seen(), 0);
        }

        // Meta-learner should be reset.
        assert_eq!(stack.meta_learner.n_samples_seen(), 0);

        // Predictions after reset should be near zero (untrained state).
        let pred = stack.predict(&[1.0, 2.0]);
        assert!(
            pred.abs() < 1e-12,
            "prediction after reset should be ~0, got {}",
            pred,
        );
    }

    #[test]
    fn test_n_samples_seen() {
        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());

        assert_eq!(stack.n_samples_seen(), 0);

        for i in 1..=10 {
            stack.train(&[i as f64], i as f64);
            assert_eq!(stack.n_samples_seen(), i);
        }

        // Weighted training also increments by 1 (sample count, not weight sum).
        stack.train_one(&[11.0], 11.0, 5.0);
        assert_eq!(stack.n_samples_seen(), 11);
    }

    #[test]
    fn test_trait_object() {
        // StackedEnsemble itself should work as Box<dyn StreamingLearner>,
        // enabling recursive stacking.
        let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
        let mut boxed: Box<dyn StreamingLearner> = Box::new(stack);

        boxed.train(&[1.0, 2.0], 3.0);
        assert_eq!(boxed.n_samples_seen(), 1);

        let pred = boxed.predict(&[1.0, 2.0]);
        assert!(pred.is_finite());

        boxed.reset();
        assert_eq!(boxed.n_samples_seen(), 0);
    }

    #[test]
    fn test_heterogeneous_bases() {
        // Mix SGBT and linear base learners -- the core polymorphism use case.
        let bases: Vec<Box<dyn StreamingLearner>> = vec![
            Box::new(SGBTLearner::from_config(test_config())),
            Box::new(StreamingLinearModel::new(0.01)),
            Box::new(StreamingLinearModel::ridge(0.01, 0.001)),
        ];
        let meta = linear_meta();

        let mut stack = StackedEnsemble::new(bases, meta);
        assert_eq!(stack.n_base_learners(), 3);

        // Train on a linear-ish pattern. Both SGBT and linear models should
        // contribute meaningful predictions.
        for i in 0..40 {
            let x = i as f64 * 0.1;
            stack.train(&[x, x * 0.5], 2.0 * x + 1.0);
        }

        assert_eq!(stack.n_samples_seen(), 40);

        let preds = stack.base_predictions(&[2.0, 1.0]);
        assert_eq!(preds.len(), 3);
        for p in &preds {
            assert!(p.is_finite(), "base prediction should be finite, got {}", p);
        }

        let final_pred = stack.predict(&[2.0, 1.0]);
        assert!(final_pred.is_finite());
    }

    #[test]
    fn test_predict_batch() {
        let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());

        // Train enough samples for non-trivial predictions.
        for i in 0..30 {
            let x = i as f64 * 0.1;
            stack.train(&[x, x * 2.0], x * 3.0);
        }

        let rows: Vec<&[f64]> = vec![&[0.5, 1.0], &[1.5, 3.0], &[2.5, 5.0]];
        let batch = stack.predict_batch(&rows);

        // Batch results should exactly match individual predictions.
        assert_eq!(batch.len(), rows.len());
        for (i, row) in rows.iter().enumerate() {
            let individual = stack.predict(row);
            assert!(
                (batch[i] - individual).abs() < 1e-12,
                "batch[{}]={} != individual={}",
                i,
                batch[i],
                individual,
            );
        }
    }

    #[test]
    fn test_debug_impl() {
        let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
        let debug_str = format!("{:?}", stack);
        assert!(debug_str.contains("StackedEnsemble"));
        assert!(debug_str.contains("n_base_learners: 2"));
        assert!(debug_str.contains("passthrough: false"));
        assert!(debug_str.contains("samples_seen: 0"));
    }
}