Skip to main content

ferrolearn_tree/
gradient_boosting.rs

1//! Gradient boosting classifiers and regressors.
2//!
3//! This module provides [`GradientBoostingClassifier`] and [`GradientBoostingRegressor`],
4//! which build ensembles of decision trees sequentially. Each tree fits the negative
5//! gradient (pseudo-residuals) of the loss function, progressively reducing prediction error.
6//!
7//! # Regression Losses
8//!
9//! - **`LeastSquares`** (L2): mean squared error; pseudo-residuals are `y - F(x)`.
10//! - **`Lad`** (L1): least absolute deviation; pseudo-residuals are `sign(y - F(x))`.
11//! - **`Huber`**: a blend of L2 (for small residuals) and L1 (for large residuals),
12//!   controlled by the `alpha` quantile parameter (default 0.9).
13//!
14//! # Classification Loss
15//!
16//! - **`LogLoss`**: binary and multiclass logistic loss. For binary classification a
17//!   single model is trained on log-odds; for *K*-class problems *K* trees are built
18//!   per boosting round (one-vs-rest in probability space via softmax).
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_tree::GradientBoostingRegressor;
24//! use ferrolearn_core::{Fit, Predict};
25//! use ndarray::{array, Array1, Array2};
26//!
27//! let x = Array2::from_shape_vec((8, 1), vec![
28//!     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
29//! ]).unwrap();
30//! let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
31//!
32//! let model = GradientBoostingRegressor::<f64>::new()
33//!     .with_n_estimators(50)
34//!     .with_learning_rate(0.1)
35//!     .with_random_state(42);
36//! let fitted = model.fit(&x, &y).unwrap();
37//! let preds = fitted.predict(&x).unwrap();
38//! assert_eq!(preds.len(), 8);
39//! ```
40
41use ferrolearn_core::error::FerroError;
42use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
43use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
44use ferrolearn_core::traits::{Fit, Predict};
45use ndarray::{Array1, Array2};
46use num_traits::{Float, FromPrimitive, ToPrimitive};
47use rand::SeedableRng;
48use rand::rngs::StdRng;
49use rand::seq::index::sample as rand_sample_indices;
50
51use crate::decision_tree::{
52    self, Node, build_regression_tree_with_feature_subset, compute_feature_importances,
53};
54
55// ---------------------------------------------------------------------------
56// Regression loss enum
57// ---------------------------------------------------------------------------
58
59/// Loss function for gradient boosting regression.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum RegressionLoss {
62    /// Least squares (L2) loss.
63    LeastSquares,
64    /// Least absolute deviation (L1) loss.
65    Lad,
66    /// Huber loss: L2 for small residuals, L1 for large residuals.
67    Huber,
68}
69
70/// Loss function for gradient boosting classification.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ClassificationLoss {
73    /// Log-loss (logistic / cross-entropy) for binary and multiclass.
74    LogLoss,
75}
76
77// ---------------------------------------------------------------------------
78// GradientBoostingRegressor
79// ---------------------------------------------------------------------------
80
81/// Gradient boosting regressor.
82///
83/// Builds an additive model in a forward stage-wise fashion, fitting each
84/// regression tree to the negative gradient of the loss function evaluated
85/// on the current ensemble prediction.
86///
87/// # Type Parameters
88///
89/// - `F`: The floating-point type (`f32` or `f64`).
90#[derive(Debug, Clone)]
91pub struct GradientBoostingRegressor<F> {
92    /// Number of boosting stages (trees).
93    pub n_estimators: usize,
94    /// Learning rate (shrinkage) applied to each tree's contribution.
95    pub learning_rate: f64,
96    /// Maximum depth of each tree.
97    pub max_depth: Option<usize>,
98    /// Minimum number of samples required to split an internal node.
99    pub min_samples_split: usize,
100    /// Minimum number of samples required in a leaf node.
101    pub min_samples_leaf: usize,
102    /// Fraction of samples to use for fitting each tree (stochastic boosting).
103    pub subsample: f64,
104    /// Loss function.
105    pub loss: RegressionLoss,
106    /// Alpha quantile for Huber loss (only used when `loss == Huber`).
107    pub huber_alpha: f64,
108    /// Random seed for reproducibility.
109    pub random_state: Option<u64>,
110    _marker: std::marker::PhantomData<F>,
111}
112
113impl<F: Float> GradientBoostingRegressor<F> {
114    /// Create a new `GradientBoostingRegressor` with default settings.
115    ///
116    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
117    /// `max_depth = Some(3)`, `min_samples_split = 2`,
118    /// `min_samples_leaf = 1`, `subsample = 1.0`,
119    /// `loss = LeastSquares`, `huber_alpha = 0.9`.
120    #[must_use]
121    pub fn new() -> Self {
122        Self {
123            n_estimators: 100,
124            learning_rate: 0.1,
125            max_depth: Some(3),
126            min_samples_split: 2,
127            min_samples_leaf: 1,
128            subsample: 1.0,
129            loss: RegressionLoss::LeastSquares,
130            huber_alpha: 0.9,
131            random_state: None,
132            _marker: std::marker::PhantomData,
133        }
134    }
135
136    /// Set the number of boosting stages.
137    #[must_use]
138    pub fn with_n_estimators(mut self, n: usize) -> Self {
139        self.n_estimators = n;
140        self
141    }
142
143    /// Set the learning rate (shrinkage).
144    #[must_use]
145    pub fn with_learning_rate(mut self, lr: f64) -> Self {
146        self.learning_rate = lr;
147        self
148    }
149
150    /// Set the maximum tree depth.
151    #[must_use]
152    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
153        self.max_depth = d;
154        self
155    }
156
157    /// Set the minimum number of samples to split a node.
158    #[must_use]
159    pub fn with_min_samples_split(mut self, n: usize) -> Self {
160        self.min_samples_split = n;
161        self
162    }
163
164    /// Set the minimum number of samples in a leaf.
165    #[must_use]
166    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
167        self.min_samples_leaf = n;
168        self
169    }
170
171    /// Set the subsample ratio (fraction of training data per tree).
172    #[must_use]
173    pub fn with_subsample(mut self, ratio: f64) -> Self {
174        self.subsample = ratio;
175        self
176    }
177
178    /// Set the loss function.
179    #[must_use]
180    pub fn with_loss(mut self, loss: RegressionLoss) -> Self {
181        self.loss = loss;
182        self
183    }
184
185    /// Set the alpha quantile for Huber loss.
186    #[must_use]
187    pub fn with_huber_alpha(mut self, alpha: f64) -> Self {
188        self.huber_alpha = alpha;
189        self
190    }
191
192    /// Set the random seed for reproducibility.
193    #[must_use]
194    pub fn with_random_state(mut self, seed: u64) -> Self {
195        self.random_state = Some(seed);
196        self
197    }
198}
199
200impl<F: Float> Default for GradientBoostingRegressor<F> {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206// ---------------------------------------------------------------------------
207// FittedGradientBoostingRegressor
208// ---------------------------------------------------------------------------
209
210/// A fitted gradient boosting regressor.
211///
212/// Stores the initial prediction (intercept) and the sequence of fitted trees.
213/// Predictions are computed as `init + learning_rate * sum(tree_predictions)`.
214#[derive(Debug, Clone)]
215pub struct FittedGradientBoostingRegressor<F> {
216    /// Initial prediction (mean of training targets for L2 loss, median for L1/Huber).
217    init: F,
218    /// Learning rate used during training.
219    learning_rate: F,
220    /// Sequence of fitted trees (one per boosting round).
221    trees: Vec<Vec<Node<F>>>,
222    /// Number of features.
223    n_features: usize,
224    /// Per-feature importance scores (normalised).
225    feature_importances: Array1<F>,
226}
227
228impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for GradientBoostingRegressor<F> {
229    type Fitted = FittedGradientBoostingRegressor<F>;
230    type Error = FerroError;
231
232    /// Fit the gradient boosting regressor.
233    ///
234    /// # Errors
235    ///
236    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
237    /// numbers of samples.
238    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
239    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
240    fn fit(
241        &self,
242        x: &Array2<F>,
243        y: &Array1<F>,
244    ) -> Result<FittedGradientBoostingRegressor<F>, FerroError> {
245        let (n_samples, n_features) = x.dim();
246
247        if n_samples != y.len() {
248            return Err(FerroError::ShapeMismatch {
249                expected: vec![n_samples],
250                actual: vec![y.len()],
251                context: "y length must match number of samples in X".into(),
252            });
253        }
254        if n_samples == 0 {
255            return Err(FerroError::InsufficientSamples {
256                required: 1,
257                actual: 0,
258                context: "GradientBoostingRegressor requires at least one sample".into(),
259            });
260        }
261        if self.n_estimators == 0 {
262            return Err(FerroError::InvalidParameter {
263                name: "n_estimators".into(),
264                reason: "must be at least 1".into(),
265            });
266        }
267        if self.learning_rate <= 0.0 {
268            return Err(FerroError::InvalidParameter {
269                name: "learning_rate".into(),
270                reason: "must be positive".into(),
271            });
272        }
273        if self.subsample <= 0.0 || self.subsample > 1.0 {
274            return Err(FerroError::InvalidParameter {
275                name: "subsample".into(),
276                reason: "must be in (0, 1]".into(),
277            });
278        }
279
280        let lr = F::from(self.learning_rate).unwrap();
281        let params = decision_tree::TreeParams {
282            max_depth: self.max_depth,
283            min_samples_split: self.min_samples_split,
284            min_samples_leaf: self.min_samples_leaf,
285        };
286
287        // Initial prediction.
288        let init = match self.loss {
289            RegressionLoss::LeastSquares => {
290                let sum: F = y.iter().copied().fold(F::zero(), |a, b| a + b);
291                sum / F::from(n_samples).unwrap()
292            }
293            RegressionLoss::Lad | RegressionLoss::Huber => median_f(y),
294        };
295
296        // Current predictions for each sample.
297        let mut f_vals = Array1::from_elem(n_samples, init);
298
299        let all_features: Vec<usize> = (0..n_features).collect();
300        let subsample_size = ((self.subsample * n_samples as f64).ceil() as usize)
301            .max(1)
302            .min(n_samples);
303
304        let mut rng = match self.random_state {
305            Some(seed) => StdRng::seed_from_u64(seed),
306            None => {
307                use rand::RngCore;
308                StdRng::seed_from_u64(rand::rng().next_u64())
309            }
310        };
311
312        let mut trees = Vec::with_capacity(self.n_estimators);
313
314        for _ in 0..self.n_estimators {
315            // Compute pseudo-residuals (negative gradient).
316            let residuals = compute_regression_residuals(y, &f_vals, self.loss, self.huber_alpha);
317
318            // Subsample indices.
319            let sample_indices = if subsample_size < n_samples {
320                rand_sample_indices(&mut rng, n_samples, subsample_size).into_vec()
321            } else {
322                (0..n_samples).collect()
323            };
324
325            // Build a regression tree on the pseudo-residuals.
326            let tree = build_regression_tree_with_feature_subset(
327                x,
328                &residuals,
329                &sample_indices,
330                &all_features,
331                &params,
332            );
333
334            // Update predictions.
335            for i in 0..n_samples {
336                let row = x.row(i);
337                let leaf_idx = decision_tree::traverse(&tree, &row);
338                if let Node::Leaf { value, .. } = tree[leaf_idx] {
339                    f_vals[i] = f_vals[i] + lr * value;
340                }
341            }
342
343            trees.push(tree);
344        }
345
346        // Compute feature importances across all trees.
347        let mut total_importances = Array1::<F>::zeros(n_features);
348        for tree_nodes in &trees {
349            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
350            total_importances = total_importances + tree_imp;
351        }
352        let imp_sum: F = total_importances
353            .iter()
354            .copied()
355            .fold(F::zero(), |a, b| a + b);
356        if imp_sum > F::zero() {
357            total_importances.mapv_inplace(|v| v / imp_sum);
358        }
359
360        Ok(FittedGradientBoostingRegressor {
361            init,
362            learning_rate: lr,
363            trees,
364            n_features,
365            feature_importances: total_importances,
366        })
367    }
368}
369
370impl<F: Float + Send + Sync + 'static> FittedGradientBoostingRegressor<F> {
371    /// Returns the initial prediction (intercept) of the boosted model.
372    #[must_use]
373    pub fn init(&self) -> F {
374        self.init
375    }
376
377    /// Returns the learning rate used during training.
378    #[must_use]
379    pub fn learning_rate(&self) -> F {
380        self.learning_rate
381    }
382
383    /// Returns a reference to the sequence of fitted trees.
384    #[must_use]
385    pub fn trees(&self) -> &[Vec<Node<F>>] {
386        &self.trees
387    }
388
389    /// Returns the number of features the model was trained on.
390    #[must_use]
391    pub fn n_features(&self) -> usize {
392        self.n_features
393    }
394}
395
396impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedGradientBoostingRegressor<F> {
397    type Output = Array1<F>;
398    type Error = FerroError;
399
400    /// Predict target values.
401    ///
402    /// # Errors
403    ///
404    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
405    /// not match the fitted model.
406    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
407        if x.ncols() != self.n_features {
408            return Err(FerroError::ShapeMismatch {
409                expected: vec![self.n_features],
410                actual: vec![x.ncols()],
411                context: "number of features must match fitted model".into(),
412            });
413        }
414
415        let n_samples = x.nrows();
416        let mut predictions = Array1::from_elem(n_samples, self.init);
417
418        for i in 0..n_samples {
419            let row = x.row(i);
420            for tree_nodes in &self.trees {
421                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
422                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
423                    predictions[i] = predictions[i] + self.learning_rate * value;
424                }
425            }
426        }
427
428        Ok(predictions)
429    }
430}
431
432impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
433    for FittedGradientBoostingRegressor<F>
434{
435    fn feature_importances(&self) -> &Array1<F> {
436        &self.feature_importances
437    }
438}
439
440// Pipeline integration.
441impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for GradientBoostingRegressor<F> {
442    fn fit_pipeline(
443        &self,
444        x: &Array2<F>,
445        y: &Array1<F>,
446    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
447        let fitted = self.fit(x, y)?;
448        Ok(Box::new(fitted))
449    }
450}
451
452impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
453    for FittedGradientBoostingRegressor<F>
454{
455    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
456        self.predict(x)
457    }
458}
459
460// ---------------------------------------------------------------------------
461// GradientBoostingClassifier
462// ---------------------------------------------------------------------------
463
464/// Gradient boosting classifier.
465///
466/// For binary classification a single model is trained on log-odds residuals.
467/// For multiclass (*K* classes), *K* regression trees are built per boosting
468/// round (one-vs-rest in probability space via softmax).
469///
470/// # Type Parameters
471///
472/// - `F`: The floating-point type (`f32` or `f64`).
473#[derive(Debug, Clone)]
474pub struct GradientBoostingClassifier<F> {
475    /// Number of boosting stages.
476    pub n_estimators: usize,
477    /// Learning rate (shrinkage).
478    pub learning_rate: f64,
479    /// Maximum depth of each tree.
480    pub max_depth: Option<usize>,
481    /// Minimum number of samples required to split an internal node.
482    pub min_samples_split: usize,
483    /// Minimum number of samples required in a leaf node.
484    pub min_samples_leaf: usize,
485    /// Fraction of samples to use for fitting each tree.
486    pub subsample: f64,
487    /// Classification loss function.
488    pub loss: ClassificationLoss,
489    /// Random seed for reproducibility.
490    pub random_state: Option<u64>,
491    _marker: std::marker::PhantomData<F>,
492}
493
494impl<F: Float> GradientBoostingClassifier<F> {
495    /// Create a new `GradientBoostingClassifier` with default settings.
496    ///
497    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
498    /// `max_depth = Some(3)`, `min_samples_split = 2`,
499    /// `min_samples_leaf = 1`, `subsample = 1.0`,
500    /// `loss = LogLoss`.
501    #[must_use]
502    pub fn new() -> Self {
503        Self {
504            n_estimators: 100,
505            learning_rate: 0.1,
506            max_depth: Some(3),
507            min_samples_split: 2,
508            min_samples_leaf: 1,
509            subsample: 1.0,
510            loss: ClassificationLoss::LogLoss,
511            random_state: None,
512            _marker: std::marker::PhantomData,
513        }
514    }
515
516    /// Set the number of boosting stages.
517    #[must_use]
518    pub fn with_n_estimators(mut self, n: usize) -> Self {
519        self.n_estimators = n;
520        self
521    }
522
523    /// Set the learning rate (shrinkage).
524    #[must_use]
525    pub fn with_learning_rate(mut self, lr: f64) -> Self {
526        self.learning_rate = lr;
527        self
528    }
529
530    /// Set the maximum tree depth.
531    #[must_use]
532    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
533        self.max_depth = d;
534        self
535    }
536
537    /// Set the minimum number of samples to split a node.
538    #[must_use]
539    pub fn with_min_samples_split(mut self, n: usize) -> Self {
540        self.min_samples_split = n;
541        self
542    }
543
544    /// Set the minimum number of samples in a leaf.
545    #[must_use]
546    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
547        self.min_samples_leaf = n;
548        self
549    }
550
551    /// Set the subsample ratio.
552    #[must_use]
553    pub fn with_subsample(mut self, ratio: f64) -> Self {
554        self.subsample = ratio;
555        self
556    }
557
558    /// Set the random seed for reproducibility.
559    #[must_use]
560    pub fn with_random_state(mut self, seed: u64) -> Self {
561        self.random_state = Some(seed);
562        self
563    }
564}
565
566impl<F: Float> Default for GradientBoostingClassifier<F> {
567    fn default() -> Self {
568        Self::new()
569    }
570}
571
572// ---------------------------------------------------------------------------
573// FittedGradientBoostingClassifier
574// ---------------------------------------------------------------------------
575
576/// A fitted gradient boosting classifier.
577///
578/// For binary classification, stores a single sequence of trees predicting log-odds.
579/// For multiclass, stores `K` sequences of trees (one per class).
580#[derive(Debug, Clone)]
581pub struct FittedGradientBoostingClassifier<F> {
582    /// Sorted unique class labels.
583    classes: Vec<usize>,
584    /// Initial predictions per class (log-odds or log-prior).
585    init: Vec<F>,
586    /// Learning rate.
587    learning_rate: F,
588    /// Trees: for binary, `trees[0]` has all trees. For multiclass,
589    /// `trees[k]` has trees for class k.
590    trees: Vec<Vec<Vec<Node<F>>>>,
591    /// Number of features.
592    n_features: usize,
593    /// Per-feature importance scores (normalised).
594    feature_importances: Array1<F>,
595}
596
597impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>>
598    for GradientBoostingClassifier<F>
599{
600    type Fitted = FittedGradientBoostingClassifier<F>;
601    type Error = FerroError;
602
603    /// Fit the gradient boosting classifier.
604    ///
605    /// # Errors
606    ///
607    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
608    /// numbers of samples.
609    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
610    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
611    fn fit(
612        &self,
613        x: &Array2<F>,
614        y: &Array1<usize>,
615    ) -> Result<FittedGradientBoostingClassifier<F>, FerroError> {
616        let (n_samples, n_features) = x.dim();
617
618        if n_samples != y.len() {
619            return Err(FerroError::ShapeMismatch {
620                expected: vec![n_samples],
621                actual: vec![y.len()],
622                context: "y length must match number of samples in X".into(),
623            });
624        }
625        if n_samples == 0 {
626            return Err(FerroError::InsufficientSamples {
627                required: 1,
628                actual: 0,
629                context: "GradientBoostingClassifier requires at least one sample".into(),
630            });
631        }
632        if self.n_estimators == 0 {
633            return Err(FerroError::InvalidParameter {
634                name: "n_estimators".into(),
635                reason: "must be at least 1".into(),
636            });
637        }
638        if self.learning_rate <= 0.0 {
639            return Err(FerroError::InvalidParameter {
640                name: "learning_rate".into(),
641                reason: "must be positive".into(),
642            });
643        }
644        if self.subsample <= 0.0 || self.subsample > 1.0 {
645            return Err(FerroError::InvalidParameter {
646                name: "subsample".into(),
647                reason: "must be in (0, 1]".into(),
648            });
649        }
650
651        // Determine unique classes.
652        let mut classes: Vec<usize> = y.iter().copied().collect();
653        classes.sort_unstable();
654        classes.dedup();
655        let n_classes = classes.len();
656
657        if n_classes < 2 {
658            return Err(FerroError::InvalidParameter {
659                name: "y".into(),
660                reason: "need at least 2 distinct classes".into(),
661            });
662        }
663
664        let y_mapped: Vec<usize> = y
665            .iter()
666            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
667            .collect();
668
669        let lr = F::from(self.learning_rate).unwrap();
670        let params = decision_tree::TreeParams {
671            max_depth: self.max_depth,
672            min_samples_split: self.min_samples_split,
673            min_samples_leaf: self.min_samples_leaf,
674        };
675
676        let all_features: Vec<usize> = (0..n_features).collect();
677        let subsample_size = ((self.subsample * n_samples as f64).ceil() as usize)
678            .max(1)
679            .min(n_samples);
680
681        let mut rng = match self.random_state {
682            Some(seed) => StdRng::seed_from_u64(seed),
683            None => {
684                use rand::RngCore;
685                StdRng::seed_from_u64(rand::rng().next_u64())
686            }
687        };
688
689        if n_classes == 2 {
690            // Binary classification: single model on log-odds.
691            self.fit_binary(
692                x,
693                &y_mapped,
694                n_samples,
695                n_features,
696                &classes,
697                lr,
698                &params,
699                &all_features,
700                subsample_size,
701                &mut rng,
702            )
703        } else {
704            // Multiclass: K trees per round.
705            self.fit_multiclass(
706                x,
707                &y_mapped,
708                n_samples,
709                n_features,
710                n_classes,
711                &classes,
712                lr,
713                &params,
714                &all_features,
715                subsample_size,
716                &mut rng,
717            )
718        }
719    }
720}
721
722impl<F: Float + Send + Sync + 'static> GradientBoostingClassifier<F> {
723    /// Fit binary classification (log-loss on log-odds).
724    #[allow(clippy::too_many_arguments)]
725    fn fit_binary(
726        &self,
727        x: &Array2<F>,
728        y_mapped: &[usize],
729        n_samples: usize,
730        n_features: usize,
731        classes: &[usize],
732        lr: F,
733        params: &decision_tree::TreeParams,
734        all_features: &[usize],
735        subsample_size: usize,
736        rng: &mut StdRng,
737    ) -> Result<FittedGradientBoostingClassifier<F>, FerroError> {
738        // Count positive class proportion for initial log-odds.
739        let pos_count = y_mapped.iter().filter(|&&c| c == 1).count();
740        let p = F::from(pos_count).unwrap() / F::from(n_samples).unwrap();
741        let eps = F::from(1e-15).unwrap();
742        let p_clipped = p.max(eps).min(F::one() - eps);
743        let init_val = (p_clipped / (F::one() - p_clipped)).ln();
744
745        let mut f_vals = Array1::from_elem(n_samples, init_val);
746        let mut trees_seq: Vec<Vec<Node<F>>> = Vec::with_capacity(self.n_estimators);
747
748        for _ in 0..self.n_estimators {
749            // Compute probabilities from current log-odds.
750            let probs: Vec<F> = f_vals.iter().map(|&fv| sigmoid(fv)).collect();
751
752            // Pseudo-residuals: y - p.
753            let mut residuals = Array1::zeros(n_samples);
754            for i in 0..n_samples {
755                let yi = F::from(y_mapped[i]).unwrap();
756                residuals[i] = yi - probs[i];
757            }
758
759            // Subsample.
760            let sample_indices = if subsample_size < n_samples {
761                rand_sample_indices(rng, n_samples, subsample_size).into_vec()
762            } else {
763                (0..n_samples).collect()
764            };
765
766            // Build tree on residuals.
767            let tree = build_regression_tree_with_feature_subset(
768                x,
769                &residuals,
770                &sample_indices,
771                all_features,
772                params,
773            );
774
775            // Update f_vals.
776            for i in 0..n_samples {
777                let row = x.row(i);
778                let leaf_idx = decision_tree::traverse(&tree, &row);
779                if let Node::Leaf { value, .. } = tree[leaf_idx] {
780                    f_vals[i] = f_vals[i] + lr * value;
781                }
782            }
783
784            trees_seq.push(tree);
785        }
786
787        // Feature importances.
788        let mut total_importances = Array1::<F>::zeros(n_features);
789        for tree_nodes in &trees_seq {
790            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
791            total_importances = total_importances + tree_imp;
792        }
793        let imp_sum: F = total_importances
794            .iter()
795            .copied()
796            .fold(F::zero(), |a, b| a + b);
797        if imp_sum > F::zero() {
798            total_importances.mapv_inplace(|v| v / imp_sum);
799        }
800
801        Ok(FittedGradientBoostingClassifier {
802            classes: classes.to_vec(),
803            init: vec![init_val],
804            learning_rate: lr,
805            trees: vec![trees_seq],
806            n_features,
807            feature_importances: total_importances,
808        })
809    }
810
811    /// Fit multiclass classification (K trees per round, softmax).
812    #[allow(clippy::too_many_arguments)]
813    fn fit_multiclass(
814        &self,
815        x: &Array2<F>,
816        y_mapped: &[usize],
817        n_samples: usize,
818        n_features: usize,
819        n_classes: usize,
820        classes: &[usize],
821        lr: F,
822        params: &decision_tree::TreeParams,
823        all_features: &[usize],
824        subsample_size: usize,
825        rng: &mut StdRng,
826    ) -> Result<FittedGradientBoostingClassifier<F>, FerroError> {
827        // Initial log-prior for each class.
828        let mut class_counts = vec![0usize; n_classes];
829        for &c in y_mapped {
830            class_counts[c] += 1;
831        }
832        let n_f = F::from(n_samples).unwrap();
833        let eps = F::from(1e-15).unwrap();
834        let init_vals: Vec<F> = class_counts
835            .iter()
836            .map(|&cnt| {
837                let p = (F::from(cnt).unwrap() / n_f).max(eps);
838                p.ln()
839            })
840            .collect();
841
842        // f_vals[k][i] = current raw score for class k, sample i.
843        let mut f_vals: Vec<Array1<F>> = init_vals
844            .iter()
845            .map(|&init| Array1::from_elem(n_samples, init))
846            .collect();
847
848        let mut trees_per_class: Vec<Vec<Vec<Node<F>>>> = (0..n_classes)
849            .map(|_| Vec::with_capacity(self.n_estimators))
850            .collect();
851
852        for _ in 0..self.n_estimators {
853            // Compute softmax probabilities.
854            let probs = softmax_matrix(&f_vals, n_samples, n_classes);
855
856            // Subsample.
857            let sample_indices = if subsample_size < n_samples {
858                rand_sample_indices(rng, n_samples, subsample_size).into_vec()
859            } else {
860                (0..n_samples).collect()
861            };
862
863            // For each class, compute residuals and fit a tree.
864            for k in 0..n_classes {
865                let mut residuals = Array1::zeros(n_samples);
866                for i in 0..n_samples {
867                    let yi_k = if y_mapped[i] == k {
868                        F::one()
869                    } else {
870                        F::zero()
871                    };
872                    residuals[i] = yi_k - probs[k][i];
873                }
874
875                let tree = build_regression_tree_with_feature_subset(
876                    x,
877                    &residuals,
878                    &sample_indices,
879                    all_features,
880                    params,
881                );
882
883                // Update f_vals for class k.
884                for (i, fv) in f_vals[k].iter_mut().enumerate() {
885                    let row = x.row(i);
886                    let leaf_idx = decision_tree::traverse(&tree, &row);
887                    if let Node::Leaf { value, .. } = tree[leaf_idx] {
888                        *fv = *fv + lr * value;
889                    }
890                }
891
892                trees_per_class[k].push(tree);
893            }
894        }
895
896        // Feature importances aggregated across all classes and rounds.
897        let mut total_importances = Array1::<F>::zeros(n_features);
898        for class_trees in &trees_per_class {
899            for tree_nodes in class_trees {
900                let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
901                total_importances = total_importances + tree_imp;
902            }
903        }
904        let imp_sum: F = total_importances
905            .iter()
906            .copied()
907            .fold(F::zero(), |a, b| a + b);
908        if imp_sum > F::zero() {
909            total_importances.mapv_inplace(|v| v / imp_sum);
910        }
911
912        Ok(FittedGradientBoostingClassifier {
913            classes: classes.to_vec(),
914            init: init_vals,
915            learning_rate: lr,
916            trees: trees_per_class,
917            n_features,
918            feature_importances: total_importances,
919        })
920    }
921}
922
923impl<F: Float + Send + Sync + 'static> FittedGradientBoostingClassifier<F> {
924    /// Returns the initial predictions per class (log-odds or log-prior).
925    #[must_use]
926    pub fn init(&self) -> &[F] {
927        &self.init
928    }
929
930    /// Returns the learning rate used during training.
931    #[must_use]
932    pub fn learning_rate(&self) -> F {
933        self.learning_rate
934    }
935
936    /// Returns a reference to the tree ensemble.
937    ///
938    /// For binary classification, `trees()[0]` contains all trees.
939    /// For multiclass, `trees()[k]` contains trees for class `k`.
940    #[must_use]
941    pub fn trees(&self) -> &[Vec<Vec<Node<F>>>] {
942        &self.trees
943    }
944
945    /// Returns the number of features the model was trained on.
946    #[must_use]
947    pub fn n_features(&self) -> usize {
948        self.n_features
949    }
950}
951
952impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedGradientBoostingClassifier<F> {
953    type Output = Array1<usize>;
954    type Error = FerroError;
955
956    /// Predict class labels.
957    ///
958    /// # Errors
959    ///
960    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
961    /// not match the fitted model.
962    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
963        if x.ncols() != self.n_features {
964            return Err(FerroError::ShapeMismatch {
965                expected: vec![self.n_features],
966                actual: vec![x.ncols()],
967                context: "number of features must match fitted model".into(),
968            });
969        }
970
971        let n_samples = x.nrows();
972        let n_classes = self.classes.len();
973
974        if n_classes == 2 {
975            // Binary: single log-odds model.
976            let init = self.init[0];
977            let mut predictions = Array1::zeros(n_samples);
978            for i in 0..n_samples {
979                let row = x.row(i);
980                let mut f_val = init;
981                for tree_nodes in &self.trees[0] {
982                    let leaf_idx = decision_tree::traverse(tree_nodes, &row);
983                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
984                        f_val = f_val + self.learning_rate * value;
985                    }
986                }
987                let prob = sigmoid(f_val);
988                let class_idx = if prob >= F::from(0.5).unwrap() { 1 } else { 0 };
989                predictions[i] = self.classes[class_idx];
990            }
991            Ok(predictions)
992        } else {
993            // Multiclass: K models, argmax of softmax.
994            let mut predictions = Array1::zeros(n_samples);
995            for i in 0..n_samples {
996                let row = x.row(i);
997                let mut scores = Vec::with_capacity(n_classes);
998                for k in 0..n_classes {
999                    let mut f_val = self.init[k];
1000                    for tree_nodes in &self.trees[k] {
1001                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1002                        if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1003                            f_val = f_val + self.learning_rate * value;
1004                        }
1005                    }
1006                    scores.push(f_val);
1007                }
1008                let best_k = scores
1009                    .iter()
1010                    .enumerate()
1011                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1012                    .map(|(k, _)| k)
1013                    .unwrap_or(0);
1014                predictions[i] = self.classes[best_k];
1015            }
1016            Ok(predictions)
1017        }
1018    }
1019}
1020
1021impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1022    for FittedGradientBoostingClassifier<F>
1023{
1024    fn feature_importances(&self) -> &Array1<F> {
1025        &self.feature_importances
1026    }
1027}
1028
1029impl<F: Float + Send + Sync + 'static> HasClasses for FittedGradientBoostingClassifier<F> {
1030    fn classes(&self) -> &[usize] {
1031        &self.classes
1032    }
1033
1034    fn n_classes(&self) -> usize {
1035        self.classes.len()
1036    }
1037}
1038
1039// Pipeline integration.
1040impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
1041    for GradientBoostingClassifier<F>
1042{
1043    fn fit_pipeline(
1044        &self,
1045        x: &Array2<F>,
1046        y: &Array1<F>,
1047    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1048        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
1049        let fitted = self.fit(x, &y_usize)?;
1050        Ok(Box::new(FittedGbcPipelineAdapter(fitted)))
1051    }
1052}
1053
1054/// Pipeline adapter for `FittedGradientBoostingClassifier<F>`.
1055struct FittedGbcPipelineAdapter<F: Float + Send + Sync + 'static>(
1056    FittedGradientBoostingClassifier<F>,
1057);
1058
1059impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
1060    for FittedGbcPipelineAdapter<F>
1061{
1062    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1063        let preds = self.0.predict(x)?;
1064        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
1065    }
1066}
1067
1068// ---------------------------------------------------------------------------
1069// Internal helpers
1070// ---------------------------------------------------------------------------
1071
1072/// Sigmoid function: 1 / (1 + exp(-x)).
1073fn sigmoid<F: Float>(x: F) -> F {
1074    F::one() / (F::one() + (-x).exp())
1075}
1076
1077/// Compute softmax probabilities for each class across all samples.
1078///
1079/// Returns `probs[k][i]` = probability of class k for sample i.
1080fn softmax_matrix<F: Float>(
1081    f_vals: &[Array1<F>],
1082    n_samples: usize,
1083    n_classes: usize,
1084) -> Vec<Vec<F>> {
1085    let mut probs: Vec<Vec<F>> = vec![vec![F::zero(); n_samples]; n_classes];
1086
1087    for i in 0..n_samples {
1088        // Find max for numerical stability.
1089        let max_val = (0..n_classes)
1090            .map(|k| f_vals[k][i])
1091            .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
1092
1093        let mut sum = F::zero();
1094        let mut exps = vec![F::zero(); n_classes];
1095        for k in 0..n_classes {
1096            exps[k] = (f_vals[k][i] - max_val).exp();
1097            sum = sum + exps[k];
1098        }
1099
1100        let eps = F::from(1e-15).unwrap();
1101        if sum < eps {
1102            sum = eps;
1103        }
1104
1105        for k in 0..n_classes {
1106            probs[k][i] = exps[k] / sum;
1107        }
1108    }
1109
1110    probs
1111}
1112
1113/// Compute the median of an Array1.
1114fn median_f<F: Float>(arr: &Array1<F>) -> F {
1115    let mut sorted: Vec<F> = arr.iter().copied().collect();
1116    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1117    let n = sorted.len();
1118    if n == 0 {
1119        return F::zero();
1120    }
1121    if n % 2 == 1 {
1122        sorted[n / 2]
1123    } else {
1124        (sorted[n / 2 - 1] + sorted[n / 2]) / F::from(2.0).unwrap()
1125    }
1126}
1127
1128/// Compute the quantile of a slice at level `alpha` (0..1).
1129fn quantile_f<F: Float>(vals: &[F], alpha: f64) -> F {
1130    if vals.is_empty() {
1131        return F::zero();
1132    }
1133    let mut sorted: Vec<F> = vals.to_vec();
1134    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1135    let idx = ((sorted.len() as f64 - 1.0) * alpha).round() as usize;
1136    let idx = idx.min(sorted.len() - 1);
1137    sorted[idx]
1138}
1139
1140/// Compute pseudo-residuals (negative gradient) for regression losses.
1141fn compute_regression_residuals<F: Float>(
1142    y: &Array1<F>,
1143    f_vals: &Array1<F>,
1144    loss: RegressionLoss,
1145    huber_alpha: f64,
1146) -> Array1<F> {
1147    let n = y.len();
1148    match loss {
1149        RegressionLoss::LeastSquares => {
1150            // negative gradient of 0.5*(y - f)^2 is (y - f)
1151            let mut residuals = Array1::zeros(n);
1152            for i in 0..n {
1153                residuals[i] = y[i] - f_vals[i];
1154            }
1155            residuals
1156        }
1157        RegressionLoss::Lad => {
1158            // negative gradient of |y - f| is sign(y - f)
1159            let mut residuals = Array1::zeros(n);
1160            for i in 0..n {
1161                let diff = y[i] - f_vals[i];
1162                residuals[i] = if diff > F::zero() {
1163                    F::one()
1164                } else if diff < F::zero() {
1165                    -F::one()
1166                } else {
1167                    F::zero()
1168                };
1169            }
1170            residuals
1171        }
1172        RegressionLoss::Huber => {
1173            // Compute residuals and delta from quantile.
1174            let raw_residuals: Vec<F> = (0..n).map(|i| (y[i] - f_vals[i]).abs()).collect();
1175            let delta = quantile_f(&raw_residuals, huber_alpha);
1176
1177            let mut residuals = Array1::zeros(n);
1178            for i in 0..n {
1179                let diff = y[i] - f_vals[i];
1180                if diff.abs() <= delta {
1181                    residuals[i] = diff;
1182                } else if diff > F::zero() {
1183                    residuals[i] = delta;
1184                } else {
1185                    residuals[i] = -delta;
1186                }
1187            }
1188            residuals
1189        }
1190    }
1191}
1192
1193// ---------------------------------------------------------------------------
1194// Tests
1195// ---------------------------------------------------------------------------
1196
1197#[cfg(test)]
1198mod tests {
1199    use super::*;
1200    use approx::assert_relative_eq;
1201    use ndarray::array;
1202
1203    // -- Regressor tests --
1204
1205    #[test]
1206    fn test_gbr_simple_least_squares() {
1207        let x =
1208            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1209        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1210
1211        let model = GradientBoostingRegressor::<f64>::new()
1212            .with_n_estimators(50)
1213            .with_learning_rate(0.1)
1214            .with_random_state(42);
1215        let fitted = model.fit(&x, &y).unwrap();
1216        let preds = fitted.predict(&x).unwrap();
1217
1218        assert_eq!(preds.len(), 8);
1219        for i in 0..4 {
1220            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
1221        }
1222        for i in 4..8 {
1223            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
1224        }
1225    }
1226
1227    #[test]
1228    fn test_gbr_lad_loss() {
1229        let x =
1230            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1231        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1232
1233        let model = GradientBoostingRegressor::<f64>::new()
1234            .with_n_estimators(50)
1235            .with_loss(RegressionLoss::Lad)
1236            .with_random_state(42);
1237        let fitted = model.fit(&x, &y).unwrap();
1238        let preds = fitted.predict(&x).unwrap();
1239
1240        assert_eq!(preds.len(), 8);
1241        // LAD should still separate the two groups.
1242        for i in 0..4 {
1243            assert!(preds[i] < 3.5, "LAD expected <3.5, got {}", preds[i]);
1244        }
1245        for i in 4..8 {
1246            assert!(preds[i] > 2.5, "LAD expected >2.5, got {}", preds[i]);
1247        }
1248    }
1249
1250    #[test]
1251    fn test_gbr_huber_loss() {
1252        let x =
1253            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1254        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1255
1256        let model = GradientBoostingRegressor::<f64>::new()
1257            .with_n_estimators(50)
1258            .with_loss(RegressionLoss::Huber)
1259            .with_huber_alpha(0.9)
1260            .with_random_state(42);
1261        let fitted = model.fit(&x, &y).unwrap();
1262        let preds = fitted.predict(&x).unwrap();
1263
1264        assert_eq!(preds.len(), 8);
1265    }
1266
1267    #[test]
1268    fn test_gbr_reproducibility() {
1269        let x =
1270            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1271        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1272
1273        let model = GradientBoostingRegressor::<f64>::new()
1274            .with_n_estimators(20)
1275            .with_random_state(123);
1276
1277        let fitted1 = model.fit(&x, &y).unwrap();
1278        let fitted2 = model.fit(&x, &y).unwrap();
1279
1280        let preds1 = fitted1.predict(&x).unwrap();
1281        let preds2 = fitted2.predict(&x).unwrap();
1282
1283        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
1284            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1285        }
1286    }
1287
1288    #[test]
1289    fn test_gbr_feature_importances() {
1290        let x = Array2::from_shape_vec(
1291            (10, 3),
1292            vec![
1293                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
1294                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
1295            ],
1296        )
1297        .unwrap();
1298        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
1299
1300        let model = GradientBoostingRegressor::<f64>::new()
1301            .with_n_estimators(20)
1302            .with_random_state(42);
1303        let fitted = model.fit(&x, &y).unwrap();
1304        let importances = fitted.feature_importances();
1305
1306        assert_eq!(importances.len(), 3);
1307        // First feature should be most important.
1308        assert!(importances[0] > importances[1]);
1309        assert!(importances[0] > importances[2]);
1310    }
1311
1312    #[test]
1313    fn test_gbr_shape_mismatch_fit() {
1314        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1315        let y = array![1.0, 2.0];
1316
1317        let model = GradientBoostingRegressor::<f64>::new().with_n_estimators(5);
1318        assert!(model.fit(&x, &y).is_err());
1319    }
1320
1321    #[test]
1322    fn test_gbr_shape_mismatch_predict() {
1323        let x =
1324            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1325        let y = array![1.0, 2.0, 3.0, 4.0];
1326
1327        let model = GradientBoostingRegressor::<f64>::new()
1328            .with_n_estimators(5)
1329            .with_random_state(0);
1330        let fitted = model.fit(&x, &y).unwrap();
1331
1332        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1333        assert!(fitted.predict(&x_bad).is_err());
1334    }
1335
1336    #[test]
1337    fn test_gbr_empty_data() {
1338        let x = Array2::<f64>::zeros((0, 2));
1339        let y = Array1::<f64>::zeros(0);
1340
1341        let model = GradientBoostingRegressor::<f64>::new().with_n_estimators(5);
1342        assert!(model.fit(&x, &y).is_err());
1343    }
1344
1345    #[test]
1346    fn test_gbr_zero_estimators() {
1347        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1348        let y = array![1.0, 2.0, 3.0, 4.0];
1349
1350        let model = GradientBoostingRegressor::<f64>::new().with_n_estimators(0);
1351        assert!(model.fit(&x, &y).is_err());
1352    }
1353
1354    #[test]
1355    fn test_gbr_invalid_learning_rate() {
1356        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1357        let y = array![1.0, 2.0, 3.0, 4.0];
1358
1359        let model = GradientBoostingRegressor::<f64>::new()
1360            .with_n_estimators(5)
1361            .with_learning_rate(0.0);
1362        assert!(model.fit(&x, &y).is_err());
1363    }
1364
1365    #[test]
1366    fn test_gbr_invalid_subsample() {
1367        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1368        let y = array![1.0, 2.0, 3.0, 4.0];
1369
1370        let model = GradientBoostingRegressor::<f64>::new()
1371            .with_n_estimators(5)
1372            .with_subsample(0.0);
1373        assert!(model.fit(&x, &y).is_err());
1374
1375        let model2 = GradientBoostingRegressor::<f64>::new()
1376            .with_n_estimators(5)
1377            .with_subsample(1.5);
1378        assert!(model2.fit(&x, &y).is_err());
1379    }
1380
1381    #[test]
1382    fn test_gbr_subsample() {
1383        let x =
1384            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1385        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1386
1387        let model = GradientBoostingRegressor::<f64>::new()
1388            .with_n_estimators(50)
1389            .with_subsample(0.5)
1390            .with_random_state(42);
1391        let fitted = model.fit(&x, &y).unwrap();
1392        let preds = fitted.predict(&x).unwrap();
1393
1394        assert_eq!(preds.len(), 8);
1395    }
1396
1397    #[test]
1398    fn test_gbr_pipeline_integration() {
1399        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1400        let y = array![1.0, 2.0, 3.0, 4.0];
1401
1402        let model = GradientBoostingRegressor::<f64>::new()
1403            .with_n_estimators(10)
1404            .with_random_state(42);
1405        let fitted = model.fit_pipeline(&x, &y).unwrap();
1406        let preds = fitted.predict_pipeline(&x).unwrap();
1407        assert_eq!(preds.len(), 4);
1408    }
1409
1410    #[test]
1411    fn test_gbr_f32_support() {
1412        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1413        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1414
1415        let model = GradientBoostingRegressor::<f32>::new()
1416            .with_n_estimators(10)
1417            .with_random_state(42);
1418        let fitted = model.fit(&x, &y).unwrap();
1419        let preds = fitted.predict(&x).unwrap();
1420        assert_eq!(preds.len(), 4);
1421    }
1422
1423    #[test]
1424    fn test_gbr_max_depth() {
1425        let x =
1426            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1427        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1428
1429        let model = GradientBoostingRegressor::<f64>::new()
1430            .with_n_estimators(20)
1431            .with_max_depth(Some(1))
1432            .with_random_state(42);
1433        let fitted = model.fit(&x, &y).unwrap();
1434        let preds = fitted.predict(&x).unwrap();
1435        assert_eq!(preds.len(), 8);
1436    }
1437
1438    #[test]
1439    fn test_gbr_default_trait() {
1440        let model = GradientBoostingRegressor::<f64>::default();
1441        assert_eq!(model.n_estimators, 100);
1442        assert!((model.learning_rate - 0.1).abs() < 1e-10);
1443    }
1444
1445    // -- Classifier tests --
1446
1447    #[test]
1448    fn test_gbc_binary_simple() {
1449        let x = Array2::from_shape_vec(
1450            (8, 2),
1451            vec![
1452                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1453            ],
1454        )
1455        .unwrap();
1456        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1457
1458        let model = GradientBoostingClassifier::<f64>::new()
1459            .with_n_estimators(50)
1460            .with_learning_rate(0.1)
1461            .with_random_state(42);
1462        let fitted = model.fit(&x, &y).unwrap();
1463        let preds = fitted.predict(&x).unwrap();
1464
1465        assert_eq!(preds.len(), 8);
1466        for i in 0..4 {
1467            assert_eq!(preds[i], 0, "Expected 0 at index {}, got {}", i, preds[i]);
1468        }
1469        for i in 4..8 {
1470            assert_eq!(preds[i], 1, "Expected 1 at index {}, got {}", i, preds[i]);
1471        }
1472    }
1473
1474    #[test]
1475    fn test_gbc_multiclass() {
1476        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1477            .unwrap();
1478        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1479
1480        let model = GradientBoostingClassifier::<f64>::new()
1481            .with_n_estimators(50)
1482            .with_learning_rate(0.1)
1483            .with_random_state(42);
1484        let fitted = model.fit(&x, &y).unwrap();
1485        let preds = fitted.predict(&x).unwrap();
1486
1487        assert_eq!(preds.len(), 9);
1488        // At least training data should mostly be correct.
1489        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
1490        assert!(
1491            correct >= 6,
1492            "Expected at least 6/9 correct, got {}/9",
1493            correct
1494        );
1495    }
1496
1497    #[test]
1498    fn test_gbc_has_classes() {
1499        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1500        let y = array![0, 1, 2, 0, 1, 2];
1501
1502        let model = GradientBoostingClassifier::<f64>::new()
1503            .with_n_estimators(5)
1504            .with_random_state(0);
1505        let fitted = model.fit(&x, &y).unwrap();
1506
1507        assert_eq!(fitted.classes(), &[0, 1, 2]);
1508        assert_eq!(fitted.n_classes(), 3);
1509    }
1510
1511    #[test]
1512    fn test_gbc_reproducibility() {
1513        let x = Array2::from_shape_vec(
1514            (8, 2),
1515            vec![
1516                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1517            ],
1518        )
1519        .unwrap();
1520        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1521
1522        let model = GradientBoostingClassifier::<f64>::new()
1523            .with_n_estimators(10)
1524            .with_random_state(42);
1525
1526        let fitted1 = model.fit(&x, &y).unwrap();
1527        let fitted2 = model.fit(&x, &y).unwrap();
1528
1529        let preds1 = fitted1.predict(&x).unwrap();
1530        let preds2 = fitted2.predict(&x).unwrap();
1531        assert_eq!(preds1, preds2);
1532    }
1533
1534    #[test]
1535    fn test_gbc_feature_importances() {
1536        let x = Array2::from_shape_vec(
1537            (10, 3),
1538            vec![
1539                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
1540                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
1541            ],
1542        )
1543        .unwrap();
1544        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
1545
1546        let model = GradientBoostingClassifier::<f64>::new()
1547            .with_n_estimators(20)
1548            .with_random_state(42);
1549        let fitted = model.fit(&x, &y).unwrap();
1550        let importances = fitted.feature_importances();
1551
1552        assert_eq!(importances.len(), 3);
1553        assert!(importances[0] > importances[1]);
1554        assert!(importances[0] > importances[2]);
1555    }
1556
1557    #[test]
1558    fn test_gbc_shape_mismatch_fit() {
1559        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1560        let y = array![0, 1];
1561
1562        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(5);
1563        assert!(model.fit(&x, &y).is_err());
1564    }
1565
1566    #[test]
1567    fn test_gbc_shape_mismatch_predict() {
1568        let x =
1569            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1570        let y = array![0, 0, 1, 1];
1571
1572        let model = GradientBoostingClassifier::<f64>::new()
1573            .with_n_estimators(5)
1574            .with_random_state(0);
1575        let fitted = model.fit(&x, &y).unwrap();
1576
1577        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1578        assert!(fitted.predict(&x_bad).is_err());
1579    }
1580
1581    #[test]
1582    fn test_gbc_empty_data() {
1583        let x = Array2::<f64>::zeros((0, 2));
1584        let y = Array1::<usize>::zeros(0);
1585
1586        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(5);
1587        assert!(model.fit(&x, &y).is_err());
1588    }
1589
1590    #[test]
1591    fn test_gbc_single_class() {
1592        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1593        let y = array![0, 0, 0];
1594
1595        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(5);
1596        assert!(model.fit(&x, &y).is_err());
1597    }
1598
1599    #[test]
1600    fn test_gbc_zero_estimators() {
1601        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1602        let y = array![0, 0, 1, 1];
1603
1604        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(0);
1605        assert!(model.fit(&x, &y).is_err());
1606    }
1607
1608    #[test]
1609    fn test_gbc_pipeline_integration() {
1610        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1611        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1612
1613        let model = GradientBoostingClassifier::<f64>::new()
1614            .with_n_estimators(10)
1615            .with_random_state(42);
1616        let fitted = model.fit_pipeline(&x, &y).unwrap();
1617        let preds = fitted.predict_pipeline(&x).unwrap();
1618        assert_eq!(preds.len(), 6);
1619    }
1620
1621    #[test]
1622    fn test_gbc_f32_support() {
1623        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1624        let y = array![0, 0, 0, 1, 1, 1];
1625
1626        let model = GradientBoostingClassifier::<f32>::new()
1627            .with_n_estimators(10)
1628            .with_random_state(42);
1629        let fitted = model.fit(&x, &y).unwrap();
1630        let preds = fitted.predict(&x).unwrap();
1631        assert_eq!(preds.len(), 6);
1632    }
1633
1634    #[test]
1635    fn test_gbc_subsample() {
1636        let x = Array2::from_shape_vec(
1637            (8, 2),
1638            vec![
1639                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1640            ],
1641        )
1642        .unwrap();
1643        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1644
1645        let model = GradientBoostingClassifier::<f64>::new()
1646            .with_n_estimators(20)
1647            .with_subsample(0.5)
1648            .with_random_state(42);
1649        let fitted = model.fit(&x, &y).unwrap();
1650        let preds = fitted.predict(&x).unwrap();
1651        assert_eq!(preds.len(), 8);
1652    }
1653
1654    #[test]
1655    fn test_gbc_default_trait() {
1656        let model = GradientBoostingClassifier::<f64>::default();
1657        assert_eq!(model.n_estimators, 100);
1658        assert!((model.learning_rate - 0.1).abs() < 1e-10);
1659    }
1660
1661    #[test]
1662    fn test_gbc_non_contiguous_labels() {
1663        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1664        let y = array![10, 10, 10, 20, 20, 20];
1665
1666        let model = GradientBoostingClassifier::<f64>::new()
1667            .with_n_estimators(20)
1668            .with_random_state(42);
1669        let fitted = model.fit(&x, &y).unwrap();
1670        let preds = fitted.predict(&x).unwrap();
1671
1672        assert_eq!(preds.len(), 6);
1673        for &p in preds.iter() {
1674            assert!(p == 10 || p == 20);
1675        }
1676    }
1677
1678    // -- Helper tests --
1679
1680    #[test]
1681    fn test_sigmoid() {
1682        assert_relative_eq!(sigmoid(0.0f64), 0.5, epsilon = 1e-10);
1683        assert!(sigmoid(10.0f64) > 0.999);
1684        assert!(sigmoid(-10.0f64) < 0.001);
1685    }
1686
1687    #[test]
1688    fn test_median_f_odd() {
1689        let arr = array![3.0, 1.0, 2.0];
1690        assert_relative_eq!(median_f(&arr), 2.0, epsilon = 1e-10);
1691    }
1692
1693    #[test]
1694    fn test_median_f_even() {
1695        let arr = array![4.0, 1.0, 3.0, 2.0];
1696        assert_relative_eq!(median_f(&arr), 2.5, epsilon = 1e-10);
1697    }
1698
1699    #[test]
1700    fn test_median_f_empty() {
1701        let arr = Array1::<f64>::zeros(0);
1702        assert_relative_eq!(median_f(&arr), 0.0, epsilon = 1e-10);
1703    }
1704
1705    #[test]
1706    fn test_quantile_f() {
1707        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1708        let q90 = quantile_f(&vals, 0.9);
1709        assert!(q90 >= 4.0 && q90 <= 5.0);
1710    }
1711
1712    #[test]
1713    fn test_regression_residuals_least_squares() {
1714        let y = array![1.0, 2.0, 3.0];
1715        let f = array![0.5, 2.5, 2.0];
1716        let r = compute_regression_residuals(&y, &f, RegressionLoss::LeastSquares, 0.9);
1717        assert_relative_eq!(r[0], 0.5, epsilon = 1e-10);
1718        assert_relative_eq!(r[1], -0.5, epsilon = 1e-10);
1719        assert_relative_eq!(r[2], 1.0, epsilon = 1e-10);
1720    }
1721
1722    #[test]
1723    fn test_regression_residuals_lad() {
1724        let y = array![1.0, 2.0, 3.0];
1725        let f = array![0.5, 2.5, 3.0];
1726        let r = compute_regression_residuals(&y, &f, RegressionLoss::Lad, 0.9);
1727        assert_relative_eq!(r[0], 1.0, epsilon = 1e-10);
1728        assert_relative_eq!(r[1], -1.0, epsilon = 1e-10);
1729        assert_relative_eq!(r[2], 0.0, epsilon = 1e-10);
1730    }
1731
1732    #[test]
1733    fn test_regression_residuals_huber() {
1734        let y = array![1.0, 2.0, 10.0, 3.0, 4.0];
1735        let f = array![1.5, 2.5, 2.0, 3.5, 4.5];
1736        // abs residuals: [0.5, 0.5, 8.0, 0.5, 0.5]
1737        // alpha=0.9 quantile index = round(4 * 0.9) = 4 => sorted[4] = 8.0
1738        // So delta = 8.0, meaning all residuals are within delta and treated as L2.
1739        let r = compute_regression_residuals(&y, &f, RegressionLoss::Huber, 0.9);
1740        // All residuals should be y - f.
1741        assert_relative_eq!(r[0], -0.5, epsilon = 1e-10);
1742        assert_relative_eq!(r[1], -0.5, epsilon = 1e-10);
1743        assert_relative_eq!(r[2], 8.0, epsilon = 1e-10);
1744        assert_relative_eq!(r[3], -0.5, epsilon = 1e-10);
1745        assert_relative_eq!(r[4], -0.5, epsilon = 1e-10);
1746
1747        // Test with lower alpha to trigger clipping.
1748        // alpha=0.1, quantile idx = round(4*0.1) = 0 => sorted[0] = 0.5
1749        // delta = 0.5, so the 8.0 residual is clipped.
1750        let r2 = compute_regression_residuals(&y, &f, RegressionLoss::Huber, 0.1);
1751        assert_relative_eq!(r2[0], -0.5, epsilon = 1e-10);
1752        // Third residual: diff=8.0 > delta=0.5, so clipped to delta=0.5.
1753        assert_relative_eq!(r2[2], 0.5, epsilon = 1e-10);
1754    }
1755
1756    #[test]
1757    fn test_gbc_multiclass_4_classes() {
1758        let x = Array2::from_shape_vec(
1759            (12, 1),
1760            vec![
1761                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1762            ],
1763        )
1764        .unwrap();
1765        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3];
1766
1767        let model = GradientBoostingClassifier::<f64>::new()
1768            .with_n_estimators(50)
1769            .with_random_state(42);
1770        let fitted = model.fit(&x, &y).unwrap();
1771        let preds = fitted.predict(&x).unwrap();
1772
1773        assert_eq!(preds.len(), 12);
1774        assert_eq!(fitted.n_classes(), 4);
1775    }
1776
1777    #[test]
1778    fn test_gbc_invalid_learning_rate() {
1779        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1780        let y = array![0, 0, 1, 1];
1781
1782        let model = GradientBoostingClassifier::<f64>::new()
1783            .with_n_estimators(5)
1784            .with_learning_rate(-0.1);
1785        assert!(model.fit(&x, &y).is_err());
1786    }
1787}