1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use anofox_ml_trees::{DecisionTreeRegressor, FittedDecisionTreeRegressor};
3use ndarray::{Array1, Array2};
4use rand::rngs::StdRng;
5use rand::seq::SliceRandom;
6use rand::SeedableRng;
7
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct GradientBoostingClassifier {
15 pub n_estimators: usize,
17 pub learning_rate: f64,
19 pub max_depth: Option<usize>,
21 pub min_samples_split: usize,
23 pub min_samples_leaf: usize,
25 pub subsample: f64,
27 pub seed: u64,
29}
30
31impl GradientBoostingClassifier {
32 pub fn new() -> Self {
34 Self {
35 n_estimators: 100,
36 learning_rate: 0.1,
37 max_depth: Some(3),
38 min_samples_split: 2,
39 min_samples_leaf: 1,
40 subsample: 1.0,
41 seed: 0,
42 }
43 }
44
45 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
47 self.n_estimators = n_estimators;
48 self
49 }
50
51 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
53 self.learning_rate = learning_rate;
54 self
55 }
56
57 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
59 self.max_depth = max_depth;
60 self
61 }
62
63 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
65 self.min_samples_split = min_samples_split;
66 self
67 }
68
69 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
71 self.min_samples_leaf = min_samples_leaf;
72 self
73 }
74
75 pub fn with_subsample(mut self, subsample: f64) -> Self {
77 self.subsample = subsample;
78 self
79 }
80
81 pub fn with_seed(mut self, seed: u64) -> Self {
83 self.seed = seed;
84 self
85 }
86}
87
88impl Default for GradientBoostingClassifier {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
96#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
97pub struct FittedGradientBoostingClassifier<F: Float> {
98 classes: Vec<F>,
100 tree_sets: Vec<Vec<FittedDecisionTreeRegressor<F>>>,
103 initial_values: Vec<F>,
105 learning_rate: F,
107 n_features: usize,
109}
110
111impl<F: Float> Fit<F> for GradientBoostingClassifier {
112 type Fitted = FittedGradientBoostingClassifier<F>;
113
114 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
115 if x.nrows() != y.len() {
116 return Err(RustMlError::ShapeMismatch(format!(
117 "X has {} rows but y has {} elements",
118 x.nrows(),
119 y.len()
120 )));
121 }
122 if x.is_empty() {
123 return Err(RustMlError::EmptyInput("training data is empty".into()));
124 }
125 if self.n_estimators == 0 {
126 return Err(RustMlError::InvalidParameter(
127 "n_estimators must be > 0".into(),
128 ));
129 }
130 if self.learning_rate <= 0.0 {
131 return Err(RustMlError::InvalidParameter(
132 "learning_rate must be > 0".into(),
133 ));
134 }
135 if self.subsample <= 0.0 || self.subsample > 1.0 {
136 return Err(RustMlError::InvalidParameter(
137 "subsample must be in (0, 1]".into(),
138 ));
139 }
140
141 let classes = unique_sorted(y);
143 let n_classes = classes.len();
144 if n_classes < 2 {
145 return Err(RustMlError::InvalidParameter(
146 "y must contain at least 2 distinct classes".into(),
147 ));
148 }
149
150 let n_features = x.ncols();
151 let lr = F::from_f64(self.learning_rate).unwrap();
152
153 if n_classes == 2 {
154 let (initial, trees) = self.fit_binary(x, y, &classes[1], lr)?;
156 Ok(FittedGradientBoostingClassifier {
157 classes,
158 tree_sets: vec![trees],
159 initial_values: vec![initial],
160 learning_rate: lr,
161 n_features,
162 })
163 } else {
164 let mut tree_sets = Vec::with_capacity(n_classes);
166 let mut initial_values = Vec::with_capacity(n_classes);
167
168 for class in &classes {
169 let (initial, trees) = self.fit_binary(x, y, class, lr)?;
170 tree_sets.push(trees);
171 initial_values.push(initial);
172 }
173
174 Ok(FittedGradientBoostingClassifier {
175 classes,
176 tree_sets,
177 initial_values,
178 learning_rate: lr,
179 n_features,
180 })
181 }
182 }
183}
184
185impl GradientBoostingClassifier {
186 fn fit_binary<F: Float>(
189 &self,
190 x: &Array2<F>,
191 y: &Array1<F>,
192 positive_class: &F,
193 lr: F,
194 ) -> Result<(F, Vec<FittedDecisionTreeRegressor<F>>)> {
195 let n_samples = x.nrows();
196 let eps = F::from_f64(1e-15).unwrap();
197
198 let binary_y: Array1<F> = y.mapv(|v| {
200 if (v - *positive_class).abs() < eps {
201 F::one()
202 } else {
203 F::zero()
204 }
205 });
206
207 let p = binary_y.sum() / F::from_usize(n_samples).unwrap();
209 let p_clipped = clamp(p, eps, F::one() - eps);
210 let initial_log_odds = (p_clipped / (F::one() - p_clipped)).ln();
211
212 let mut log_odds = Array1::from_elem(n_samples, initial_log_odds);
213
214 let tree_params = DecisionTreeRegressor {
215 max_depth: self.max_depth,
216 min_samples_split: self.min_samples_split,
217 min_samples_leaf: self.min_samples_leaf,
218 max_features: None,
219 sample_weight: None,
220 };
221
222 let mut rng = StdRng::seed_from_u64(self.seed);
223 let mut trees = Vec::with_capacity(self.n_estimators);
224 let subsample_size = ((self.subsample * n_samples as f64).round() as usize).max(1);
225
226 let mut probs = Array1::<F>::zeros(n_samples);
228 let mut residuals = Array1::<F>::zeros(n_samples);
229 let mut indices: Vec<usize> = (0..n_samples).collect();
230
231 for _ in 0..self.n_estimators {
232 for i in 0..n_samples {
234 probs[i] = sigmoid(log_odds[i]);
235 residuals[i] = binary_y[i] - probs[i];
236 }
237
238 let fitted_tree: FittedDecisionTreeRegressor<F> = if subsample_size < n_samples {
240 indices.clear();
241 indices.extend(0..n_samples);
242 indices.shuffle(&mut rng);
243 indices.truncate(subsample_size);
244 indices.sort_unstable();
245
246 let x_sub = build_sub_rows(x, &indices);
247 let r_sub = Array1::from_vec(indices.iter().map(|&i| residuals[i]).collect());
248 tree_params.fit(&x_sub, &r_sub)?
249 } else {
250 tree_params.fit(x, &residuals)?
251 };
252
253 let tree_preds = fitted_tree.predict(x)?;
255 log_odds += &(tree_preds * lr);
256
257 trees.push(fitted_tree);
258 }
259
260 Ok((initial_log_odds, trees))
261 }
262}
263
264impl<F: Float> Predict<F> for FittedGradientBoostingClassifier<F> {
265 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
266 if x.ncols() != self.n_features {
267 return Err(RustMlError::ShapeMismatch(format!(
268 "expected {} features, got {}",
269 self.n_features,
270 x.ncols()
271 )));
272 }
273
274 let n_samples = x.nrows();
275
276 if self.classes.len() == 2 {
277 let log_odds = self.predict_log_odds(x, 0)?;
279 let half = F::from_f64(0.5).unwrap();
280
281 let predictions: Vec<F> = log_odds
282 .iter()
283 .map(|&lo| {
284 if sigmoid(lo) >= half {
285 self.classes[1]
286 } else {
287 self.classes[0]
288 }
289 })
290 .collect();
291
292 Ok(Array1::from_vec(predictions))
293 } else {
294 let mut all_log_odds = Vec::with_capacity(self.classes.len());
296 for k in 0..self.classes.len() {
297 all_log_odds.push(self.predict_log_odds(x, k)?);
298 }
299
300 let mut predictions = Vec::with_capacity(n_samples);
301 for sample_idx in 0..n_samples {
302 let mut best_class = 0;
303 let mut best_val = all_log_odds[0][sample_idx];
304 for (k, log_odds_k) in all_log_odds.iter().enumerate().skip(1) {
305 if log_odds_k[sample_idx] > best_val {
306 best_val = log_odds_k[sample_idx];
307 best_class = k;
308 }
309 }
310 predictions.push(self.classes[best_class]);
311 }
312
313 Ok(Array1::from_vec(predictions))
314 }
315 }
316}
317
318impl<F: Float> FittedGradientBoostingClassifier<F> {
319 pub fn n_estimators(&self) -> usize {
321 self.tree_sets.first().map_or(0, |ts| ts.len())
322 }
323
324 pub fn classes(&self) -> &[F] {
326 &self.classes
327 }
328
329 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
335 if x.ncols() != self.n_features {
336 return Err(RustMlError::ShapeMismatch(format!(
337 "expected {} features, got {}",
338 self.n_features,
339 x.ncols()
340 )));
341 }
342
343 let n_samples = x.nrows();
344 let n_classes = self.classes.len();
345
346 if n_classes == 2 {
347 let log_odds = self.predict_log_odds(x, 0)?;
349 let mut proba = Array2::<F>::zeros((n_samples, 2));
350 for i in 0..n_samples {
351 let p1 = sigmoid(log_odds[i]);
352 proba[[i, 0]] = F::one() - p1;
353 proba[[i, 1]] = p1;
354 }
355 Ok(proba)
356 } else {
357 let mut all_log_odds = Vec::with_capacity(n_classes);
359 for k in 0..n_classes {
360 all_log_odds.push(self.predict_log_odds(x, k)?);
361 }
362
363 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
364 for i in 0..n_samples {
365 let mut max_lo = all_log_odds[0][i];
367 for k in 1..n_classes {
368 if all_log_odds[k][i] > max_lo {
369 max_lo = all_log_odds[k][i];
370 }
371 }
372 let mut sum = F::zero();
374 for k in 0..n_classes {
375 let e = (all_log_odds[k][i] - max_lo).exp();
376 proba[[i, k]] = e;
377 sum += e;
378 }
379 for k in 0..n_classes {
381 proba[[i, k]] /= sum;
382 }
383 }
384 Ok(proba)
385 }
386 }
387
388 pub fn feature_importances(&self) -> Array1<F> {
393 let mut importances = vec![F::zero(); self.n_features];
394 let mut total_trees = 0usize;
395
396 for tree_set in &self.tree_sets {
397 for tree in tree_set {
398 let tree_imp = tree.feature_importances();
399 for (j, &imp) in tree_imp.iter().enumerate() {
400 importances[j] += imp;
401 }
402 total_trees += 1;
403 }
404 }
405
406 if total_trees > 0 {
407 let total_f = F::from_usize(total_trees).unwrap();
408 for imp in &mut importances {
409 *imp /= total_f;
410 }
411 }
412
413 let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
415 if sum > F::zero() {
416 Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
417 } else {
418 Array1::zeros(self.n_features)
419 }
420 }
421
422 fn predict_log_odds(&self, x: &Array2<F>, k: usize) -> Result<Array1<F>> {
424 let n_samples = x.nrows();
425 let mut log_odds = Array1::from_elem(n_samples, self.initial_values[k]);
426
427 for tree in &self.tree_sets[k] {
428 let tree_preds = tree.predict(x)?;
429 log_odds += &(tree_preds * self.learning_rate);
430 }
431
432 Ok(log_odds)
433 }
434}
435
436fn sigmoid<F: Float>(x: F) -> F {
442 F::one() / (F::one() + (-x).exp())
443}
444
445fn clamp<F: Float>(x: F, lo: F, hi: F) -> F {
447 if x < lo {
448 lo
449 } else if x > hi {
450 hi
451 } else {
452 x
453 }
454}
455
456fn unique_sorted<F: Float>(arr: &Array1<F>) -> Vec<F> {
458 let eps = F::from_f64(1e-9).unwrap();
459 let mut vals: Vec<F> = arr.to_vec();
460 vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
461 vals.dedup_by(|a, b| (*a - *b).abs() < eps);
462 vals
463}
464
465fn build_sub_rows<F: Float>(x: &Array2<F>, row_indices: &[usize]) -> Array2<F> {
467 x.select(ndarray::Axis(0), row_indices)
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use approx::assert_abs_diff_eq;
474 use ndarray::array;
475
476 #[test]
477 fn test_basic_binary_classification() {
478 let x = array![
479 [1.0, 0.0],
480 [2.0, 0.0],
481 [3.0, 0.0],
482 [10.0, 1.0],
483 [11.0, 1.0],
484 [12.0, 1.0]
485 ];
486 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
487
488 let gb = GradientBoostingClassifier {
489 n_estimators: 50,
490 learning_rate: 0.1,
491 max_depth: Some(3),
492 seed: 42,
493 ..Default::default()
494 };
495 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
496
497 let preds = fitted.predict(&x).unwrap();
498 for (p, t) in preds.iter().zip(y.iter()) {
499 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
500 }
501 }
502
503 #[test]
504 fn test_multiclass_classification() {
505 let x = array![
507 [0.0, 0.0],
508 [0.5, 0.0],
509 [1.0, 0.0],
510 [5.0, 5.0],
511 [5.5, 5.0],
512 [6.0, 5.0],
513 [10.0, 10.0],
514 [10.5, 10.0],
515 [11.0, 10.0]
516 ];
517 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
518
519 let gb = GradientBoostingClassifier {
520 n_estimators: 100,
521 learning_rate: 0.1,
522 max_depth: Some(3),
523 seed: 42,
524 ..Default::default()
525 };
526 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
527
528 let preds = fitted.predict(&x).unwrap();
529 for (p, t) in preds.iter().zip(y.iter()) {
530 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
531 }
532
533 assert_eq!(fitted.classes().len(), 3);
535 }
536
537 #[test]
538 fn test_reproducibility() {
539 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
540 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
541
542 let gb = GradientBoostingClassifier {
543 n_estimators: 20,
544 seed: 123,
545 ..Default::default()
546 };
547
548 let fitted1: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
549 let fitted2: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
550
551 let preds1 = fitted1.predict(&x).unwrap();
552 let preds2 = fitted2.predict(&x).unwrap();
553
554 for (a, b) in preds1.iter().zip(preds2.iter()) {
555 assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
556 }
557 }
558
559 #[test]
560 fn test_subsample_binary() {
561 let x = array![
562 [1.0, 0.0],
563 [2.0, 0.0],
564 [3.0, 0.0],
565 [4.0, 0.0],
566 [10.0, 1.0],
567 [11.0, 1.0],
568 [12.0, 1.0],
569 [13.0, 1.0]
570 ];
571 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
572
573 let gb = GradientBoostingClassifier {
574 n_estimators: 80,
575 learning_rate: 0.1,
576 max_depth: Some(3),
577 subsample: 0.75,
578 seed: 42,
579 ..Default::default()
580 };
581 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
582
583 let preds = fitted.predict(&x).unwrap();
584 for (p, t) in preds.iter().zip(y.iter()) {
585 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
586 }
587 }
588
589 #[test]
590 fn test_shape_mismatch_error() {
591 let x = array![[1.0], [2.0]];
592 let y = array![0.0, 1.0, 2.0];
593
594 let gb = GradientBoostingClassifier::default();
595 let result: std::result::Result<FittedGradientBoostingClassifier<f64>, _> = gb.fit(&x, &y);
596 assert!(result.is_err());
597 }
598
599 #[test]
600 fn test_predict_wrong_features_error() {
601 let x = array![[1.0, 2.0], [3.0, 4.0]];
602 let y = array![0.0, 1.0];
603
604 let gb = GradientBoostingClassifier {
605 n_estimators: 5,
606 seed: 0,
607 ..Default::default()
608 };
609 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
610
611 let x_bad = array![[1.0], [2.0]];
612 let result = fitted.predict(&x_bad);
613 assert!(result.is_err());
614 }
615
616 #[test]
617 fn test_single_class_error() {
618 let x = array![[1.0], [2.0], [3.0]];
619 let y = array![1.0, 1.0, 1.0];
620
621 let gb = GradientBoostingClassifier::default();
622 let result: std::result::Result<FittedGradientBoostingClassifier<f64>, _> = gb.fit(&x, &y);
623 assert!(result.is_err());
624 }
625
626 #[test]
627 fn test_invalid_parameters() {
628 let x = array![[1.0], [2.0]];
629 let y = array![0.0, 1.0];
630
631 let gb = GradientBoostingClassifier {
632 n_estimators: 0,
633 ..Default::default()
634 };
635 assert!(Fit::<f64>::fit(&gb, &x, &y).is_err());
636
637 let gb = GradientBoostingClassifier {
638 learning_rate: -0.1,
639 ..Default::default()
640 };
641 assert!(Fit::<f64>::fit(&gb, &x, &y).is_err());
642 }
643
644 #[test]
645 fn test_n_estimators_one() {
646 let x = array![
647 [1.0, 0.0],
648 [2.0, 0.0],
649 [3.0, 0.0],
650 [10.0, 1.0],
651 [11.0, 1.0],
652 [12.0, 1.0]
653 ];
654 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
655
656 let gb = GradientBoostingClassifier {
657 n_estimators: 1,
658 learning_rate: 0.1,
659 max_depth: Some(3),
660 seed: 42,
661 ..Default::default()
662 };
663 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
664 assert_eq!(fitted.n_estimators(), 1);
665
666 let preds = fitted.predict(&x).unwrap();
668 assert_eq!(preds.len(), y.len());
669 }
670
671 #[test]
672 fn test_predictions_are_valid_labels() {
673 let x = array![
674 [0.0, 0.0],
675 [0.5, 0.0],
676 [1.0, 0.0],
677 [5.0, 5.0],
678 [5.5, 5.0],
679 [6.0, 5.0],
680 [10.0, 10.0],
681 [10.5, 10.0],
682 [11.0, 10.0]
683 ];
684 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
685
686 let gb = GradientBoostingClassifier {
687 n_estimators: 50,
688 learning_rate: 0.1,
689 max_depth: Some(3),
690 seed: 42,
691 ..Default::default()
692 };
693 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
694
695 let preds = fitted.predict(&x).unwrap();
696 let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
697 for &p in preds.iter() {
698 assert!(
699 valid_labels.contains(&p.to_bits()),
700 "prediction {p} is not a valid training label"
701 );
702 }
703 }
704
705 #[test]
706 fn test_subsample_impact() {
707 let x = array![
710 [1.0, 0.0],
711 [2.0, 0.0],
712 [3.0, 0.0],
713 [4.0, 0.0],
714 [10.0, 1.0],
715 [11.0, 1.0],
716 [12.0, 1.0],
717 [13.0, 1.0]
718 ];
719 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
720
721 let gb = GradientBoostingClassifier {
722 n_estimators: 80,
723 learning_rate: 0.1,
724 max_depth: Some(3),
725 subsample: 0.5,
726 seed: 7,
727 ..Default::default()
728 };
729 let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
730
731 let preds = fitted.predict(&x).unwrap();
732 let correct: usize = preds
733 .iter()
734 .zip(y.iter())
735 .filter(|(p, t)| (*p - *t).abs() < 1e-10)
736 .count();
737 let accuracy = correct as f64 / y.len() as f64;
738 assert!(
739 accuracy >= 0.75,
740 "subsample=0.5 should still achieve reasonable accuracy, got {accuracy}"
741 );
742 }
743
744 #[test]
745 fn test_learning_rate_zero_error_or_degrades() {
746 let x = array![
750 [1.0, 0.0],
751 [2.0, 0.0],
752 [3.0, 0.0],
753 [10.0, 1.0],
754 [11.0, 1.0],
755 [12.0, 1.0]
756 ];
757 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
758
759 let gb_normal = GradientBoostingClassifier {
761 n_estimators: 50,
762 learning_rate: 0.1,
763 max_depth: Some(3),
764 seed: 42,
765 ..Default::default()
766 };
767 let fitted_normal: FittedGradientBoostingClassifier<f64> = gb_normal.fit(&x, &y).unwrap();
768 let preds_normal = fitted_normal.predict(&x).unwrap();
769 let correct_normal: usize = preds_normal
770 .iter()
771 .zip(y.iter())
772 .filter(|(p, t)| (*p - *t).abs() < 1e-10)
773 .count();
774
775 let gb_tiny = GradientBoostingClassifier {
777 n_estimators: 50,
778 learning_rate: 0.001,
779 max_depth: Some(3),
780 seed: 42,
781 ..Default::default()
782 };
783 let fitted_tiny: FittedGradientBoostingClassifier<f64> = gb_tiny.fit(&x, &y).unwrap();
784 let preds_tiny = fitted_tiny.predict(&x).unwrap();
785 let correct_tiny: usize = preds_tiny
786 .iter()
787 .zip(y.iter())
788 .filter(|(p, t)| (*p - *t).abs() < 1e-10)
789 .count();
790
791 assert!(
793 correct_normal >= correct_tiny,
794 "normal lr ({correct_normal} correct) should be >= tiny lr ({correct_tiny} correct)"
795 );
796 }
797}
798
799impl<F: Float> PredictProba<F> for FittedGradientBoostingClassifier<F> {
800 fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
801 Self::predict_proba(self, x)
802 }
803}