use alloc::boxed::Box;
use alloc::vec::Vec;
use core::fmt;
use crate::learner::StreamingLearner;
pub struct StackedEnsemble {
base_learners: Vec<Box<dyn StreamingLearner>>,
meta_learner: Box<dyn StreamingLearner>,
passthrough: bool,
samples_seen: u64,
}
impl StackedEnsemble {
#[inline]
pub fn new(
base_learners: Vec<Box<dyn StreamingLearner>>,
meta_learner: Box<dyn StreamingLearner>,
) -> Self {
Self {
base_learners,
meta_learner,
passthrough: false,
samples_seen: 0,
}
}
#[inline]
pub fn with_passthrough(
base_learners: Vec<Box<dyn StreamingLearner>>,
meta_learner: Box<dyn StreamingLearner>,
passthrough: bool,
) -> Self {
Self {
base_learners,
meta_learner,
passthrough,
samples_seen: 0,
}
}
#[inline]
pub fn n_base_learners(&self) -> usize {
self.base_learners.len()
}
#[inline]
pub fn passthrough(&self) -> bool {
self.passthrough
}
#[inline]
pub fn base_predictions(&self, features: &[f64]) -> Vec<f64> {
self.base_learners
.iter()
.map(|learner| learner.predict(features))
.collect()
}
fn build_meta_features(&self, features: &[f64], base_preds: &[f64]) -> Vec<f64> {
if self.passthrough {
let mut meta_features = Vec::with_capacity(base_preds.len() + features.len());
meta_features.extend_from_slice(base_preds);
meta_features.extend_from_slice(features);
meta_features
} else {
base_preds.to_vec()
}
}
}
impl StreamingLearner for StackedEnsemble {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let base_preds: Vec<f64> = self
.base_learners
.iter()
.map(|learner| learner.predict(features))
.collect();
let meta_features = self.build_meta_features(features, &base_preds);
self.meta_learner.train_one(&meta_features, target, weight);
for learner in &mut self.base_learners {
learner.train_one(features, target, weight);
}
self.samples_seen += 1;
}
#[inline]
fn predict(&self, features: &[f64]) -> f64 {
let base_preds = self.base_predictions(features);
let meta_features = self.build_meta_features(features, &base_preds);
self.meta_learner.predict(&meta_features)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
for learner in &mut self.base_learners {
learner.reset();
}
self.meta_learner.reset();
self.samples_seen = 0;
}
}
impl fmt::Debug for StackedEnsemble {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StackedEnsemble")
.field("n_base_learners", &self.base_learners.len())
.field("passthrough", &self.passthrough)
.field("samples_seen", &self.samples_seen)
.finish()
}
}
#[cfg(all(test, feature = "_stacked_tests_disabled"))]
mod tests {
use super::*;
use crate::learner::SGBTLearner;
use crate::learners::linear::StreamingLinearModel;
use crate::SGBTConfig;
fn test_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.n_bins(8)
.build()
.unwrap()
}
fn sgbt_bases() -> Vec<Box<dyn StreamingLearner>> {
vec![
Box::new(SGBTLearner::from_config(test_config())),
Box::new(SGBTLearner::from_config(test_config())),
]
}
fn linear_meta() -> Box<dyn StreamingLearner> {
Box::new(StreamingLinearModel::new(0.01))
}
#[test]
fn test_creation() {
let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
assert_eq!(stack.n_base_learners(), 2);
assert!(!stack.passthrough());
assert_eq!(stack.n_samples_seen(), 0);
}
#[test]
fn test_train_and_predict() {
let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
for i in 0..50 {
let x = i as f64 * 0.1;
stack.train(&[x, x * 2.0], x * 3.0);
}
assert_eq!(stack.n_samples_seen(), 50);
let pred = stack.predict(&[1.0, 2.0]);
assert!(
pred.is_finite(),
"prediction should be finite, got {}",
pred
);
}
#[test]
fn test_temporal_holdout() {
let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
for bp in &stack.base_learners {
assert_eq!(bp.n_samples_seen(), 0);
}
stack.train(&[1.0, 2.0], 3.0);
for bp in &stack.base_learners {
assert_eq!(bp.n_samples_seen(), 1);
}
assert_eq!(stack.meta_learner.n_samples_seen(), 1);
assert_eq!(stack.n_samples_seen(), 1);
stack.train(&[3.0, 4.0], 5.0);
for bp in &stack.base_learners {
assert_eq!(bp.n_samples_seen(), 2);
}
assert_eq!(stack.meta_learner.n_samples_seen(), 2);
assert_eq!(stack.n_samples_seen(), 2);
}
#[test]
fn test_passthrough() {
let bases_a = sgbt_bases();
let bases_b = sgbt_bases();
let mut no_pass = StackedEnsemble::new(bases_a, linear_meta());
let mut with_pass = StackedEnsemble::with_passthrough(bases_b, linear_meta(), true);
assert!(!no_pass.passthrough());
assert!(with_pass.passthrough());
for i in 0..30 {
let x = i as f64 * 0.1;
let features = [x, x * 2.0];
let target = x * 3.0 + 1.0;
no_pass.train(&features, target);
with_pass.train(&features, target);
}
let features = [1.0, 2.0];
let base_preds = [0.5, 0.7]; let meta_no = no_pass.build_meta_features(&features, &base_preds);
let meta_yes = with_pass.build_meta_features(&features, &base_preds);
assert_eq!(meta_no.len(), 2, "no passthrough: only base predictions");
assert_eq!(
meta_yes.len(),
4,
"passthrough: base predictions + original features"
);
assert!(
crate::math::abs((meta_yes[2] - 1.0)) < 1e-12,
"original features appended"
);
assert!(
crate::math::abs((meta_yes[3] - 2.0)) < 1e-12,
"original features appended"
);
}
#[test]
fn test_base_predictions() {
let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
let preds = stack.base_predictions(&[1.0, 2.0]);
assert_eq!(preds.len(), 2);
for p in &preds {
assert!(
crate::math::abs(p) < 1e-12,
"untrained base should predict ~0, got {}",
p
);
}
for i in 0..20 {
let x = i as f64;
stack.train(&[x, x * 0.5], x * 2.0);
}
let preds_after = stack.base_predictions(&[5.0, 2.5]);
assert_eq!(preds_after.len(), 2);
for p in &preds_after {
assert!(p.is_finite());
}
}
#[test]
fn test_reset() {
let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
for i in 0..30 {
let x = i as f64 * 0.1;
stack.train(&[x, x * 2.0], x * 3.0);
}
assert_eq!(stack.n_samples_seen(), 30);
stack.reset();
assert_eq!(stack.n_samples_seen(), 0);
for bp in &stack.base_learners {
assert_eq!(bp.n_samples_seen(), 0);
}
assert_eq!(stack.meta_learner.n_samples_seen(), 0);
let pred = stack.predict(&[1.0, 2.0]);
assert!(
crate::math::abs(pred) < 1e-12,
"prediction after reset should be ~0, got {}",
pred,
);
}
#[test]
fn test_n_samples_seen() {
let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
assert_eq!(stack.n_samples_seen(), 0);
for i in 1..=10 {
stack.train(&[i as f64], i as f64);
assert_eq!(stack.n_samples_seen(), i);
}
stack.train_one(&[11.0], 11.0, 5.0);
assert_eq!(stack.n_samples_seen(), 11);
}
#[test]
fn test_trait_object() {
let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
let mut boxed: Box<dyn StreamingLearner> = Box::new(stack);
boxed.train(&[1.0, 2.0], 3.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[1.0, 2.0]);
assert!(pred.is_finite());
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
#[test]
fn test_heterogeneous_bases() {
let bases: Vec<Box<dyn StreamingLearner>> = vec![
Box::new(SGBTLearner::from_config(test_config())),
Box::new(StreamingLinearModel::new(0.01)),
Box::new(StreamingLinearModel::ridge(0.01, 0.001)),
];
let meta = linear_meta();
let mut stack = StackedEnsemble::new(bases, meta);
assert_eq!(stack.n_base_learners(), 3);
for i in 0..40 {
let x = i as f64 * 0.1;
stack.train(&[x, x * 0.5], 2.0 * x + 1.0);
}
assert_eq!(stack.n_samples_seen(), 40);
let preds = stack.base_predictions(&[2.0, 1.0]);
assert_eq!(preds.len(), 3);
for p in &preds {
assert!(p.is_finite(), "base prediction should be finite, got {}", p);
}
let final_pred = stack.predict(&[2.0, 1.0]);
assert!(final_pred.is_finite());
}
#[test]
fn test_predict_batch() {
let mut stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
for i in 0..30 {
let x = i as f64 * 0.1;
stack.train(&[x, x * 2.0], x * 3.0);
}
let rows: Vec<&[f64]> = vec![&[0.5, 1.0], &[1.5, 3.0], &[2.5, 5.0]];
let batch = stack.predict_batch(&rows);
assert_eq!(batch.len(), rows.len());
for (i, row) in rows.iter().enumerate() {
let individual = stack.predict(row);
assert!(
crate::math::abs((batch[i] - individual)) < 1e-12,
"batch[{}]={} != individual={}",
i,
batch[i],
individual,
);
}
}
#[test]
fn test_debug_impl() {
let stack = StackedEnsemble::new(sgbt_bases(), linear_meta());
let debug_str = format!("{:?}", stack);
assert!(debug_str.contains("StackedEnsemble"));
assert!(debug_str.contains("n_base_learners: 2"));
assert!(debug_str.contains("passthrough: false"));
assert!(debug_str.contains("samples_seen: 0"));
}
}