use crate::AdaBoostValidParams;
use linfa::{
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned},
error::Error,
traits::*,
DatasetBase,
};
use ndarray::{Array1, Array2, Axis};
use ndarray_rand::rand::distributions::WeightedIndex;
use ndarray_rand::rand::prelude::*;
use ndarray_rand::rand::Rng;
use std::{cmp::Eq, collections::HashMap, hash::Hash};
const PERFECT_MODEL_WEIGHT: f64 = 1e6;
#[derive(Debug, Clone)]
pub struct AdaBoost<M, L> {
pub models: Vec<M>,
pub model_weights: Vec<f64>,
pub classes: Vec<L>,
}
impl<M, L> AdaBoost<M, L> {
pub fn n_estimators(&self) -> usize {
self.models.len()
}
pub fn weights(&self) -> &[f64] {
&self.model_weights
}
}
impl<F: Clone, T, M, L> PredictInplace<Array2<F>, T> for AdaBoost<M, L>
where
M: PredictInplace<Array2<F>, T>,
<T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
usize: Into<<T as AsTargets>::Elem>,
{
fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
let y_array = y.as_targets();
assert_eq!(
x.nrows(),
y_array.len_of(Axis(0)),
"The number of data points must match the number of outputs."
);
let mut all_predictions = Vec::with_capacity(self.models.len());
for model in &self.models {
let mut pred = model.default_target(x);
model.predict_inplace(x, &mut pred);
all_predictions.push(pred);
}
let mut prediction_maps = y_array.map(|_| HashMap::new());
for (model_idx, prediction) in all_predictions.iter().enumerate() {
let pred_array = prediction.as_targets();
let weight = self.model_weights[model_idx];
for (vote_map, &pred_val) in prediction_maps.iter_mut().zip(pred_array.iter()) {
let class_idx: usize = pred_val.into();
*vote_map.entry(class_idx).or_insert(0.0) += weight;
}
}
let final_predictions = prediction_maps.map(|votes| {
votes
.iter()
.max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap())
.map(|(k, _)| (*k).into())
.expect("No predictions found")
});
let mut y_array_mut = y.as_targets_mut();
for (y, pred) in y_array_mut.iter_mut().zip(final_predictions.iter()) {
*y = *pred;
}
}
fn default_target(&self, x: &Array2<F>) -> T {
self.models[0].default_target(x)
}
}
impl<D, T, P, R> Fit<Array2<D>, T, Error> for AdaBoostValidParams<P, R>
where
D: Clone + ndarray::ScalarOperand,
T: FromTargetArrayOwned<Owned = T> + AsTargets + Clone,
T::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
P: Fit<Array2<D>, T, Error> + Clone,
P::Object: PredictInplace<Array2<D>, T>,
R: Rng + Clone,
usize: Into<T::Elem>,
{
type Object = AdaBoost<P::Object, T::Elem>;
fn fit(
&self,
dataset: &DatasetBase<Array2<D>, T>,
) -> core::result::Result<Self::Object, Error> {
let n_samples = dataset.records.nrows();
if n_samples == 0 {
return Err(Error::Parameters(
"Cannot fit AdaBoost on empty dataset".to_string(),
));
}
let target_array = dataset.targets.as_targets();
let mut classes_set: Vec<T::Elem> = target_array
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
classes_set.sort_unstable_by_key(|x| (*x).into());
if classes_set.len() < 2 {
return Err(Error::Parameters(
"AdaBoost requires at least 2 classes".to_string(),
));
}
let mut sample_weights = Array1::from_elem(n_samples, 1.0 / n_samples as f64);
let mut models = Vec::with_capacity(self.n_estimators);
let mut model_weights = Vec::with_capacity(self.n_estimators);
let mut rng = self.rng.clone();
for iteration in 0..self.n_estimators {
let weight_sum = sample_weights.sum();
if weight_sum <= 0.0 {
return Err(Error::NotConverged(format!(
"Sample weights sum to zero at iteration {}",
iteration
)));
}
sample_weights /= weight_sum;
let dist = WeightedIndex::new(sample_weights.iter().copied())
.map_err(|_| Error::Parameters("Invalid sample weights".to_string()))?;
let bootstrap_indices: Vec<usize> =
(0..n_samples).map(|_| dist.sample(&mut rng)).collect();
let bootstrap_records = dataset.records.select(Axis(0), &bootstrap_indices);
let bootstrap_targets_array = target_array.select(Axis(0), &bootstrap_indices);
let bootstrap_targets = T::new_targets(bootstrap_targets_array);
let bootstrap_dataset = DatasetBase::new(bootstrap_records, bootstrap_targets);
let model = self.model_params.fit(&bootstrap_dataset).map_err(|e| {
Error::NotConverged(format!(
"Base learner failed to fit at iteration {}: {}",
iteration, e
))
})?;
let mut predictions = model.default_target(&dataset.records);
model.predict_inplace(&dataset.records, &mut predictions);
let pred_array = predictions.as_targets();
let mut weighted_error = 0.0;
for ((true_label, pred_label), weight) in target_array
.iter()
.zip(pred_array.iter())
.zip(sample_weights.iter())
{
let true_idx: usize = (*true_label).into();
let pred_idx: usize = (*pred_label).into();
if true_idx != pred_idx {
weighted_error += *weight;
}
}
if weighted_error <= 0.0 {
model_weights.push(PERFECT_MODEL_WEIGHT); models.push(model);
break;
}
let k = classes_set.len() as f64;
let error_threshold = (k - 1.0) / k;
if weighted_error >= error_threshold {
if models.is_empty() {
return Err(Error::NotConverged(format!(
"First base learner performs worse than random guessing (error: {:.4}, threshold: {:.4})",
weighted_error, error_threshold
)));
}
break;
}
let error_ratio = (1.0 - weighted_error) / weighted_error;
let alpha = self.learning_rate * (error_ratio.ln() + (k - 1.0).ln());
for ((true_label, pred_label), weight) in target_array
.iter()
.zip(pred_array.iter())
.zip(sample_weights.iter_mut())
{
let true_idx: usize = (*true_label).into();
let pred_idx: usize = (*pred_label).into();
if true_idx != pred_idx {
*weight *= alpha.exp();
}
}
model_weights.push(alpha);
models.push(model);
}
if models.is_empty() {
return Err(Error::NotConverged(
"No models were successfully trained".to_string(),
));
}
Ok(AdaBoost {
models,
model_weights,
classes: classes_set,
})
}
}