Skip to main content

anofox_ml_ensemble/
gradient_boosting_classifier.rs

1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use anofox_ml_trees::{DecisionTreeRegressor, FittedDecisionTreeRegressor};
3use ndarray::{Array1, Array2};
4use rand::rngs::StdRng;
5use rand::seq::SliceRandom;
6use rand::SeedableRng;
7
8/// Gradient boosting classifier parameters (unfitted state).
9///
10/// For binary classification, fits trees to the negative gradient of the
11/// log loss (logistic regression loss). For multi-class (>2 classes), uses
12/// a one-vs-rest strategy with separate sets of trees for each class.
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct GradientBoostingClassifier {
15    /// Number of boosting rounds (trees per class for multi-class).
16    pub n_estimators: usize,
17    /// Shrinkage applied to each tree's contribution.
18    pub learning_rate: f64,
19    /// Maximum depth of each tree.
20    pub max_depth: Option<usize>,
21    /// Minimum samples required to split a node.
22    pub min_samples_split: usize,
23    /// Minimum samples required in a leaf node.
24    pub min_samples_leaf: usize,
25    /// Fraction of training samples used per tree.
26    pub subsample: f64,
27    /// Random seed for reproducibility.
28    pub seed: u64,
29}
30
31impl GradientBoostingClassifier {
32    /// Create a new `GradientBoostingClassifier` with default parameters.
33    pub fn new() -> Self {
34        Self {
35            n_estimators: 100,
36            learning_rate: 0.1,
37            max_depth: Some(3),
38            min_samples_split: 2,
39            min_samples_leaf: 1,
40            subsample: 1.0,
41            seed: 0,
42        }
43    }
44
45    /// Set the number of boosting rounds.
46    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
47        self.n_estimators = n_estimators;
48        self
49    }
50
51    /// Set the learning rate (shrinkage).
52    pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
53        self.learning_rate = learning_rate;
54        self
55    }
56
57    /// Set the maximum depth of each tree.
58    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
59        self.max_depth = max_depth;
60        self
61    }
62
63    /// Set the minimum number of samples required to split a node.
64    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
65        self.min_samples_split = min_samples_split;
66        self
67    }
68
69    /// Set the minimum number of samples required in a leaf node.
70    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
71        self.min_samples_leaf = min_samples_leaf;
72        self
73    }
74
75    /// Set the fraction of samples used per boosting round.
76    pub fn with_subsample(mut self, subsample: f64) -> Self {
77        self.subsample = subsample;
78        self
79    }
80
81    /// Set the random seed for reproducibility.
82    pub fn with_seed(mut self, seed: u64) -> Self {
83        self.seed = seed;
84        self
85    }
86}
87
88impl Default for GradientBoostingClassifier {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94/// Fitted gradient boosting classifier.
95#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
96#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
97pub struct FittedGradientBoostingClassifier<F: Float> {
98    /// Unique class labels sorted in ascending order.
99    classes: Vec<F>,
100    /// For binary: a single list of trees operating on log-odds.
101    /// For multi-class OVR: one list of trees per class.
102    tree_sets: Vec<Vec<FittedDecisionTreeRegressor<F>>>,
103    /// Initial log-odds per class set.
104    initial_values: Vec<F>,
105    /// Learning rate.
106    learning_rate: F,
107    /// Number of features expected.
108    n_features: usize,
109}
110
111impl<F: Float> Fit<F> for GradientBoostingClassifier {
112    type Fitted = FittedGradientBoostingClassifier<F>;
113
114    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
115        if x.nrows() != y.len() {
116            return Err(RustMlError::ShapeMismatch(format!(
117                "X has {} rows but y has {} elements",
118                x.nrows(),
119                y.len()
120            )));
121        }
122        if x.is_empty() {
123            return Err(RustMlError::EmptyInput("training data is empty".into()));
124        }
125        if self.n_estimators == 0 {
126            return Err(RustMlError::InvalidParameter(
127                "n_estimators must be > 0".into(),
128            ));
129        }
130        if self.learning_rate <= 0.0 {
131            return Err(RustMlError::InvalidParameter(
132                "learning_rate must be > 0".into(),
133            ));
134        }
135        if self.subsample <= 0.0 || self.subsample > 1.0 {
136            return Err(RustMlError::InvalidParameter(
137                "subsample must be in (0, 1]".into(),
138            ));
139        }
140
141        // Discover unique classes.
142        let classes = unique_sorted(y);
143        let n_classes = classes.len();
144        if n_classes < 2 {
145            return Err(RustMlError::InvalidParameter(
146                "y must contain at least 2 distinct classes".into(),
147            ));
148        }
149
150        let n_features = x.ncols();
151        let lr = F::from_f64(self.learning_rate).unwrap();
152
153        if n_classes == 2 {
154            // Binary classification: single set of trees on log-odds.
155            let (initial, trees) = self.fit_binary(x, y, &classes[1], lr)?;
156            Ok(FittedGradientBoostingClassifier {
157                classes,
158                tree_sets: vec![trees],
159                initial_values: vec![initial],
160                learning_rate: lr,
161                n_features,
162            })
163        } else {
164            // Multi-class: one-vs-rest, one tree set per class.
165            let mut tree_sets = Vec::with_capacity(n_classes);
166            let mut initial_values = Vec::with_capacity(n_classes);
167
168            for class in &classes {
169                let (initial, trees) = self.fit_binary(x, y, class, lr)?;
170                tree_sets.push(trees);
171                initial_values.push(initial);
172            }
173
174            Ok(FittedGradientBoostingClassifier {
175                classes,
176                tree_sets,
177                initial_values,
178                learning_rate: lr,
179                n_features,
180            })
181        }
182    }
183}
184
185impl GradientBoostingClassifier {
186    /// Fit a binary gradient boosting model where the positive class is
187    /// `positive_class`. Returns (initial_log_odds, fitted_trees).
188    fn fit_binary<F: Float>(
189        &self,
190        x: &Array2<F>,
191        y: &Array1<F>,
192        positive_class: &F,
193        lr: F,
194    ) -> Result<(F, Vec<FittedDecisionTreeRegressor<F>>)> {
195        let n_samples = x.nrows();
196        let eps = F::from_f64(1e-15).unwrap();
197
198        // Convert labels to binary 0/1 for the positive class.
199        let binary_y: Array1<F> = y.mapv(|v| {
200            if (v - *positive_class).abs() < eps {
201                F::one()
202            } else {
203                F::zero()
204            }
205        });
206
207        // Initial prediction: log-odds of positive class frequency.
208        let p = binary_y.sum() / F::from_usize(n_samples).unwrap();
209        let p_clipped = clamp(p, eps, F::one() - eps);
210        let initial_log_odds = (p_clipped / (F::one() - p_clipped)).ln();
211
212        let mut log_odds = Array1::from_elem(n_samples, initial_log_odds);
213
214        let tree_params = DecisionTreeRegressor {
215            max_depth: self.max_depth,
216            min_samples_split: self.min_samples_split,
217            min_samples_leaf: self.min_samples_leaf,
218            max_features: None,
219            sample_weight: None,
220        };
221
222        let mut rng = StdRng::seed_from_u64(self.seed);
223        let mut trees = Vec::with_capacity(self.n_estimators);
224        let subsample_size = ((self.subsample * n_samples as f64).round() as usize).max(1);
225
226        // Pre-allocate reusable buffers outside the boosting loop.
227        let mut probs = Array1::<F>::zeros(n_samples);
228        let mut residuals = Array1::<F>::zeros(n_samples);
229        let mut indices: Vec<usize> = (0..n_samples).collect();
230
231        for _ in 0..self.n_estimators {
232            // Compute probabilities and residuals in-place.
233            for i in 0..n_samples {
234                probs[i] = sigmoid(log_odds[i]);
235                residuals[i] = binary_y[i] - probs[i];
236            }
237
238            // Fit tree to (subsampled) pseudo-residuals.
239            let fitted_tree: FittedDecisionTreeRegressor<F> = if subsample_size < n_samples {
240                indices.clear();
241                indices.extend(0..n_samples);
242                indices.shuffle(&mut rng);
243                indices.truncate(subsample_size);
244                indices.sort_unstable();
245
246                let x_sub = build_sub_rows(x, &indices);
247                let r_sub = Array1::from_vec(indices.iter().map(|&i| residuals[i]).collect());
248                tree_params.fit(&x_sub, &r_sub)?
249            } else {
250                tree_params.fit(x, &residuals)?
251            };
252
253            // Update log-odds on full training set.
254            let tree_preds = fitted_tree.predict(x)?;
255            log_odds += &(tree_preds * lr);
256
257            trees.push(fitted_tree);
258        }
259
260        Ok((initial_log_odds, trees))
261    }
262}
263
264impl<F: Float> Predict<F> for FittedGradientBoostingClassifier<F> {
265    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
266        if x.ncols() != self.n_features {
267            return Err(RustMlError::ShapeMismatch(format!(
268                "expected {} features, got {}",
269                self.n_features,
270                x.ncols()
271            )));
272        }
273
274        let n_samples = x.nrows();
275
276        if self.classes.len() == 2 {
277            // Binary: single set of trees.
278            let log_odds = self.predict_log_odds(x, 0)?;
279            let half = F::from_f64(0.5).unwrap();
280
281            let predictions: Vec<F> = log_odds
282                .iter()
283                .map(|&lo| {
284                    if sigmoid(lo) >= half {
285                        self.classes[1]
286                    } else {
287                        self.classes[0]
288                    }
289                })
290                .collect();
291
292            Ok(Array1::from_vec(predictions))
293        } else {
294            // Multi-class: predict class with highest log-odds (OVR).
295            let mut all_log_odds = Vec::with_capacity(self.classes.len());
296            for k in 0..self.classes.len() {
297                all_log_odds.push(self.predict_log_odds(x, k)?);
298            }
299
300            let mut predictions = Vec::with_capacity(n_samples);
301            for sample_idx in 0..n_samples {
302                let mut best_class = 0;
303                let mut best_val = all_log_odds[0][sample_idx];
304                for (k, log_odds_k) in all_log_odds.iter().enumerate().skip(1) {
305                    if log_odds_k[sample_idx] > best_val {
306                        best_val = log_odds_k[sample_idx];
307                        best_class = k;
308                    }
309                }
310                predictions.push(self.classes[best_class]);
311            }
312
313            Ok(Array1::from_vec(predictions))
314        }
315    }
316}
317
318impl<F: Float> FittedGradientBoostingClassifier<F> {
319    /// Number of estimators per class set.
320    pub fn n_estimators(&self) -> usize {
321        self.tree_sets.first().map_or(0, |ts| ts.len())
322    }
323
324    /// The unique classes discovered during training.
325    pub fn classes(&self) -> &[F] {
326        &self.classes
327    }
328
329    /// Predict class probabilities for each sample.
330    ///
331    /// For binary classification, returns shape `(n_samples, 2)` with columns
332    /// `[P(class_0), P(class_1)]`. For multi-class, returns `(n_samples, n_classes)`
333    /// using softmax over the OVR log-odds.
334    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
335        if x.ncols() != self.n_features {
336            return Err(RustMlError::ShapeMismatch(format!(
337                "expected {} features, got {}",
338                self.n_features,
339                x.ncols()
340            )));
341        }
342
343        let n_samples = x.nrows();
344        let n_classes = self.classes.len();
345
346        if n_classes == 2 {
347            // Binary: sigmoid of log-odds
348            let log_odds = self.predict_log_odds(x, 0)?;
349            let mut proba = Array2::<F>::zeros((n_samples, 2));
350            for i in 0..n_samples {
351                let p1 = sigmoid(log_odds[i]);
352                proba[[i, 0]] = F::one() - p1;
353                proba[[i, 1]] = p1;
354            }
355            Ok(proba)
356        } else {
357            // Multi-class: softmax over OVR log-odds
358            let mut all_log_odds = Vec::with_capacity(n_classes);
359            for k in 0..n_classes {
360                all_log_odds.push(self.predict_log_odds(x, k)?);
361            }
362
363            let mut proba = Array2::<F>::zeros((n_samples, n_classes));
364            for i in 0..n_samples {
365                // Find max for numerical stability
366                let mut max_lo = all_log_odds[0][i];
367                for k in 1..n_classes {
368                    if all_log_odds[k][i] > max_lo {
369                        max_lo = all_log_odds[k][i];
370                    }
371                }
372                // Compute exp(lo - max) and sum
373                let mut sum = F::zero();
374                for k in 0..n_classes {
375                    let e = (all_log_odds[k][i] - max_lo).exp();
376                    proba[[i, k]] = e;
377                    sum += e;
378                }
379                // Normalize
380                for k in 0..n_classes {
381                    proba[[i, k]] /= sum;
382                }
383            }
384            Ok(proba)
385        }
386    }
387
388    /// Feature importances averaged across all tree sets, normalized to sum to 1.
389    ///
390    /// Each tree's importances are accumulated and weighted by the number of
391    /// boosting rounds.
392    pub fn feature_importances(&self) -> Array1<F> {
393        let mut importances = vec![F::zero(); self.n_features];
394        let mut total_trees = 0usize;
395
396        for tree_set in &self.tree_sets {
397            for tree in tree_set {
398                let tree_imp = tree.feature_importances();
399                for (j, &imp) in tree_imp.iter().enumerate() {
400                    importances[j] += imp;
401                }
402                total_trees += 1;
403            }
404        }
405
406        if total_trees > 0 {
407            let total_f = F::from_usize(total_trees).unwrap();
408            for imp in &mut importances {
409                *imp /= total_f;
410            }
411        }
412
413        // Normalize to sum to 1
414        let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
415        if sum > F::zero() {
416            Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
417        } else {
418            Array1::zeros(self.n_features)
419        }
420    }
421
422    /// Compute raw log-odds for the k-th tree set.
423    fn predict_log_odds(&self, x: &Array2<F>, k: usize) -> Result<Array1<F>> {
424        let n_samples = x.nrows();
425        let mut log_odds = Array1::from_elem(n_samples, self.initial_values[k]);
426
427        for tree in &self.tree_sets[k] {
428            let tree_preds = tree.predict(x)?;
429            log_odds += &(tree_preds * self.learning_rate);
430        }
431
432        Ok(log_odds)
433    }
434}
435
436// ---------------------------------------------------------------------------
437// Helper functions
438// ---------------------------------------------------------------------------
439
440/// Sigmoid function: 1 / (1 + exp(-x)).
441fn sigmoid<F: Float>(x: F) -> F {
442    F::one() / (F::one() + (-x).exp())
443}
444
445/// Clamp value to [lo, hi].
446fn clamp<F: Float>(x: F, lo: F, hi: F) -> F {
447    if x < lo {
448        lo
449    } else if x > hi {
450        hi
451    } else {
452        x
453    }
454}
455
456/// Return sorted unique values from an array.
457fn unique_sorted<F: Float>(arr: &Array1<F>) -> Vec<F> {
458    let eps = F::from_f64(1e-9).unwrap();
459    let mut vals: Vec<F> = arr.to_vec();
460    vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
461    vals.dedup_by(|a, b| (*a - *b).abs() < eps);
462    vals
463}
464
465/// Build a sub-matrix selecting specific rows from `x`.
466fn build_sub_rows<F: Float>(x: &Array2<F>, row_indices: &[usize]) -> Array2<F> {
467    x.select(ndarray::Axis(0), row_indices)
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use approx::assert_abs_diff_eq;
474    use ndarray::array;
475
476    #[test]
477    fn test_basic_binary_classification() {
478        let x = array![
479            [1.0, 0.0],
480            [2.0, 0.0],
481            [3.0, 0.0],
482            [10.0, 1.0],
483            [11.0, 1.0],
484            [12.0, 1.0]
485        ];
486        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
487
488        let gb = GradientBoostingClassifier {
489            n_estimators: 50,
490            learning_rate: 0.1,
491            max_depth: Some(3),
492            seed: 42,
493            ..Default::default()
494        };
495        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
496
497        let preds = fitted.predict(&x).unwrap();
498        for (p, t) in preds.iter().zip(y.iter()) {
499            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
500        }
501    }
502
503    #[test]
504    fn test_multiclass_classification() {
505        // Three classes with clearly separable data.
506        let x = array![
507            [0.0, 0.0],
508            [0.5, 0.0],
509            [1.0, 0.0],
510            [5.0, 5.0],
511            [5.5, 5.0],
512            [6.0, 5.0],
513            [10.0, 10.0],
514            [10.5, 10.0],
515            [11.0, 10.0]
516        ];
517        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
518
519        let gb = GradientBoostingClassifier {
520            n_estimators: 100,
521            learning_rate: 0.1,
522            max_depth: Some(3),
523            seed: 42,
524            ..Default::default()
525        };
526        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
527
528        let preds = fitted.predict(&x).unwrap();
529        for (p, t) in preds.iter().zip(y.iter()) {
530            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
531        }
532
533        // Verify all 3 classes were detected.
534        assert_eq!(fitted.classes().len(), 3);
535    }
536
537    #[test]
538    fn test_reproducibility() {
539        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
540        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
541
542        let gb = GradientBoostingClassifier {
543            n_estimators: 20,
544            seed: 123,
545            ..Default::default()
546        };
547
548        let fitted1: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
549        let fitted2: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
550
551        let preds1 = fitted1.predict(&x).unwrap();
552        let preds2 = fitted2.predict(&x).unwrap();
553
554        for (a, b) in preds1.iter().zip(preds2.iter()) {
555            assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
556        }
557    }
558
559    #[test]
560    fn test_subsample_binary() {
561        let x = array![
562            [1.0, 0.0],
563            [2.0, 0.0],
564            [3.0, 0.0],
565            [4.0, 0.0],
566            [10.0, 1.0],
567            [11.0, 1.0],
568            [12.0, 1.0],
569            [13.0, 1.0]
570        ];
571        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
572
573        let gb = GradientBoostingClassifier {
574            n_estimators: 80,
575            learning_rate: 0.1,
576            max_depth: Some(3),
577            subsample: 0.75,
578            seed: 42,
579            ..Default::default()
580        };
581        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
582
583        let preds = fitted.predict(&x).unwrap();
584        for (p, t) in preds.iter().zip(y.iter()) {
585            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
586        }
587    }
588
589    #[test]
590    fn test_shape_mismatch_error() {
591        let x = array![[1.0], [2.0]];
592        let y = array![0.0, 1.0, 2.0];
593
594        let gb = GradientBoostingClassifier::default();
595        let result: std::result::Result<FittedGradientBoostingClassifier<f64>, _> = gb.fit(&x, &y);
596        assert!(result.is_err());
597    }
598
599    #[test]
600    fn test_predict_wrong_features_error() {
601        let x = array![[1.0, 2.0], [3.0, 4.0]];
602        let y = array![0.0, 1.0];
603
604        let gb = GradientBoostingClassifier {
605            n_estimators: 5,
606            seed: 0,
607            ..Default::default()
608        };
609        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
610
611        let x_bad = array![[1.0], [2.0]];
612        let result = fitted.predict(&x_bad);
613        assert!(result.is_err());
614    }
615
616    #[test]
617    fn test_single_class_error() {
618        let x = array![[1.0], [2.0], [3.0]];
619        let y = array![1.0, 1.0, 1.0];
620
621        let gb = GradientBoostingClassifier::default();
622        let result: std::result::Result<FittedGradientBoostingClassifier<f64>, _> = gb.fit(&x, &y);
623        assert!(result.is_err());
624    }
625
626    #[test]
627    fn test_invalid_parameters() {
628        let x = array![[1.0], [2.0]];
629        let y = array![0.0, 1.0];
630
631        let gb = GradientBoostingClassifier {
632            n_estimators: 0,
633            ..Default::default()
634        };
635        assert!(Fit::<f64>::fit(&gb, &x, &y).is_err());
636
637        let gb = GradientBoostingClassifier {
638            learning_rate: -0.1,
639            ..Default::default()
640        };
641        assert!(Fit::<f64>::fit(&gb, &x, &y).is_err());
642    }
643
644    #[test]
645    fn test_n_estimators_one() {
646        let x = array![
647            [1.0, 0.0],
648            [2.0, 0.0],
649            [3.0, 0.0],
650            [10.0, 1.0],
651            [11.0, 1.0],
652            [12.0, 1.0]
653        ];
654        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
655
656        let gb = GradientBoostingClassifier {
657            n_estimators: 1,
658            learning_rate: 0.1,
659            max_depth: Some(3),
660            seed: 42,
661            ..Default::default()
662        };
663        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
664        assert_eq!(fitted.n_estimators(), 1);
665
666        // Even a single boosting round should produce predictions with the correct length.
667        let preds = fitted.predict(&x).unwrap();
668        assert_eq!(preds.len(), y.len());
669    }
670
671    #[test]
672    fn test_predictions_are_valid_labels() {
673        let x = array![
674            [0.0, 0.0],
675            [0.5, 0.0],
676            [1.0, 0.0],
677            [5.0, 5.0],
678            [5.5, 5.0],
679            [6.0, 5.0],
680            [10.0, 10.0],
681            [10.5, 10.0],
682            [11.0, 10.0]
683        ];
684        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
685
686        let gb = GradientBoostingClassifier {
687            n_estimators: 50,
688            learning_rate: 0.1,
689            max_depth: Some(3),
690            seed: 42,
691            ..Default::default()
692        };
693        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
694
695        let preds = fitted.predict(&x).unwrap();
696        let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
697        for &p in preds.iter() {
698            assert!(
699                valid_labels.contains(&p.to_bits()),
700                "prediction {p} is not a valid training label"
701            );
702        }
703    }
704
705    #[test]
706    fn test_subsample_impact() {
707        // With subsample < 1.0, the model should still produce reasonable
708        // predictions on clearly separable data.
709        let x = array![
710            [1.0, 0.0],
711            [2.0, 0.0],
712            [3.0, 0.0],
713            [4.0, 0.0],
714            [10.0, 1.0],
715            [11.0, 1.0],
716            [12.0, 1.0],
717            [13.0, 1.0]
718        ];
719        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
720
721        let gb = GradientBoostingClassifier {
722            n_estimators: 80,
723            learning_rate: 0.1,
724            max_depth: Some(3),
725            subsample: 0.5,
726            seed: 7,
727            ..Default::default()
728        };
729        let fitted: FittedGradientBoostingClassifier<f64> = gb.fit(&x, &y).unwrap();
730
731        let preds = fitted.predict(&x).unwrap();
732        let correct: usize = preds
733            .iter()
734            .zip(y.iter())
735            .filter(|(p, t)| (*p - *t).abs() < 1e-10)
736            .count();
737        let accuracy = correct as f64 / y.len() as f64;
738        assert!(
739            accuracy >= 0.75,
740            "subsample=0.5 should still achieve reasonable accuracy, got {accuracy}"
741        );
742    }
743
744    #[test]
745    fn test_learning_rate_zero_error_or_degrades() {
746        // A very small learning rate with few estimators should produce weaker
747        // predictions than a normal learning rate (less overfitting / underfitting
748        // with few rounds).
749        let x = array![
750            [1.0, 0.0],
751            [2.0, 0.0],
752            [3.0, 0.0],
753            [10.0, 1.0],
754            [11.0, 1.0],
755            [12.0, 1.0]
756        ];
757        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
758
759        // Normal learning rate should fit well with enough estimators.
760        let gb_normal = GradientBoostingClassifier {
761            n_estimators: 50,
762            learning_rate: 0.1,
763            max_depth: Some(3),
764            seed: 42,
765            ..Default::default()
766        };
767        let fitted_normal: FittedGradientBoostingClassifier<f64> = gb_normal.fit(&x, &y).unwrap();
768        let preds_normal = fitted_normal.predict(&x).unwrap();
769        let correct_normal: usize = preds_normal
770            .iter()
771            .zip(y.iter())
772            .filter(|(p, t)| (*p - *t).abs() < 1e-10)
773            .count();
774
775        // Tiny learning rate with the same number of estimators should learn slower.
776        let gb_tiny = GradientBoostingClassifier {
777            n_estimators: 50,
778            learning_rate: 0.001,
779            max_depth: Some(3),
780            seed: 42,
781            ..Default::default()
782        };
783        let fitted_tiny: FittedGradientBoostingClassifier<f64> = gb_tiny.fit(&x, &y).unwrap();
784        let preds_tiny = fitted_tiny.predict(&x).unwrap();
785        let correct_tiny: usize = preds_tiny
786            .iter()
787            .zip(y.iter())
788            .filter(|(p, t)| (*p - *t).abs() < 1e-10)
789            .count();
790
791        // The normal learning rate should be at least as accurate (likely better).
792        assert!(
793            correct_normal >= correct_tiny,
794            "normal lr ({correct_normal} correct) should be >= tiny lr ({correct_tiny} correct)"
795        );
796    }
797}
798
799impl<F: Float> PredictProba<F> for FittedGradientBoostingClassifier<F> {
800    fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
801        Self::predict_proba(self, x)
802    }
803}