Skip to main content

anofox_ml_ensemble/
bagging_classifier.rs

1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use anofox_ml_trees::{DecisionTreeClassifier, FittedDecisionTreeClassifier, SplitCriterion};
3use ndarray::{Array1, Array2};
4use rand::rngs::StdRng;
5use rand::{Rng, SeedableRng};
6use rayon::prelude::*;
7
8/// Bagging (Bootstrap Aggregating) classifier parameters (unfitted state).
9///
10/// Trains an ensemble of decision tree classifiers, each on a bootstrap sample
11/// of the data using the **full** feature set. Unlike `RandomForestClassifier`,
12/// bagging does not perform random feature subsampling at the tree level --
13/// every tree sees all features.
14///
15/// Predictions are made by majority vote across all trees.
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct BaggingClassifier {
18    /// Number of trees in the ensemble.
19    pub n_estimators: usize,
20    /// Maximum depth of each tree.
21    pub max_depth: Option<usize>,
22    /// Fraction of samples to draw for each tree (with replacement when
23    /// `bootstrap=true`). If `None`, draws `n_samples`. Value in (0, 1].
24    pub max_samples: Option<f64>,
25    /// Whether to use bootstrap sampling. Default: true.
26    pub bootstrap: bool,
27    /// Random seed for reproducibility.
28    pub seed: u64,
29}
30
31impl BaggingClassifier {
32    /// Create a new `BaggingClassifier` with the given number of trees and default parameters.
33    pub fn new(n_estimators: usize) -> Self {
34        Self {
35            n_estimators,
36            max_depth: None,
37            max_samples: None,
38            bootstrap: true,
39            seed: 0,
40        }
41    }
42
43    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
44        self.max_depth = max_depth;
45        self
46    }
47    pub fn with_max_samples(mut self, max_samples: Option<f64>) -> Self {
48        self.max_samples = max_samples;
49        self
50    }
51    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
52        self.bootstrap = bootstrap;
53        self
54    }
55    pub fn with_seed(mut self, seed: u64) -> Self {
56        self.seed = seed;
57        self
58    }
59}
60
61impl Default for BaggingClassifier {
62    fn default() -> Self {
63        Self::new(10)
64    }
65}
66
67/// Fitted bagging classifier.
68#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
70pub struct FittedBaggingClassifier<F: Float> {
71    trees: Vec<FittedDecisionTreeClassifier<F>>,
72    n_features: usize,
73}
74
75impl<F: Float> Fit<F> for BaggingClassifier {
76    type Fitted = FittedBaggingClassifier<F>;
77
78    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
79        if x.nrows() != y.len() {
80            return Err(RustMlError::ShapeMismatch(format!(
81                "X has {} rows but y has {} elements",
82                x.nrows(),
83                y.len()
84            )));
85        }
86        if x.is_empty() {
87            return Err(RustMlError::EmptyInput("training data is empty".into()));
88        }
89        if self.n_estimators == 0 {
90            return Err(RustMlError::InvalidParameter(
91                "n_estimators must be > 0".into(),
92            ));
93        }
94
95        let n_samples = x.nrows();
96        let n_features = x.ncols();
97
98        let mut rng = StdRng::seed_from_u64(self.seed);
99
100        // Compute bootstrap sample size
101        let draw_size = if let Some(frac) = self.max_samples {
102            if frac <= 0.0 || frac > 1.0 {
103                return Err(RustMlError::InvalidParameter(
104                    "max_samples must be in (0, 1]".into(),
105                ));
106            }
107            (n_samples as f64 * frac).ceil() as usize
108        } else {
109            n_samples
110        };
111
112        let tree_params = DecisionTreeClassifier {
113            max_depth: self.max_depth,
114            min_samples_split: 2,
115            min_samples_leaf: 1,
116            criterion: SplitCriterion::Gini,
117            max_features: None,
118            sample_weight: None,
119            class_weight: None,
120        };
121
122        // Pre-generate bootstrap row indices for determinism
123        let sample_plans: Vec<Vec<usize>> = (0..self.n_estimators)
124            .map(|_| {
125                if self.bootstrap {
126                    (0..draw_size)
127                        .map(|_| rng.gen_range(0..n_samples))
128                        .collect()
129                } else {
130                    (0..n_samples).collect()
131                }
132            })
133            .collect();
134
135        // Train trees in parallel -- no feature subsampling
136        let trees: Result<Vec<FittedDecisionTreeClassifier<F>>> = sample_plans
137            .into_par_iter()
138            .map(|row_indices| {
139                let x_sub = build_sub_matrix_rows(x, &row_indices);
140                let y_sub = Array1::from_vec(row_indices.iter().map(|&i| y[i]).collect::<Vec<F>>());
141                tree_params.fit(&x_sub, &y_sub)
142            })
143            .collect();
144        let trees = trees?;
145
146        Ok(FittedBaggingClassifier { trees, n_features })
147    }
148}
149
150impl<F: Float> Predict<F> for FittedBaggingClassifier<F> {
151    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
152        if x.ncols() != self.n_features {
153            return Err(RustMlError::ShapeMismatch(format!(
154                "expected {} features, got {}",
155                self.n_features,
156                x.ncols()
157            )));
158        }
159
160        let n_samples = x.nrows();
161        let n_trees = self.trees.len();
162
163        // Collect all tree predictions in parallel
164        let all_preds: Result<Vec<Array1<F>>> =
165            self.trees.par_iter().map(|tree| tree.predict(x)).collect();
166        let all_preds = all_preds?;
167
168        // Aggregate votes per sample (majority vote)
169        let mut predictions = Vec::with_capacity(n_samples);
170        let mut votes = Vec::with_capacity(n_trees);
171        for i in 0..n_samples {
172            votes.clear();
173            for tree_pred in &all_preds {
174                votes.push(tree_pred[i]);
175            }
176            predictions.push(majority_vote(&votes));
177        }
178
179        Ok(Array1::from_vec(predictions))
180    }
181}
182
183impl<F: Float> FittedBaggingClassifier<F> {
184    /// Feature importances averaged across all trees and normalized to sum to 1.
185    pub fn feature_importances(&self) -> Array1<F> {
186        let mut importances = vec![F::zero(); self.n_features];
187        let n_trees = F::from_usize(self.trees.len()).unwrap();
188
189        for tree in &self.trees {
190            let tree_importances = tree.feature_importances();
191            for (idx, &imp) in tree_importances.iter().enumerate() {
192                importances[idx] += imp / n_trees;
193            }
194        }
195
196        // Normalize so importances sum to 1
197        let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
198        if sum > F::zero() {
199            Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
200        } else {
201            Array1::zeros(self.n_features)
202        }
203    }
204
205    /// Number of trees in the ensemble.
206    pub fn n_estimators(&self) -> usize {
207        self.trees.len()
208    }
209
210    /// Predict class probabilities for each sample.
211    ///
212    /// Returns an `Array2<F>` of shape `(n_samples, n_classes)` where each row
213    /// sums to 1.0. Probabilities are averaged across all trees in the ensemble.
214    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
215        if x.ncols() != self.n_features {
216            return Err(RustMlError::ShapeMismatch(format!(
217                "expected {} features, got {}",
218                self.n_features,
219                x.ncols()
220            )));
221        }
222
223        // Collect probabilities from each tree in parallel
224        let all_proba: Result<Vec<Array2<F>>> = self
225            .trees
226            .par_iter()
227            .map(|tree| tree.predict_proba(x))
228            .collect();
229        let all_proba = all_proba?;
230
231        // Determine global class set (union of all trees' classes)
232        let mut global_classes: Vec<F> = Vec::new();
233        let eps = F::from_f64(1e-9).unwrap();
234        for tree in &self.trees {
235            for c in tree.classes() {
236                if !global_classes.iter().any(|&gc| (gc - c).abs() < eps) {
237                    global_classes.push(c);
238                }
239            }
240        }
241        global_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
242
243        let n_samples = x.nrows();
244        let n_classes = global_classes.len();
245        let n_trees_f = F::from_usize(self.trees.len()).unwrap();
246        let mut avg_proba = Array2::<F>::zeros((n_samples, n_classes));
247
248        // Map each tree's probabilities to the global class indices and average
249        for (tree_idx, tree) in self.trees.iter().enumerate() {
250            let tree_classes = tree.classes();
251            let tree_proba = &all_proba[tree_idx];
252
253            for (local_ci, &tc) in tree_classes.iter().enumerate() {
254                if let Some(global_ci) = global_classes.iter().position(|&gc| (gc - tc).abs() < eps)
255                {
256                    for i in 0..n_samples {
257                        avg_proba[[i, global_ci]] += tree_proba[[i, local_ci]] / n_trees_f;
258                    }
259                }
260            }
261        }
262
263        Ok(avg_proba)
264    }
265
266    /// Compute classification accuracy on the given data.
267    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<f64> {
268        let preds = self.predict(x)?;
269        let n = y.len();
270        let correct = preds
271            .iter()
272            .zip(y.iter())
273            .filter(|(&p, &t)| (p - t).abs() < F::from_f64(1e-9).unwrap())
274            .count();
275        Ok(correct as f64 / n as f64)
276    }
277
278    /// Returns the unique sorted class labels across all trees.
279    pub fn classes(&self) -> Vec<F> {
280        let eps = F::from_f64(1e-9).unwrap();
281        let mut classes: Vec<F> = Vec::new();
282        for tree in &self.trees {
283            for c in tree.classes() {
284                if !classes.iter().any(|&gc| (gc - c).abs() < eps) {
285                    classes.push(c);
286                }
287            }
288        }
289        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
290        classes
291    }
292}
293
294/// Build a sub-matrix selecting specific rows (all columns) from `x`.
295fn build_sub_matrix_rows<F: Float>(x: &Array2<F>, row_indices: &[usize]) -> Array2<F> {
296    let n_rows = row_indices.len();
297    let n_cols = x.ncols();
298    let mut data = Vec::with_capacity(n_rows * n_cols);
299    for &ri in row_indices {
300        for ci in 0..n_cols {
301            data.push(x[[ri, ci]]);
302        }
303    }
304    Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
305}
306
307/// Return the class that appears most frequently in `votes`.
308#[inline]
309fn majority_vote<F: Float>(votes: &[F]) -> F {
310    use std::collections::HashMap;
311    let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
312    for &v in votes {
313        let key = v.to_f64().unwrap().to_bits();
314        counts.entry(key).and_modify(|e| e.1 += 1).or_insert((v, 1));
315    }
316    counts
317        .into_values()
318        .max_by_key(|&(_, count)| count)
319        .unwrap()
320        .0
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use approx::assert_abs_diff_eq;
327    use ndarray::array;
328
329    #[test]
330    fn test_basic_classification() {
331        let x = array![
332            [1.0, 0.0],
333            [2.0, 0.0],
334            [3.0, 0.0],
335            [10.0, 1.0],
336            [11.0, 1.0],
337            [12.0, 1.0]
338        ];
339        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
340
341        let bc = BaggingClassifier::new(20)
342            .with_max_depth(Some(3))
343            .with_seed(42);
344        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
345
346        let preds = fitted.predict(&x).unwrap();
347        for (p, t) in preds.iter().zip(y.iter()) {
348            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
349        }
350    }
351
352    #[test]
353    fn test_reproducibility() {
354        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
355        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
356
357        let bc = BaggingClassifier::new(10).with_seed(123);
358
359        let fitted1: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
360        let fitted2: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
361
362        let preds1 = fitted1.predict(&x).unwrap();
363        let preds2 = fitted2.predict(&x).unwrap();
364
365        for (a, b) in preds1.iter().zip(preds2.iter()) {
366            assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
367        }
368    }
369
370    #[test]
371    fn test_feature_importances_sum_to_one() {
372        let x = array![
373            [1.0, 100.0],
374            [2.0, 200.0],
375            [3.0, 300.0],
376            [4.0, 400.0],
377            [5.0, 500.0],
378            [6.0, 600.0]
379        ];
380        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
381
382        let bc = BaggingClassifier::new(20).with_seed(7);
383        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
384
385        let importances = fitted.feature_importances();
386        let sum: f64 = importances.iter().sum();
387        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
388    }
389
390    #[test]
391    fn test_predict_proba_rows_sum_to_one() {
392        let x = array![
393            [1.0, 0.0],
394            [2.0, 0.0],
395            [3.0, 0.0],
396            [10.0, 1.0],
397            [11.0, 1.0],
398            [12.0, 1.0]
399        ];
400        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
401
402        let bc = BaggingClassifier::new(20)
403            .with_max_depth(Some(3))
404            .with_seed(42);
405        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
406
407        let proba = fitted.predict_proba(&x).unwrap();
408        assert_eq!(proba.nrows(), x.nrows());
409        for i in 0..proba.nrows() {
410            let row_sum: f64 = proba.row(i).iter().sum();
411            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
412        }
413    }
414
415    #[test]
416    fn test_score() {
417        let x = array![
418            [1.0, 0.0],
419            [2.0, 0.0],
420            [3.0, 0.0],
421            [10.0, 1.0],
422            [11.0, 1.0],
423            [12.0, 1.0]
424        ];
425        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
426
427        let bc = BaggingClassifier::new(20)
428            .with_max_depth(Some(3))
429            .with_seed(42);
430        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
431
432        let acc = fitted.score(&x, &y).unwrap();
433        assert_abs_diff_eq!(acc, 1.0, epsilon = 1e-10);
434    }
435
436    #[test]
437    fn test_n_estimators() {
438        let x = array![[1.0], [2.0], [3.0], [4.0]];
439        let y = array![0.0, 0.0, 1.0, 1.0];
440
441        let bc = BaggingClassifier::new(7).with_seed(0);
442        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
443        assert_eq!(fitted.n_estimators(), 7);
444    }
445
446    #[test]
447    fn test_shape_mismatch_error() {
448        let x = array![[1.0], [2.0]];
449        let y = array![0.0, 1.0, 2.0];
450
451        let bc = BaggingClassifier::default();
452        let result: std::result::Result<FittedBaggingClassifier<f64>, _> = bc.fit(&x, &y);
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn test_predict_wrong_features_error() {
458        let x = array![[1.0, 2.0], [3.0, 4.0]];
459        let y = array![0.0, 1.0];
460
461        let bc = BaggingClassifier::new(5).with_seed(0);
462        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
463
464        let x_bad = array![[1.0], [2.0]];
465        let result = fitted.predict(&x_bad);
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_empty_input_error() {
471        let x: Array2<f64> = Array2::zeros((0, 2));
472        let y: Array1<f64> = Array1::zeros(0);
473
474        let bc = BaggingClassifier::default();
475        let result: std::result::Result<FittedBaggingClassifier<f64>, _> = bc.fit(&x, &y);
476        assert!(result.is_err());
477    }
478
479    #[test]
480    fn test_zero_estimators_error() {
481        let x = array![[1.0, 2.0], [3.0, 4.0]];
482        let y = array![0.0, 1.0];
483
484        let bc = BaggingClassifier::new(0);
485        let result: std::result::Result<FittedBaggingClassifier<f64>, _> = bc.fit(&x, &y);
486        assert!(result.is_err());
487    }
488
489    #[test]
490    fn test_multiclass() {
491        let x = array![
492            [1.0, 0.0],
493            [2.0, 0.0],
494            [3.0, 0.0],
495            [10.0, 1.0],
496            [11.0, 1.0],
497            [12.0, 1.0],
498            [20.0, 2.0],
499            [21.0, 2.0],
500            [22.0, 2.0]
501        ];
502        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
503
504        let bc = BaggingClassifier::new(30)
505            .with_max_depth(Some(5))
506            .with_seed(42);
507        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
508
509        let preds = fitted.predict(&x).unwrap();
510        let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
511        for &p in preds.iter() {
512            assert!(
513                valid_labels.contains(&p.to_bits()),
514                "prediction {p} is not a valid training label"
515            );
516        }
517    }
518
519    #[test]
520    fn test_max_samples() {
521        let x = array![
522            [1.0, 0.0],
523            [2.0, 0.0],
524            [3.0, 0.0],
525            [10.0, 1.0],
526            [11.0, 1.0],
527            [12.0, 1.0]
528        ];
529        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
530
531        let bc = BaggingClassifier::new(30)
532            .with_max_depth(Some(3))
533            .with_max_samples(Some(0.5))
534            .with_seed(42);
535        let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
536
537        // Should still produce valid predictions
538        let preds = fitted.predict(&x).unwrap();
539        assert_eq!(preds.len(), y.len());
540    }
541
542    #[test]
543    fn test_default() {
544        let bc = BaggingClassifier::default();
545        assert_eq!(bc.n_estimators, 10);
546        assert!(bc.bootstrap);
547        assert!(bc.max_depth.is_none());
548        assert!(bc.max_samples.is_none());
549        assert_eq!(bc.seed, 0);
550    }
551}
552
553impl<F: Float> PredictProba<F> for FittedBaggingClassifier<F> {
554    fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
555        Self::predict_proba(self, x)
556    }
557}