Skip to main content

linfa_ensemble/
adaboost.rs

1use crate::AdaBoostValidParams;
2use linfa::{
3    dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned},
4    error::Error,
5    traits::*,
6    DatasetBase,
7};
8use ndarray::{Array1, Array2, Axis};
9use ndarray_rand::rand::distributions::WeightedIndex;
10use ndarray_rand::rand::prelude::*;
11use ndarray_rand::rand::Rng;
12use std::{cmp::Eq, collections::HashMap, hash::Hash};
13
14/// Huge weight used to favor model with perfect predictions
15const PERFECT_MODEL_WEIGHT: f64 = 1e6;
16
17/// A fitted AdaBoost ensemble classifier.
18///
19/// ## Structure
20///
21/// AdaBoost (Adaptive Boosting) is an ensemble learning method that combines multiple weak learners
22/// into a strong classifier. Unlike bagging methods (like Random Forest), AdaBoost trains learners
23/// sequentially, where each new learner focuses more on examples that previous learners misclassified.
24///
25/// Each fitted model `M` has an associated weight (alpha) that represents its contribution to the
26/// final prediction. Models that perform better on their training data receive higher weights.
27///
28/// ## Algorithm Overview
29///
30/// Given a [DatasetBase](DatasetBase) denoted as `D` with `n` samples:
31/// 1. Initialize sample weights uniformly: `w_i = 1/n` for all samples
32/// 2. For each iteration `t` from 1 to T (number of estimators):
33///    a. Train base learner on weighted dataset
34///    b. Calculate weighted error rate
35///    c. Compute model weight (alpha) based on error
36///    d. Update sample weights: increase weights for misclassified samples
37///    e. Normalize sample weights
38///
39/// ## Prediction Algorithm
40///
41/// The final prediction is computed using weighted majority voting:
42/// - Each model's prediction is weighted by its alpha value
43/// - The class with the highest weighted vote is selected
44///
45/// ## Example
46///
47/// ```no_run
48/// use linfa::prelude::{Fit, Predict};
49/// use linfa_ensemble::AdaBoostParams;
50/// use linfa_trees::DecisionTree;
51/// use ndarray_rand::rand::SeedableRng;
52/// use rand::rngs::SmallRng;
53///
54/// // Load Iris dataset
55/// let mut rng = SmallRng::seed_from_u64(42);
56/// let (train, test) = linfa_datasets::iris()
57///     .shuffle(&mut rng)
58///     .split_with_ratio(0.8);
59///
60/// // Train AdaBoost with decision tree stumps
61/// let adaboost_model = AdaBoostParams::new(DecisionTree::params().max_depth(Some(1)))
62///     .n_estimators(50)
63///     .learning_rate(1.0)
64///     .fit(&train)
65///     .unwrap();
66///
67/// // Make predictions on the test set
68/// let predictions = adaboost_model.predict(&test);
69/// ```
70///
71/// ## References
72///
73/// * Freund, Y., & Schapire, R. E. (1997). A decision-theoretic generalization of on-line learning
74///   and an application to boosting. Journal of Computer and System Sciences, 55(1), 119-139.
75/// * [Scikit-Learn AdaBoost Documentation](https://scikit-learn.org/stable/modules/ensemble.html#adaboost)
76/// * [An Introduction to Statistical Learning](https://www.statlearning.com/), Chapter 8
77#[derive(Debug, Clone)]
78pub struct AdaBoost<M, L> {
79    /// The fitted base learner models
80    pub models: Vec<M>,
81    /// The weight (alpha) for each model in the ensemble
82    pub model_weights: Vec<f64>,
83    /// The unique class labels seen during training
84    pub classes: Vec<L>,
85}
86
87impl<M, L> AdaBoost<M, L> {
88    /// Returns the number of estimators in the ensemble
89    pub fn n_estimators(&self) -> usize {
90        self.models.len()
91    }
92
93    /// Returns the model weights (alpha values)
94    pub fn weights(&self) -> &[f64] {
95        &self.model_weights
96    }
97}
98
99impl<F: Clone, T, M, L> PredictInplace<Array2<F>, T> for AdaBoost<M, L>
100where
101    M: PredictInplace<Array2<F>, T>,
102    <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
103    T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
104    usize: Into<<T as AsTargets>::Elem>,
105{
106    fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
107        let y_array = y.as_targets();
108        assert_eq!(
109            x.nrows(),
110            y_array.len_of(Axis(0)),
111            "The number of data points must match the number of outputs."
112        );
113
114        // Collect predictions from all models
115        let mut all_predictions = Vec::with_capacity(self.models.len());
116        for model in &self.models {
117            let mut pred = model.default_target(x);
118            model.predict_inplace(x, &mut pred);
119            all_predictions.push(pred);
120        }
121
122        // Create a map for each sample to accumulate weighted votes
123        let mut prediction_maps = y_array.map(|_| HashMap::new());
124
125        // Accumulate weighted predictions from each model
126        for (model_idx, prediction) in all_predictions.iter().enumerate() {
127            let pred_array = prediction.as_targets();
128            let weight = self.model_weights[model_idx];
129
130            // For each sample, add the model's weighted prediction
131            for (vote_map, &pred_val) in prediction_maps.iter_mut().zip(pred_array.iter()) {
132                let class_idx: usize = pred_val.into();
133                *vote_map.entry(class_idx).or_insert(0.0) += weight;
134            }
135        }
136
137        // For each sample, select the class with the highest weighted vote
138        let final_predictions = prediction_maps.map(|votes| {
139            votes
140                .iter()
141                .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap())
142                .map(|(k, _)| (*k).into())
143                .expect("No predictions found")
144        });
145
146        // Write final predictions to output
147        let mut y_array_mut = y.as_targets_mut();
148        for (y, pred) in y_array_mut.iter_mut().zip(final_predictions.iter()) {
149            *y = *pred;
150        }
151    }
152
153    fn default_target(&self, x: &Array2<F>) -> T {
154        self.models[0].default_target(x)
155    }
156}
157
158impl<D, T, P, R> Fit<Array2<D>, T, Error> for AdaBoostValidParams<P, R>
159where
160    D: Clone + ndarray::ScalarOperand,
161    T: FromTargetArrayOwned<Owned = T> + AsTargets + Clone,
162    T::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
163    P: Fit<Array2<D>, T, Error> + Clone,
164    P::Object: PredictInplace<Array2<D>, T>,
165    R: Rng + Clone,
166    usize: Into<T::Elem>,
167{
168    type Object = AdaBoost<P::Object, T::Elem>;
169
170    fn fit(
171        &self,
172        dataset: &DatasetBase<Array2<D>, T>,
173    ) -> core::result::Result<Self::Object, Error> {
174        let n_samples = dataset.records.nrows();
175
176        if n_samples == 0 {
177            return Err(Error::Parameters(
178                "Cannot fit AdaBoost on empty dataset".to_string(),
179            ));
180        }
181
182        // Extract unique class labels from target array
183        let target_array = dataset.targets.as_targets();
184        let mut classes_set: Vec<T::Elem> = target_array
185            .iter()
186            .copied()
187            .collect::<std::collections::HashSet<_>>()
188            .into_iter()
189            .collect();
190        // Sort by converting to usize for ordering
191        classes_set.sort_unstable_by_key(|x| (*x).into());
192
193        if classes_set.len() < 2 {
194            return Err(Error::Parameters(
195                "AdaBoost requires at least 2 classes".to_string(),
196            ));
197        }
198
199        // Initialize sample weights uniformly
200        let mut sample_weights = Array1::from_elem(n_samples, 1.0 / n_samples as f64);
201
202        let mut models = Vec::with_capacity(self.n_estimators);
203        let mut model_weights = Vec::with_capacity(self.n_estimators);
204
205        let mut rng = self.rng.clone();
206
207        for iteration in 0..self.n_estimators {
208            // Normalize weights to sum to 1
209            let weight_sum = sample_weights.sum();
210            if weight_sum <= 0.0 {
211                return Err(Error::NotConverged(format!(
212                    "Sample weights sum to zero at iteration {}",
213                    iteration
214                )));
215            }
216            sample_weights /= weight_sum;
217
218            // Resample dataset according to sample weights
219            // This is the practical implementation of AdaBoost when base learners don't support weights
220            let dist = WeightedIndex::new(sample_weights.iter().copied())
221                .map_err(|_| Error::Parameters("Invalid sample weights".to_string()))?;
222
223            let bootstrap_indices: Vec<usize> =
224                (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
225
226            // Create bootstrap dataset by selecting rows according to weights
227            let bootstrap_records = dataset.records.select(Axis(0), &bootstrap_indices);
228            let bootstrap_targets_array = target_array.select(Axis(0), &bootstrap_indices);
229
230            // Convert to owned target type using new_targets
231            let bootstrap_targets = T::new_targets(bootstrap_targets_array);
232            let bootstrap_dataset = DatasetBase::new(bootstrap_records, bootstrap_targets);
233
234            // Fit base learner on resampled dataset
235            let model = self.model_params.fit(&bootstrap_dataset).map_err(|e| {
236                Error::NotConverged(format!(
237                    "Base learner failed to fit at iteration {}: {}",
238                    iteration, e
239                ))
240            })?;
241
242            // Make predictions on training data
243            let mut predictions = model.default_target(&dataset.records);
244            model.predict_inplace(&dataset.records, &mut predictions);
245            let pred_array = predictions.as_targets();
246
247            // Calculate weighted error
248            let mut weighted_error = 0.0;
249            for ((true_label, pred_label), weight) in target_array
250                .iter()
251                .zip(pred_array.iter())
252                .zip(sample_weights.iter())
253            {
254                let true_idx: usize = (*true_label).into();
255                let pred_idx: usize = (*pred_label).into();
256
257                if true_idx != pred_idx {
258                    weighted_error += *weight;
259                }
260            }
261
262            // Handle edge cases for weighted error
263            if weighted_error <= 0.0 {
264                // Perfect prediction - add model with maximum weight and stop
265                model_weights.push(PERFECT_MODEL_WEIGHT); // Large weight for perfect model
266                models.push(model);
267                break;
268            }
269
270            // For multi-class SAMME, check if error rate is above the random guessing threshold
271            let k = classes_set.len() as f64;
272            let error_threshold = (k - 1.0) / k;
273
274            if weighted_error >= error_threshold {
275                // Worse than random guessing for multi-class - don't add this model
276                if models.is_empty() {
277                    return Err(Error::NotConverged(format!(
278                        "First base learner performs worse than random guessing (error: {:.4}, threshold: {:.4})",
279                        weighted_error, error_threshold
280                    )));
281                }
282                break;
283            }
284
285            // Calculate model weight (alpha) using SAMME algorithm
286            // For multi-class: alpha = learning_rate * (log((1 - error) / error) + log(K - 1))
287            // where K is number of classes
288            let error_ratio = (1.0 - weighted_error) / weighted_error;
289            let alpha = self.learning_rate * (error_ratio.ln() + (k - 1.0).ln());
290
291            // Update sample weights
292            for ((true_label, pred_label), weight) in target_array
293                .iter()
294                .zip(pred_array.iter())
295                .zip(sample_weights.iter_mut())
296            {
297                let true_idx: usize = (*true_label).into();
298                let pred_idx: usize = (*pred_label).into();
299
300                if true_idx != pred_idx {
301                    // Increase weight for misclassified samples
302                    *weight *= alpha.exp();
303                }
304            }
305
306            model_weights.push(alpha);
307            models.push(model);
308        }
309
310        if models.is_empty() {
311            return Err(Error::NotConverged(
312                "No models were successfully trained".to_string(),
313            ));
314        }
315
316        Ok(AdaBoost {
317            models,
318            model_weights,
319            classes: classes_set,
320        })
321    }
322}