1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use anofox_ml_trees::{DecisionTreeClassifier, FittedDecisionTreeClassifier, SplitCriterion};
3use ndarray::{Array1, Array2};
4use rand::rngs::StdRng;
5use rand::{Rng, SeedableRng};
6use rayon::prelude::*;
7
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct BaggingClassifier {
18 pub n_estimators: usize,
20 pub max_depth: Option<usize>,
22 pub max_samples: Option<f64>,
25 pub bootstrap: bool,
27 pub seed: u64,
29}
30
31impl BaggingClassifier {
32 pub fn new(n_estimators: usize) -> Self {
34 Self {
35 n_estimators,
36 max_depth: None,
37 max_samples: None,
38 bootstrap: true,
39 seed: 0,
40 }
41 }
42
43 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
44 self.max_depth = max_depth;
45 self
46 }
47 pub fn with_max_samples(mut self, max_samples: Option<f64>) -> Self {
48 self.max_samples = max_samples;
49 self
50 }
51 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
52 self.bootstrap = bootstrap;
53 self
54 }
55 pub fn with_seed(mut self, seed: u64) -> Self {
56 self.seed = seed;
57 self
58 }
59}
60
61impl Default for BaggingClassifier {
62 fn default() -> Self {
63 Self::new(10)
64 }
65}
66
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
70pub struct FittedBaggingClassifier<F: Float> {
71 trees: Vec<FittedDecisionTreeClassifier<F>>,
72 n_features: usize,
73}
74
75impl<F: Float> Fit<F> for BaggingClassifier {
76 type Fitted = FittedBaggingClassifier<F>;
77
78 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
79 if x.nrows() != y.len() {
80 return Err(RustMlError::ShapeMismatch(format!(
81 "X has {} rows but y has {} elements",
82 x.nrows(),
83 y.len()
84 )));
85 }
86 if x.is_empty() {
87 return Err(RustMlError::EmptyInput("training data is empty".into()));
88 }
89 if self.n_estimators == 0 {
90 return Err(RustMlError::InvalidParameter(
91 "n_estimators must be > 0".into(),
92 ));
93 }
94
95 let n_samples = x.nrows();
96 let n_features = x.ncols();
97
98 let mut rng = StdRng::seed_from_u64(self.seed);
99
100 let draw_size = if let Some(frac) = self.max_samples {
102 if frac <= 0.0 || frac > 1.0 {
103 return Err(RustMlError::InvalidParameter(
104 "max_samples must be in (0, 1]".into(),
105 ));
106 }
107 (n_samples as f64 * frac).ceil() as usize
108 } else {
109 n_samples
110 };
111
112 let tree_params = DecisionTreeClassifier {
113 max_depth: self.max_depth,
114 min_samples_split: 2,
115 min_samples_leaf: 1,
116 criterion: SplitCriterion::Gini,
117 max_features: None,
118 sample_weight: None,
119 class_weight: None,
120 };
121
122 let sample_plans: Vec<Vec<usize>> = (0..self.n_estimators)
124 .map(|_| {
125 if self.bootstrap {
126 (0..draw_size)
127 .map(|_| rng.gen_range(0..n_samples))
128 .collect()
129 } else {
130 (0..n_samples).collect()
131 }
132 })
133 .collect();
134
135 let trees: Result<Vec<FittedDecisionTreeClassifier<F>>> = sample_plans
137 .into_par_iter()
138 .map(|row_indices| {
139 let x_sub = build_sub_matrix_rows(x, &row_indices);
140 let y_sub = Array1::from_vec(row_indices.iter().map(|&i| y[i]).collect::<Vec<F>>());
141 tree_params.fit(&x_sub, &y_sub)
142 })
143 .collect();
144 let trees = trees?;
145
146 Ok(FittedBaggingClassifier { trees, n_features })
147 }
148}
149
150impl<F: Float> Predict<F> for FittedBaggingClassifier<F> {
151 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
152 if x.ncols() != self.n_features {
153 return Err(RustMlError::ShapeMismatch(format!(
154 "expected {} features, got {}",
155 self.n_features,
156 x.ncols()
157 )));
158 }
159
160 let n_samples = x.nrows();
161 let n_trees = self.trees.len();
162
163 let all_preds: Result<Vec<Array1<F>>> =
165 self.trees.par_iter().map(|tree| tree.predict(x)).collect();
166 let all_preds = all_preds?;
167
168 let mut predictions = Vec::with_capacity(n_samples);
170 let mut votes = Vec::with_capacity(n_trees);
171 for i in 0..n_samples {
172 votes.clear();
173 for tree_pred in &all_preds {
174 votes.push(tree_pred[i]);
175 }
176 predictions.push(majority_vote(&votes));
177 }
178
179 Ok(Array1::from_vec(predictions))
180 }
181}
182
183impl<F: Float> FittedBaggingClassifier<F> {
184 pub fn feature_importances(&self) -> Array1<F> {
186 let mut importances = vec![F::zero(); self.n_features];
187 let n_trees = F::from_usize(self.trees.len()).unwrap();
188
189 for tree in &self.trees {
190 let tree_importances = tree.feature_importances();
191 for (idx, &imp) in tree_importances.iter().enumerate() {
192 importances[idx] += imp / n_trees;
193 }
194 }
195
196 let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
198 if sum > F::zero() {
199 Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
200 } else {
201 Array1::zeros(self.n_features)
202 }
203 }
204
205 pub fn n_estimators(&self) -> usize {
207 self.trees.len()
208 }
209
210 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
215 if x.ncols() != self.n_features {
216 return Err(RustMlError::ShapeMismatch(format!(
217 "expected {} features, got {}",
218 self.n_features,
219 x.ncols()
220 )));
221 }
222
223 let all_proba: Result<Vec<Array2<F>>> = self
225 .trees
226 .par_iter()
227 .map(|tree| tree.predict_proba(x))
228 .collect();
229 let all_proba = all_proba?;
230
231 let mut global_classes: Vec<F> = Vec::new();
233 let eps = F::from_f64(1e-9).unwrap();
234 for tree in &self.trees {
235 for c in tree.classes() {
236 if !global_classes.iter().any(|&gc| (gc - c).abs() < eps) {
237 global_classes.push(c);
238 }
239 }
240 }
241 global_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
242
243 let n_samples = x.nrows();
244 let n_classes = global_classes.len();
245 let n_trees_f = F::from_usize(self.trees.len()).unwrap();
246 let mut avg_proba = Array2::<F>::zeros((n_samples, n_classes));
247
248 for (tree_idx, tree) in self.trees.iter().enumerate() {
250 let tree_classes = tree.classes();
251 let tree_proba = &all_proba[tree_idx];
252
253 for (local_ci, &tc) in tree_classes.iter().enumerate() {
254 if let Some(global_ci) = global_classes.iter().position(|&gc| (gc - tc).abs() < eps)
255 {
256 for i in 0..n_samples {
257 avg_proba[[i, global_ci]] += tree_proba[[i, local_ci]] / n_trees_f;
258 }
259 }
260 }
261 }
262
263 Ok(avg_proba)
264 }
265
266 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<f64> {
268 let preds = self.predict(x)?;
269 let n = y.len();
270 let correct = preds
271 .iter()
272 .zip(y.iter())
273 .filter(|(&p, &t)| (p - t).abs() < F::from_f64(1e-9).unwrap())
274 .count();
275 Ok(correct as f64 / n as f64)
276 }
277
278 pub fn classes(&self) -> Vec<F> {
280 let eps = F::from_f64(1e-9).unwrap();
281 let mut classes: Vec<F> = Vec::new();
282 for tree in &self.trees {
283 for c in tree.classes() {
284 if !classes.iter().any(|&gc| (gc - c).abs() < eps) {
285 classes.push(c);
286 }
287 }
288 }
289 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
290 classes
291 }
292}
293
294fn build_sub_matrix_rows<F: Float>(x: &Array2<F>, row_indices: &[usize]) -> Array2<F> {
296 let n_rows = row_indices.len();
297 let n_cols = x.ncols();
298 let mut data = Vec::with_capacity(n_rows * n_cols);
299 for &ri in row_indices {
300 for ci in 0..n_cols {
301 data.push(x[[ri, ci]]);
302 }
303 }
304 Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
305}
306
307#[inline]
309fn majority_vote<F: Float>(votes: &[F]) -> F {
310 use std::collections::HashMap;
311 let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
312 for &v in votes {
313 let key = v.to_f64().unwrap().to_bits();
314 counts.entry(key).and_modify(|e| e.1 += 1).or_insert((v, 1));
315 }
316 counts
317 .into_values()
318 .max_by_key(|&(_, count)| count)
319 .unwrap()
320 .0
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use approx::assert_abs_diff_eq;
327 use ndarray::array;
328
329 #[test]
330 fn test_basic_classification() {
331 let x = array![
332 [1.0, 0.0],
333 [2.0, 0.0],
334 [3.0, 0.0],
335 [10.0, 1.0],
336 [11.0, 1.0],
337 [12.0, 1.0]
338 ];
339 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
340
341 let bc = BaggingClassifier::new(20)
342 .with_max_depth(Some(3))
343 .with_seed(42);
344 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
345
346 let preds = fitted.predict(&x).unwrap();
347 for (p, t) in preds.iter().zip(y.iter()) {
348 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
349 }
350 }
351
352 #[test]
353 fn test_reproducibility() {
354 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
355 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
356
357 let bc = BaggingClassifier::new(10).with_seed(123);
358
359 let fitted1: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
360 let fitted2: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
361
362 let preds1 = fitted1.predict(&x).unwrap();
363 let preds2 = fitted2.predict(&x).unwrap();
364
365 for (a, b) in preds1.iter().zip(preds2.iter()) {
366 assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
367 }
368 }
369
370 #[test]
371 fn test_feature_importances_sum_to_one() {
372 let x = array![
373 [1.0, 100.0],
374 [2.0, 200.0],
375 [3.0, 300.0],
376 [4.0, 400.0],
377 [5.0, 500.0],
378 [6.0, 600.0]
379 ];
380 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
381
382 let bc = BaggingClassifier::new(20).with_seed(7);
383 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
384
385 let importances = fitted.feature_importances();
386 let sum: f64 = importances.iter().sum();
387 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
388 }
389
390 #[test]
391 fn test_predict_proba_rows_sum_to_one() {
392 let x = array![
393 [1.0, 0.0],
394 [2.0, 0.0],
395 [3.0, 0.0],
396 [10.0, 1.0],
397 [11.0, 1.0],
398 [12.0, 1.0]
399 ];
400 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
401
402 let bc = BaggingClassifier::new(20)
403 .with_max_depth(Some(3))
404 .with_seed(42);
405 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
406
407 let proba = fitted.predict_proba(&x).unwrap();
408 assert_eq!(proba.nrows(), x.nrows());
409 for i in 0..proba.nrows() {
410 let row_sum: f64 = proba.row(i).iter().sum();
411 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
412 }
413 }
414
415 #[test]
416 fn test_score() {
417 let x = array![
418 [1.0, 0.0],
419 [2.0, 0.0],
420 [3.0, 0.0],
421 [10.0, 1.0],
422 [11.0, 1.0],
423 [12.0, 1.0]
424 ];
425 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
426
427 let bc = BaggingClassifier::new(20)
428 .with_max_depth(Some(3))
429 .with_seed(42);
430 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
431
432 let acc = fitted.score(&x, &y).unwrap();
433 assert_abs_diff_eq!(acc, 1.0, epsilon = 1e-10);
434 }
435
436 #[test]
437 fn test_n_estimators() {
438 let x = array![[1.0], [2.0], [3.0], [4.0]];
439 let y = array![0.0, 0.0, 1.0, 1.0];
440
441 let bc = BaggingClassifier::new(7).with_seed(0);
442 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
443 assert_eq!(fitted.n_estimators(), 7);
444 }
445
446 #[test]
447 fn test_shape_mismatch_error() {
448 let x = array![[1.0], [2.0]];
449 let y = array![0.0, 1.0, 2.0];
450
451 let bc = BaggingClassifier::default();
452 let result: std::result::Result<FittedBaggingClassifier<f64>, _> = bc.fit(&x, &y);
453 assert!(result.is_err());
454 }
455
456 #[test]
457 fn test_predict_wrong_features_error() {
458 let x = array![[1.0, 2.0], [3.0, 4.0]];
459 let y = array![0.0, 1.0];
460
461 let bc = BaggingClassifier::new(5).with_seed(0);
462 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
463
464 let x_bad = array![[1.0], [2.0]];
465 let result = fitted.predict(&x_bad);
466 assert!(result.is_err());
467 }
468
469 #[test]
470 fn test_empty_input_error() {
471 let x: Array2<f64> = Array2::zeros((0, 2));
472 let y: Array1<f64> = Array1::zeros(0);
473
474 let bc = BaggingClassifier::default();
475 let result: std::result::Result<FittedBaggingClassifier<f64>, _> = bc.fit(&x, &y);
476 assert!(result.is_err());
477 }
478
479 #[test]
480 fn test_zero_estimators_error() {
481 let x = array![[1.0, 2.0], [3.0, 4.0]];
482 let y = array![0.0, 1.0];
483
484 let bc = BaggingClassifier::new(0);
485 let result: std::result::Result<FittedBaggingClassifier<f64>, _> = bc.fit(&x, &y);
486 assert!(result.is_err());
487 }
488
489 #[test]
490 fn test_multiclass() {
491 let x = array![
492 [1.0, 0.0],
493 [2.0, 0.0],
494 [3.0, 0.0],
495 [10.0, 1.0],
496 [11.0, 1.0],
497 [12.0, 1.0],
498 [20.0, 2.0],
499 [21.0, 2.0],
500 [22.0, 2.0]
501 ];
502 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
503
504 let bc = BaggingClassifier::new(30)
505 .with_max_depth(Some(5))
506 .with_seed(42);
507 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
508
509 let preds = fitted.predict(&x).unwrap();
510 let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
511 for &p in preds.iter() {
512 assert!(
513 valid_labels.contains(&p.to_bits()),
514 "prediction {p} is not a valid training label"
515 );
516 }
517 }
518
519 #[test]
520 fn test_max_samples() {
521 let x = array![
522 [1.0, 0.0],
523 [2.0, 0.0],
524 [3.0, 0.0],
525 [10.0, 1.0],
526 [11.0, 1.0],
527 [12.0, 1.0]
528 ];
529 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
530
531 let bc = BaggingClassifier::new(30)
532 .with_max_depth(Some(3))
533 .with_max_samples(Some(0.5))
534 .with_seed(42);
535 let fitted: FittedBaggingClassifier<f64> = bc.fit(&x, &y).unwrap();
536
537 let preds = fitted.predict(&x).unwrap();
539 assert_eq!(preds.len(), y.len());
540 }
541
542 #[test]
543 fn test_default() {
544 let bc = BaggingClassifier::default();
545 assert_eq!(bc.n_estimators, 10);
546 assert!(bc.bootstrap);
547 assert!(bc.max_depth.is_none());
548 assert!(bc.max_samples.is_none());
549 assert_eq!(bc.seed, 0);
550 }
551}
552
553impl<F: Float> PredictProba<F> for FittedBaggingClassifier<F> {
554 fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
555 Self::predict_proba(self, x)
556 }
557}