use super::multi_seed::{MultiSeedConfig, MultiSeedTrainer, TrainedMember};
use super::selection::{HillClimbingSelector, SelectionConfig};
use super::stacking::{RidgeStacker, SimpleAverageStacker, StackingConfig};
use super::traits::Stacker;
use crate::booster::GBDTConfig;
use crate::dataset::BinnedDataset;
use crate::tuner::Metric;
use crate::Result;
#[derive(Debug, Clone)]
pub struct EnsembleStats {
pub n_members: usize,
pub weights: Option<Vec<f32>>,
pub oof_metric: f32,
pub member_metrics: Vec<f32>,
pub best_individual: f32,
pub improvement: f32,
}
pub struct StackedEnsemble {
members: Vec<TrainedMember>,
stacker: Box<dyn Stacker>,
oof_metric: f32,
metric: Metric,
}
impl StackedEnsemble {
pub fn new(
members: Vec<TrainedMember>,
stacker: Box<dyn Stacker>,
oof_metric: f32,
metric: Metric,
) -> Self {
Self {
members,
stacker,
oof_metric,
metric,
}
}
pub fn n_members(&self) -> usize {
self.members.len()
}
pub fn weights(&self) -> Option<&[f32]> {
self.stacker.weights()
}
pub fn oof_metric(&self) -> f32 {
self.oof_metric
}
pub fn stats(&self) -> EnsembleStats {
let member_metrics: Vec<f32> = self.members.iter().map(|m| m.oof_metric).collect();
let best_individual = if self.metric.lower_is_better() {
member_metrics
.iter()
.cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(f32::INFINITY)
} else {
member_metrics
.iter()
.cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(f32::NEG_INFINITY)
};
let improvement = if self.metric.lower_is_better() {
best_individual - self.oof_metric
} else {
self.oof_metric - best_individual
};
EnsembleStats {
n_members: self.members.len(),
weights: self.stacker.weights().map(|w| w.to_vec()),
oof_metric: self.oof_metric,
member_metrics,
best_individual,
improvement,
}
}
pub fn predict(&self, dataset: &BinnedDataset) -> Vec<f32> {
if self.members.is_empty() {
return vec![0.0; dataset.num_rows()];
}
let predictions: Vec<Vec<f32>> = self
.members
.iter()
.map(|m| m.model.predict(dataset))
.collect();
self.stacker.combine(&predictions)
}
pub fn predict_raw(&self, features: &[f64]) -> Vec<f32> {
if self.members.is_empty() {
return Vec::new();
}
let predictions: Vec<Vec<f32>> = self
.members
.iter()
.map(|m| m.model.predict_raw(features))
.collect();
self.stacker.combine(&predictions)
}
pub fn members(&self) -> &[TrainedMember] {
&self.members
}
}
pub struct EnsembleBuilder {
base_config: GBDTConfig,
multi_seed: MultiSeedConfig,
selection: SelectionConfig,
stacking: StackingConfig,
metric: Option<Metric>,
use_ridge: bool,
}
impl EnsembleBuilder {
pub fn new(base_config: GBDTConfig) -> Self {
Self {
base_config,
multi_seed: MultiSeedConfig::default(),
selection: SelectionConfig::default(),
stacking: StackingConfig::default(),
metric: None,
use_ridge: true,
}
}
pub fn with_n_seeds(mut self, n: usize) -> Self {
self.multi_seed.n_seeds = n;
self
}
pub fn with_base_seed(mut self, seed: u64) -> Self {
self.multi_seed.base_seed = seed;
self
}
pub fn with_n_folds(mut self, n: usize) -> Self {
self.multi_seed.n_folds = n;
self
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.multi_seed.parallel = parallel;
self
}
pub fn with_max_models(mut self, max: usize) -> Self {
self.selection.max_models = max;
self
}
pub fn with_min_improvement(mut self, min: f32) -> Self {
self.selection.min_improvement = min;
self
}
pub fn with_ridge_alpha(mut self, alpha: f32) -> Self {
self.stacking.alpha = alpha;
self
}
pub fn with_rank_transform(mut self, enabled: bool) -> Self {
self.stacking.rank_transform = enabled;
self
}
pub fn with_simple_average(mut self) -> Self {
self.use_ridge = false;
self
}
pub fn with_metric(mut self, metric: Metric) -> Self {
self.metric = Some(metric);
self
}
pub fn with_multi_seed_config(mut self, config: MultiSeedConfig) -> Self {
self.multi_seed = config;
self
}
pub fn with_selection_config(mut self, config: SelectionConfig) -> Self {
self.selection = config;
self
}
pub fn with_stacking_config(mut self, config: StackingConfig) -> Self {
self.stacking = config;
self
}
pub fn build(self, dataset: &BinnedDataset) -> Result<StackedEnsemble> {
let targets = dataset.targets();
let metric = self.metric.unwrap_or({
match self.base_config.loss_type {
crate::booster::LossType::BinaryLogLoss => Metric::BinaryLogLoss,
crate::booster::LossType::MultiClassLogLoss { num_classes } => {
Metric::MultiClassLogLoss {
n_classes: num_classes,
}
}
_ => Metric::Mse,
}
});
let trainer =
MultiSeedTrainer::new(self.base_config.clone(), self.multi_seed).with_metric(metric);
let all_members = trainer.train(dataset)?;
let selector = HillClimbingSelector::new(self.selection, metric);
let selected_indices = selector.select(&all_members, targets);
let members: Vec<TrainedMember> = if selected_indices.is_empty() {
all_members
} else {
selected_indices
.iter()
.map(|&i| all_members[i].clone())
.collect()
};
let oof_preds: Vec<Vec<f32>> = members.iter().map(|m| m.oof_preds.clone()).collect();
let mut stacker: Box<dyn Stacker> = if self.use_ridge {
Box::new(RidgeStacker::new(self.stacking))
} else {
Box::new(SimpleAverageStacker::new())
};
stacker.fit(&oof_preds, targets);
let blended = stacker.combine(&oof_preds);
let oof_metric = metric.compute(&blended, targets);
Ok(StackedEnsemble::new(members, stacker, oof_metric, metric))
}
pub fn build_from_members(
self,
members: Vec<TrainedMember>,
targets: &[f32],
) -> Result<StackedEnsemble> {
let metric = self.metric.unwrap_or(Metric::Mse);
let selector = HillClimbingSelector::new(self.selection, metric);
let selected_indices = selector.select(&members, targets);
let selected: Vec<TrainedMember> = if selected_indices.is_empty() {
members
} else {
selected_indices
.iter()
.map(|&i| members[i].clone())
.collect()
};
let oof_preds: Vec<Vec<f32>> = selected.iter().map(|m| m.oof_preds.clone()).collect();
let mut stacker: Box<dyn Stacker> = if self.use_ridge {
Box::new(RidgeStacker::new(self.stacking))
} else {
Box::new(SimpleAverageStacker::new())
};
stacker.fit(&oof_preds, targets);
let blended = stacker.combine(&oof_preds);
let oof_metric = metric.compute(&blended, targets);
Ok(StackedEnsemble::new(selected, stacker, oof_metric, metric))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ensemble_builder_defaults() {
let config = GBDTConfig::new();
let builder = EnsembleBuilder::new(config);
assert_eq!(builder.multi_seed.n_seeds, 5);
assert!(builder.use_ridge);
}
#[test]
fn test_ensemble_builder_fluent_api() {
let config = GBDTConfig::new();
let builder = EnsembleBuilder::new(config)
.with_n_seeds(10)
.with_base_seed(123)
.with_ridge_alpha(5.0)
.with_max_models(20)
.with_rank_transform(true);
assert_eq!(builder.multi_seed.n_seeds, 10);
assert_eq!(builder.multi_seed.base_seed, 123);
assert!((builder.stacking.alpha - 5.0).abs() < 1e-6);
assert_eq!(builder.selection.max_models, 20);
assert!(builder.stacking.rank_transform);
}
#[test]
fn test_ensemble_builder_simple_average() {
let config = GBDTConfig::new();
let builder = EnsembleBuilder::new(config).with_simple_average();
assert!(!builder.use_ridge);
}
}