use crate::{Model, TrainError, TrainResult};
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
pub trait Ensemble {
fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
fn num_models(&self) -> usize;
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum VotingMode {
Hard,
Soft,
}
#[derive(Debug)]
pub struct VotingEnsemble<M: Model> {
models: Vec<M>,
mode: VotingMode,
weights: Option<Vec<f64>>,
}
impl<M: Model> VotingEnsemble<M> {
pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
if models.is_empty() {
return Err(TrainError::InvalidParameter(
"Ensemble must have at least one model".to_string(),
));
}
Ok(Self {
models,
mode,
weights: None,
})
}
pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
if weights.len() != self.models.len() {
return Err(TrainError::InvalidParameter(
"Number of weights must match number of models".to_string(),
));
}
let sum: f64 = weights.iter().sum();
if (sum - 1.0).abs() > 1e-6 {
return Err(TrainError::InvalidParameter(
"Weights must sum to 1.0".to_string(),
));
}
self.weights = Some(weights);
Ok(self)
}
pub fn mode(&self) -> VotingMode {
self.mode
}
}
impl<M: Model> Ensemble for VotingEnsemble<M> {
fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
let batch_size = input.nrows();
let mut all_predictions = Vec::with_capacity(self.models.len());
for model in &self.models {
let pred = model.forward(&input.view())?;
all_predictions.push(pred);
}
let num_classes = all_predictions[0].ncols();
let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
match self.mode {
VotingMode::Hard => {
for i in 0..batch_size {
let mut votes = vec![0.0; num_classes];
for (model_idx, pred) in all_predictions.iter().enumerate() {
let row = pred.row(i);
let class_idx = row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.unwrap_or(0);
let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
votes[class_idx] += weight;
}
let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let winning_class = votes
.iter()
.position(|&v| (v - max_votes).abs() < 1e-10)
.unwrap_or(0);
ensemble_pred[[i, winning_class]] = 1.0;
}
}
VotingMode::Soft => {
for i in 0..batch_size {
for j in 0..num_classes {
let mut weighted_sum = 0.0;
for (model_idx, pred) in all_predictions.iter().enumerate() {
let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
weighted_sum += pred[[i, j]] * weight;
}
let normalizer = if self.weights.is_some() {
1.0 } else {
self.models.len() as f64
};
ensemble_pred[[i, j]] = weighted_sum / normalizer;
}
}
}
}
Ok(ensemble_pred)
}
fn num_models(&self) -> usize {
self.models.len()
}
}
#[derive(Debug)]
pub struct AveragingEnsemble<M: Model> {
models: Vec<M>,
weights: Option<Vec<f64>>,
}
impl<M: Model> AveragingEnsemble<M> {
pub fn new(models: Vec<M>) -> TrainResult<Self> {
if models.is_empty() {
return Err(TrainError::InvalidParameter(
"Ensemble must have at least one model".to_string(),
));
}
Ok(Self {
models,
weights: None,
})
}
pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
if weights.len() != self.models.len() {
return Err(TrainError::InvalidParameter(
"Number of weights must match number of models".to_string(),
));
}
let sum: f64 = weights.iter().sum();
if sum <= 0.0 {
return Err(TrainError::InvalidParameter(
"Weights must sum to a positive value".to_string(),
));
}
let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
self.weights = Some(normalized_weights);
Ok(self)
}
}
impl<M: Model> Ensemble for AveragingEnsemble<M> {
fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
let mut all_predictions = Vec::with_capacity(self.models.len());
for model in &self.models {
let pred = model.forward(&input.view())?;
all_predictions.push(pred);
}
let shape = all_predictions[0].raw_dim();
let mut ensemble_pred = Array2::zeros(shape);
for (model_idx, pred) in all_predictions.iter().enumerate() {
let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
for i in 0..pred.nrows() {
for j in 0..pred.ncols() {
ensemble_pred[[i, j]] += pred[[i, j]] * weight;
}
}
}
if self.weights.is_none() {
ensemble_pred /= self.models.len() as f64;
}
Ok(ensemble_pred)
}
fn num_models(&self) -> usize {
self.models.len()
}
}
#[derive(Debug)]
pub struct StackingEnsemble<M: Model, Meta: Model> {
base_models: Vec<M>,
meta_model: Meta,
}
impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
if base_models.is_empty() {
return Err(TrainError::InvalidParameter(
"Ensemble must have at least one base model".to_string(),
));
}
Ok(Self {
base_models,
meta_model,
})
}
pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
let batch_size = input.nrows();
let mut all_predictions = Vec::with_capacity(self.base_models.len());
for model in &self.base_models {
let pred = model.forward(&input.view())?;
all_predictions.push(pred);
}
let num_features_per_model = all_predictions[0].ncols();
let total_features = self.base_models.len() * num_features_per_model;
let mut meta_features = Array2::zeros((batch_size, total_features));
for (model_idx, pred) in all_predictions.iter().enumerate() {
let start_col = model_idx * num_features_per_model;
for i in 0..batch_size {
for j in 0..num_features_per_model {
meta_features[[i, start_col + j]] = pred[[i, j]];
}
}
}
Ok(meta_features)
}
}
impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
let meta_features = self.generate_meta_features(input)?;
self.meta_model.forward(&meta_features.view())
}
fn num_models(&self) -> usize {
self.base_models.len() + 1 }
}
#[derive(Debug)]
pub struct BaggingHelper {
pub n_estimators: usize,
pub random_seed: u64,
}
impl BaggingHelper {
pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
if n_estimators == 0 {
return Err(TrainError::InvalidParameter(
"n_estimators must be positive".to_string(),
));
}
Ok(Self {
n_estimators,
random_seed,
})
}
pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
#[allow(unused_imports)]
use scirs2_core::random::{Rng, SeedableRng, StdRng};
let seed = self.random_seed.wrapping_add(estimator_idx as u64);
let mut rng = StdRng::seed_from_u64(seed);
(0..n_samples)
.map(|_| rng.gen_range(0..n_samples))
.collect()
}
pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
let bootstrap_set: std::collections::HashSet<usize> =
bootstrap_indices.iter().cloned().collect();
(0..n_samples)
.filter(|idx| !bootstrap_set.contains(idx))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct ModelSoup {
weights: HashMap<String, Array2<f64>>,
num_models: usize,
recipe: SoupRecipe,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SoupRecipe {
Uniform,
Greedy,
Weighted,
}
impl ModelSoup {
pub fn uniform_soup(model_weights: Vec<HashMap<String, Array2<f64>>>) -> TrainResult<Self> {
if model_weights.is_empty() {
return Err(TrainError::InvalidParameter(
"At least one model required for soup".to_string(),
));
}
let num_models = model_weights.len();
let mut averaged_weights = HashMap::new();
let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
for param_name in param_names {
let shape = model_weights[0][¶m_name].raw_dim();
let mut averaged_param = Array2::zeros(shape);
for model_weight in &model_weights {
if let Some(param) = model_weight.get(¶m_name) {
averaged_param += param;
} else {
return Err(TrainError::InvalidParameter(format!(
"Parameter '{}' not found in all models",
param_name
)));
}
}
averaged_param /= num_models as f64;
averaged_weights.insert(param_name, averaged_param);
}
Ok(Self {
weights: averaged_weights,
num_models,
recipe: SoupRecipe::Uniform,
})
}
pub fn greedy_soup(
model_weights: Vec<HashMap<String, Array2<f64>>>,
val_accuracies: Vec<f64>,
) -> TrainResult<Self> {
if model_weights.is_empty() {
return Err(TrainError::InvalidParameter(
"At least one model required for soup".to_string(),
));
}
if model_weights.len() != val_accuracies.len() {
return Err(TrainError::InvalidParameter(
"Number of models must match number of validation accuracies".to_string(),
));
}
let best_idx = val_accuracies
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
let mut soup_indices = vec![best_idx];
let mut best_accuracy = val_accuracies[best_idx];
loop {
let mut improved = false;
let mut best_addition = None;
let mut best_new_accuracy = best_accuracy;
for (idx, acc) in val_accuracies.iter().enumerate() {
if soup_indices.contains(&idx) {
continue;
}
let potential_accuracy = (*acc + best_accuracy) / 2.0;
if potential_accuracy > best_new_accuracy {
best_new_accuracy = potential_accuracy;
best_addition = Some(idx);
improved = true;
}
}
if improved {
if let Some(idx) = best_addition {
soup_indices.push(idx);
best_accuracy = best_new_accuracy;
} else {
break;
}
} else {
break;
}
}
let selected_weights: Vec<_> = soup_indices
.iter()
.map(|&idx| model_weights[idx].clone())
.collect();
let mut soup = Self::uniform_soup(selected_weights)?;
soup.recipe = SoupRecipe::Greedy;
soup.num_models = soup_indices.len();
Ok(soup)
}
pub fn weighted_soup(
model_weights: Vec<HashMap<String, Array2<f64>>>,
weights: Vec<f64>,
) -> TrainResult<Self> {
if model_weights.is_empty() {
return Err(TrainError::InvalidParameter(
"At least one model required for soup".to_string(),
));
}
if model_weights.len() != weights.len() {
return Err(TrainError::InvalidParameter(
"Number of models must match number of weights".to_string(),
));
}
let sum: f64 = weights.iter().sum();
if sum <= 0.0 {
return Err(TrainError::InvalidParameter(
"Weights must sum to positive value".to_string(),
));
}
let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
let num_models = model_weights.len();
let mut averaged_weights = HashMap::new();
let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
for param_name in param_names {
let shape = model_weights[0][¶m_name].raw_dim();
let mut averaged_param = Array2::zeros(shape);
for (model_idx, model_weight) in model_weights.iter().enumerate() {
if let Some(param) = model_weight.get(¶m_name) {
averaged_param = averaged_param + param * normalized_weights[model_idx];
} else {
return Err(TrainError::InvalidParameter(format!(
"Parameter '{}' not found in all models",
param_name
)));
}
}
averaged_weights.insert(param_name, averaged_param);
}
Ok(Self {
weights: averaged_weights,
num_models,
recipe: SoupRecipe::Weighted,
})
}
pub fn weights(&self) -> &HashMap<String, Array2<f64>> {
&self.weights
}
pub fn num_models(&self) -> usize {
self.num_models
}
pub fn recipe(&self) -> SoupRecipe {
self.recipe
}
pub fn get_parameter(&self, name: &str) -> Option<&Array2<f64>> {
self.weights.get(name)
}
pub fn into_weights(self) -> HashMap<String, Array2<f64>> {
self.weights
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::LinearModel;
use scirs2_core::ndarray::array;
fn create_test_model() -> LinearModel {
LinearModel::new(2, 2)
}
#[test]
fn test_voting_ensemble_hard() {
let model1 = create_test_model();
let model2 = create_test_model();
let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).expect("unwrap");
assert_eq!(ensemble.num_models(), 2);
assert_eq!(ensemble.mode(), VotingMode::Hard);
let input = array![[1.0, 0.0], [0.0, 1.0]];
let pred = ensemble.predict(&input).expect("unwrap");
assert_eq!(pred.shape(), &[2, 2]);
}
#[test]
fn test_voting_ensemble_soft() {
let model1 = create_test_model();
let model2 = create_test_model();
let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).expect("unwrap");
let input = array![[1.0, 0.0]];
let pred = ensemble.predict(&input).expect("unwrap");
assert_eq!(pred.shape(), &[1, 2]);
}
#[test]
fn test_voting_ensemble_with_weights() {
let model1 = create_test_model();
let model2 = create_test_model();
let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
.expect("unwrap")
.with_weights(vec![0.7, 0.3])
.expect("unwrap");
let input = array![[1.0, 0.0]];
let pred = ensemble.predict(&input).expect("unwrap");
assert_eq!(pred.shape(), &[1, 2]);
}
#[test]
fn test_voting_ensemble_invalid_weights() {
let model1 = create_test_model();
let model2 = create_test_model();
let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).expect("unwrap");
let result = ensemble.with_weights(vec![0.5]);
assert!(result.is_err());
let model3 = create_test_model();
let model4 = create_test_model();
let ensemble2 =
VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).expect("unwrap");
let result = ensemble2.with_weights(vec![0.5, 0.6]);
assert!(result.is_err());
}
#[test]
fn test_averaging_ensemble() {
let model1 = create_test_model();
let model2 = create_test_model();
let ensemble = AveragingEnsemble::new(vec![model1, model2]).expect("unwrap");
assert_eq!(ensemble.num_models(), 2);
let input = array![[1.0, 0.0], [0.0, 1.0]];
let pred = ensemble.predict(&input).expect("unwrap");
assert_eq!(pred.shape(), &[2, 2]);
}
#[test]
fn test_averaging_ensemble_with_weights() {
let model1 = create_test_model();
let model2 = create_test_model();
let ensemble = AveragingEnsemble::new(vec![model1, model2])
.expect("unwrap")
.with_weights(vec![2.0, 1.0])
.expect("unwrap");
let input = array![[1.0, 0.0]];
let pred = ensemble.predict(&input).expect("unwrap");
assert_eq!(pred.shape(), &[1, 2]);
}
#[test]
fn test_stacking_ensemble() {
let base1 = create_test_model(); let base2 = create_test_model(); let meta = LinearModel::new(4, 2);
let ensemble = StackingEnsemble::new(vec![base1, base2], meta).expect("unwrap");
assert_eq!(ensemble.num_models(), 3);
let input = array![[1.0, 0.0]];
let pred = ensemble.predict(&input).expect("unwrap");
assert_eq!(pred.nrows(), 1);
}
#[test]
fn test_stacking_meta_features() {
let base1 = create_test_model();
let base2 = create_test_model();
let meta = create_test_model();
let ensemble = StackingEnsemble::new(vec![base1, base2], meta).expect("unwrap");
let input = array![[1.0, 0.0]];
let meta_features = ensemble.generate_meta_features(&input).expect("unwrap");
assert_eq!(meta_features.shape(), &[1, 4]);
}
#[test]
fn test_bagging_helper() {
let helper = BaggingHelper::new(10, 42).expect("unwrap");
let indices = helper.generate_bootstrap_indices(100, 0);
assert_eq!(indices.len(), 100);
assert!(indices.iter().all(|&i| i < 100));
let oob = helper.get_oob_indices(100, &indices);
assert!(!oob.is_empty());
for &idx in &oob {
assert!(!indices.contains(&idx));
}
}
#[test]
fn test_bagging_helper_different_seeds() {
let helper = BaggingHelper::new(10, 42).expect("unwrap");
let indices1 = helper.generate_bootstrap_indices(50, 0);
let indices2 = helper.generate_bootstrap_indices(50, 1);
assert_ne!(indices1, indices2);
}
#[test]
fn test_bagging_helper_invalid() {
assert!(BaggingHelper::new(0, 42).is_err());
}
#[test]
fn test_ensemble_empty_models() {
let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
assert!(result.is_err());
let result = AveragingEnsemble::<LinearModel>::new(vec![]);
assert!(result.is_err());
}
#[test]
fn test_uniform_soup() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0, 2.0]]);
weights1.insert("b".to_string(), array![[0.5]]);
let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[3.0, 4.0]]);
weights2.insert("b".to_string(), array![[1.5]]);
let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
assert_eq!(soup.num_models(), 2);
assert_eq!(soup.recipe(), SoupRecipe::Uniform);
let w = soup.get_parameter("w").expect("unwrap");
assert_eq!(w[[0, 0]], 2.0); assert_eq!(w[[0, 1]], 3.0);
let b = soup.get_parameter("b").expect("unwrap");
assert_eq!(b[[0, 0]], 1.0); }
#[test]
fn test_uniform_soup_three_models() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0]]);
let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[2.0]]);
let mut weights3 = HashMap::new();
weights3.insert("w".to_string(), array![[3.0]]);
let soup = ModelSoup::uniform_soup(vec![weights1, weights2, weights3]).expect("unwrap");
let w = soup.get_parameter("w").expect("unwrap");
assert_eq!(w[[0, 0]], 2.0); }
#[test]
fn test_greedy_soup() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0]]);
let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[2.0]]);
let mut weights3 = HashMap::new();
weights3.insert("w".to_string(), array![[3.0]]);
let accuracies = vec![0.8, 0.9, 0.85];
let soup =
ModelSoup::greedy_soup(vec![weights1, weights2, weights3], accuracies).expect("unwrap");
assert_eq!(soup.recipe(), SoupRecipe::Greedy);
assert!(soup.num_models() >= 1); }
#[test]
fn test_weighted_soup() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0, 2.0]]);
let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[3.0, 4.0]]);
let soup =
ModelSoup::weighted_soup(vec![weights1, weights2], vec![2.0, 1.0]).expect("unwrap");
assert_eq!(soup.recipe(), SoupRecipe::Weighted);
let w = soup.get_parameter("w").expect("unwrap");
assert!((w[[0, 0]] - 1.6666666).abs() < 1e-5);
assert!((w[[0, 1]] - 2.6666666).abs() < 1e-5);
}
#[test]
fn test_soup_empty_models() {
let result = ModelSoup::uniform_soup(vec![]);
assert!(result.is_err());
}
#[test]
fn test_soup_mismatched_parameters() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0]]);
let mut weights2 = HashMap::new();
weights2.insert("b".to_string(), array![[2.0]]);
let result = ModelSoup::uniform_soup(vec![weights1, weights2]);
assert!(result.is_err());
}
#[test]
fn test_greedy_soup_mismatched_lengths() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0]]);
let result = ModelSoup::greedy_soup(vec![weights1], vec![0.8, 0.9]);
assert!(result.is_err());
}
#[test]
fn test_weighted_soup_invalid_weights() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0]]);
let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[2.0]]);
let result =
ModelSoup::weighted_soup(vec![weights1.clone(), weights2.clone()], vec![-1.0, 1.0]);
assert!(result.is_err());
let result = ModelSoup::weighted_soup(vec![weights1], vec![1.0, 2.0]);
assert!(result.is_err());
}
#[test]
fn test_soup_into_weights() {
let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0]]);
let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[3.0]]);
let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
let final_weights = soup.into_weights();
assert_eq!(final_weights["w"][[0, 0]], 2.0);
}
#[test]
fn test_soup_multidimensional_weights() {
let mut weights1 = HashMap::new();
weights1.insert("conv".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
let mut weights2 = HashMap::new();
weights2.insert("conv".to_string(), array![[5.0, 6.0], [7.0, 8.0]]);
let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
let conv = soup.get_parameter("conv").expect("unwrap");
assert_eq!(conv[[0, 0]], 3.0); assert_eq!(conv[[0, 1]], 4.0); assert_eq!(conv[[1, 0]], 5.0); assert_eq!(conv[[1, 1]], 6.0); }
}