Skip to main content

ferrolearn_tree/
gradient_boosting.rs

1//! Gradient boosting classifiers and regressors.
2//!
3//! This module provides [`GradientBoostingClassifier`] and [`GradientBoostingRegressor`],
4//! which build ensembles of decision trees sequentially. Each tree fits the negative
5//! gradient (pseudo-residuals) of the loss function, progressively reducing prediction error.
6//!
7//! # Regression Losses
8//!
9//! - **`LeastSquares`** (L2): mean squared error; pseudo-residuals are `y - F(x)`.
10//! - **`Lad`** (L1): least absolute deviation; pseudo-residuals are `sign(y - F(x))`.
11//! - **`Huber`**: a blend of L2 (for small residuals) and L1 (for large residuals),
12//!   controlled by the `alpha` quantile parameter (default 0.9).
13//!
14//! # Classification Loss
15//!
16//! - **`LogLoss`**: binary and multiclass logistic loss. For binary classification a
17//!   single model is trained on log-odds; for *K*-class problems *K* trees are built
18//!   per boosting round (one-vs-rest in probability space via softmax).
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_tree::GradientBoostingRegressor;
24//! use ferrolearn_core::{Fit, Predict};
25//! use ndarray::{array, Array1, Array2};
26//!
27//! let x = Array2::from_shape_vec((8, 1), vec![
28//!     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
29//! ]).unwrap();
30//! let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
31//!
32//! let model = GradientBoostingRegressor::<f64>::new()
33//!     .with_n_estimators(50)
34//!     .with_learning_rate(0.1)
35//!     .with_random_state(42);
36//! let fitted = model.fit(&x, &y).unwrap();
37//! let preds = fitted.predict(&x).unwrap();
38//! assert_eq!(preds.len(), 8);
39//! ```
40
41use ferrolearn_core::error::FerroError;
42use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
43use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
44use ferrolearn_core::traits::{Fit, Predict};
45use ndarray::{Array1, Array2};
46use num_traits::{Float, FromPrimitive, ToPrimitive};
47use rand::SeedableRng;
48use rand::rngs::StdRng;
49use rand::seq::index::sample as rand_sample_indices;
50
51use crate::decision_tree::{
52    self, Node, build_regression_tree_with_feature_subset, compute_feature_importances,
53};
54
55// ---------------------------------------------------------------------------
56// Regression loss enum
57// ---------------------------------------------------------------------------
58
59/// Loss function for gradient boosting regression.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum RegressionLoss {
62    /// Least squares (L2) loss.
63    LeastSquares,
64    /// Least absolute deviation (L1) loss.
65    Lad,
66    /// Huber loss: L2 for small residuals, L1 for large residuals.
67    Huber,
68}
69
70/// Loss function for gradient boosting classification.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ClassificationLoss {
73    /// Log-loss (logistic / cross-entropy) for binary and multiclass.
74    LogLoss,
75}
76
77// ---------------------------------------------------------------------------
78// GradientBoostingRegressor
79// ---------------------------------------------------------------------------
80
81/// Gradient boosting regressor.
82///
83/// Builds an additive model in a forward stage-wise fashion, fitting each
84/// regression tree to the negative gradient of the loss function evaluated
85/// on the current ensemble prediction.
86///
87/// # Type Parameters
88///
89/// - `F`: The floating-point type (`f32` or `f64`).
90#[derive(Debug, Clone)]
91pub struct GradientBoostingRegressor<F> {
92    /// Number of boosting stages (trees).
93    pub n_estimators: usize,
94    /// Learning rate (shrinkage) applied to each tree's contribution.
95    pub learning_rate: f64,
96    /// Maximum depth of each tree.
97    pub max_depth: Option<usize>,
98    /// Minimum number of samples required to split an internal node.
99    pub min_samples_split: usize,
100    /// Minimum number of samples required in a leaf node.
101    pub min_samples_leaf: usize,
102    /// Fraction of samples to use for fitting each tree (stochastic boosting).
103    pub subsample: f64,
104    /// Loss function.
105    pub loss: RegressionLoss,
106    /// Alpha quantile for Huber loss (only used when `loss == Huber`).
107    pub huber_alpha: f64,
108    /// Random seed for reproducibility.
109    pub random_state: Option<u64>,
110    _marker: std::marker::PhantomData<F>,
111}
112
113impl<F: Float> GradientBoostingRegressor<F> {
114    /// Create a new `GradientBoostingRegressor` with default settings.
115    ///
116    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
117    /// `max_depth = Some(3)`, `min_samples_split = 2`,
118    /// `min_samples_leaf = 1`, `subsample = 1.0`,
119    /// `loss = LeastSquares`, `huber_alpha = 0.9`.
120    #[must_use]
121    pub fn new() -> Self {
122        Self {
123            n_estimators: 100,
124            learning_rate: 0.1,
125            max_depth: Some(3),
126            min_samples_split: 2,
127            min_samples_leaf: 1,
128            subsample: 1.0,
129            loss: RegressionLoss::LeastSquares,
130            huber_alpha: 0.9,
131            random_state: None,
132            _marker: std::marker::PhantomData,
133        }
134    }
135
136    /// Set the number of boosting stages.
137    #[must_use]
138    pub fn with_n_estimators(mut self, n: usize) -> Self {
139        self.n_estimators = n;
140        self
141    }
142
143    /// Set the learning rate (shrinkage).
144    #[must_use]
145    pub fn with_learning_rate(mut self, lr: f64) -> Self {
146        self.learning_rate = lr;
147        self
148    }
149
150    /// Set the maximum tree depth.
151    #[must_use]
152    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
153        self.max_depth = d;
154        self
155    }
156
157    /// Set the minimum number of samples to split a node.
158    #[must_use]
159    pub fn with_min_samples_split(mut self, n: usize) -> Self {
160        self.min_samples_split = n;
161        self
162    }
163
164    /// Set the minimum number of samples in a leaf.
165    #[must_use]
166    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
167        self.min_samples_leaf = n;
168        self
169    }
170
171    /// Set the subsample ratio (fraction of training data per tree).
172    #[must_use]
173    pub fn with_subsample(mut self, ratio: f64) -> Self {
174        self.subsample = ratio;
175        self
176    }
177
178    /// Set the loss function.
179    #[must_use]
180    pub fn with_loss(mut self, loss: RegressionLoss) -> Self {
181        self.loss = loss;
182        self
183    }
184
185    /// Set the alpha quantile for Huber loss.
186    #[must_use]
187    pub fn with_huber_alpha(mut self, alpha: f64) -> Self {
188        self.huber_alpha = alpha;
189        self
190    }
191
192    /// Set the random seed for reproducibility.
193    #[must_use]
194    pub fn with_random_state(mut self, seed: u64) -> Self {
195        self.random_state = Some(seed);
196        self
197    }
198}
199
200impl<F: Float> Default for GradientBoostingRegressor<F> {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206// ---------------------------------------------------------------------------
207// FittedGradientBoostingRegressor
208// ---------------------------------------------------------------------------
209
210/// A fitted gradient boosting regressor.
211///
212/// Stores the initial prediction (intercept) and the sequence of fitted trees.
213/// Predictions are computed as `init + learning_rate * sum(tree_predictions)`.
214#[derive(Debug, Clone)]
215pub struct FittedGradientBoostingRegressor<F> {
216    /// Initial prediction (mean of training targets for L2 loss, median for L1/Huber).
217    init: F,
218    /// Learning rate used during training.
219    learning_rate: F,
220    /// Sequence of fitted trees (one per boosting round).
221    trees: Vec<Vec<Node<F>>>,
222    /// Number of features.
223    n_features: usize,
224    /// Per-feature importance scores (normalised).
225    feature_importances: Array1<F>,
226}
227
228impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for GradientBoostingRegressor<F> {
229    type Fitted = FittedGradientBoostingRegressor<F>;
230    type Error = FerroError;
231
232    /// Fit the gradient boosting regressor.
233    ///
234    /// # Errors
235    ///
236    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
237    /// numbers of samples.
238    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
239    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
240    fn fit(
241        &self,
242        x: &Array2<F>,
243        y: &Array1<F>,
244    ) -> Result<FittedGradientBoostingRegressor<F>, FerroError> {
245        let (n_samples, n_features) = x.dim();
246
247        if n_samples != y.len() {
248            return Err(FerroError::ShapeMismatch {
249                expected: vec![n_samples],
250                actual: vec![y.len()],
251                context: "y length must match number of samples in X".into(),
252            });
253        }
254        if n_samples == 0 {
255            return Err(FerroError::InsufficientSamples {
256                required: 1,
257                actual: 0,
258                context: "GradientBoostingRegressor requires at least one sample".into(),
259            });
260        }
261        if self.n_estimators == 0 {
262            return Err(FerroError::InvalidParameter {
263                name: "n_estimators".into(),
264                reason: "must be at least 1".into(),
265            });
266        }
267        if self.learning_rate <= 0.0 {
268            return Err(FerroError::InvalidParameter {
269                name: "learning_rate".into(),
270                reason: "must be positive".into(),
271            });
272        }
273        if self.subsample <= 0.0 || self.subsample > 1.0 {
274            return Err(FerroError::InvalidParameter {
275                name: "subsample".into(),
276                reason: "must be in (0, 1]".into(),
277            });
278        }
279
280        let lr = F::from(self.learning_rate).unwrap();
281        let params = decision_tree::TreeParams {
282            max_depth: self.max_depth,
283            min_samples_split: self.min_samples_split,
284            min_samples_leaf: self.min_samples_leaf,
285        };
286
287        // Initial prediction.
288        let init = match self.loss {
289            RegressionLoss::LeastSquares => {
290                let sum: F = y.iter().copied().fold(F::zero(), |a, b| a + b);
291                sum / F::from(n_samples).unwrap()
292            }
293            RegressionLoss::Lad | RegressionLoss::Huber => median_f(y),
294        };
295
296        // Current predictions for each sample.
297        let mut f_vals = Array1::from_elem(n_samples, init);
298
299        let all_features: Vec<usize> = (0..n_features).collect();
300        let subsample_size = ((self.subsample * n_samples as f64).ceil() as usize)
301            .max(1)
302            .min(n_samples);
303
304        let mut rng = if let Some(seed) = self.random_state {
305            StdRng::seed_from_u64(seed)
306        } else {
307            use rand::RngCore;
308            StdRng::seed_from_u64(rand::rng().next_u64())
309        };
310
311        let mut trees = Vec::with_capacity(self.n_estimators);
312
313        for _ in 0..self.n_estimators {
314            // Compute pseudo-residuals (negative gradient).
315            let residuals = compute_regression_residuals(y, &f_vals, self.loss, self.huber_alpha);
316
317            // Subsample indices.
318            let sample_indices = if subsample_size < n_samples {
319                rand_sample_indices(&mut rng, n_samples, subsample_size).into_vec()
320            } else {
321                (0..n_samples).collect()
322            };
323
324            // Build a regression tree on the pseudo-residuals.
325            let tree = build_regression_tree_with_feature_subset(
326                x,
327                &residuals,
328                &sample_indices,
329                &all_features,
330                &params,
331            );
332
333            // Update predictions.
334            for i in 0..n_samples {
335                let row = x.row(i);
336                let leaf_idx = decision_tree::traverse(&tree, &row);
337                if let Node::Leaf { value, .. } = tree[leaf_idx] {
338                    f_vals[i] = f_vals[i] + lr * value;
339                }
340            }
341
342            trees.push(tree);
343        }
344
345        // Compute feature importances across all trees.
346        let mut total_importances = Array1::<F>::zeros(n_features);
347        for tree_nodes in &trees {
348            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
349            total_importances = total_importances + tree_imp;
350        }
351        let imp_sum: F = total_importances
352            .iter()
353            .copied()
354            .fold(F::zero(), |a, b| a + b);
355        if imp_sum > F::zero() {
356            total_importances.mapv_inplace(|v| v / imp_sum);
357        }
358
359        Ok(FittedGradientBoostingRegressor {
360            init,
361            learning_rate: lr,
362            trees,
363            n_features,
364            feature_importances: total_importances,
365        })
366    }
367}
368
369impl<F: Float + Send + Sync + 'static> FittedGradientBoostingRegressor<F> {
370    /// Returns the initial prediction (intercept) of the boosted model.
371    #[must_use]
372    pub fn init(&self) -> F {
373        self.init
374    }
375
376    /// Returns the learning rate used during training.
377    #[must_use]
378    pub fn learning_rate(&self) -> F {
379        self.learning_rate
380    }
381
382    /// Returns a reference to the sequence of fitted trees.
383    #[must_use]
384    pub fn trees(&self) -> &[Vec<Node<F>>] {
385        &self.trees
386    }
387
388    /// Returns the number of features the model was trained on.
389    #[must_use]
390    pub fn n_features(&self) -> usize {
391        self.n_features
392    }
393
394    /// R² coefficient of determination on the given test data.
395    /// Equivalent to sklearn's `RegressorMixin.score`.
396    ///
397    /// # Errors
398    ///
399    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
400    /// the feature count does not match the training data.
401    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
402        if x.nrows() != y.len() {
403            return Err(FerroError::ShapeMismatch {
404                expected: vec![x.nrows()],
405                actual: vec![y.len()],
406                context: "y length must match number of samples in X".into(),
407            });
408        }
409        let preds = self.predict(x)?;
410        Ok(crate::r2_score(&preds, y))
411    }
412}
413
414impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedGradientBoostingRegressor<F> {
415    type Output = Array1<F>;
416    type Error = FerroError;
417
418    /// Predict target values.
419    ///
420    /// # Errors
421    ///
422    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
423    /// not match the fitted model.
424    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
425        if x.ncols() != self.n_features {
426            return Err(FerroError::ShapeMismatch {
427                expected: vec![self.n_features],
428                actual: vec![x.ncols()],
429                context: "number of features must match fitted model".into(),
430            });
431        }
432
433        let n_samples = x.nrows();
434        let mut predictions = Array1::from_elem(n_samples, self.init);
435
436        for i in 0..n_samples {
437            let row = x.row(i);
438            for tree_nodes in &self.trees {
439                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
440                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
441                    predictions[i] = predictions[i] + self.learning_rate * value;
442                }
443            }
444        }
445
446        Ok(predictions)
447    }
448}
449
450impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
451    for FittedGradientBoostingRegressor<F>
452{
453    fn feature_importances(&self) -> &Array1<F> {
454        &self.feature_importances
455    }
456}
457
458// Pipeline integration.
459impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for GradientBoostingRegressor<F> {
460    fn fit_pipeline(
461        &self,
462        x: &Array2<F>,
463        y: &Array1<F>,
464    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
465        let fitted = self.fit(x, y)?;
466        Ok(Box::new(fitted))
467    }
468}
469
470impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
471    for FittedGradientBoostingRegressor<F>
472{
473    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
474        self.predict(x)
475    }
476}
477
478// ---------------------------------------------------------------------------
479// GradientBoostingClassifier
480// ---------------------------------------------------------------------------
481
482/// Gradient boosting classifier.
483///
484/// For binary classification a single model is trained on log-odds residuals.
485/// For multiclass (*K* classes), *K* regression trees are built per boosting
486/// round (one-vs-rest in probability space via softmax).
487///
488/// # Type Parameters
489///
490/// - `F`: The floating-point type (`f32` or `f64`).
491#[derive(Debug, Clone)]
492pub struct GradientBoostingClassifier<F> {
493    /// Number of boosting stages.
494    pub n_estimators: usize,
495    /// Learning rate (shrinkage).
496    pub learning_rate: f64,
497    /// Maximum depth of each tree.
498    pub max_depth: Option<usize>,
499    /// Minimum number of samples required to split an internal node.
500    pub min_samples_split: usize,
501    /// Minimum number of samples required in a leaf node.
502    pub min_samples_leaf: usize,
503    /// Fraction of samples to use for fitting each tree.
504    pub subsample: f64,
505    /// Classification loss function.
506    pub loss: ClassificationLoss,
507    /// Random seed for reproducibility.
508    pub random_state: Option<u64>,
509    _marker: std::marker::PhantomData<F>,
510}
511
512impl<F: Float> GradientBoostingClassifier<F> {
513    /// Create a new `GradientBoostingClassifier` with default settings.
514    ///
515    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
516    /// `max_depth = Some(3)`, `min_samples_split = 2`,
517    /// `min_samples_leaf = 1`, `subsample = 1.0`,
518    /// `loss = LogLoss`.
519    #[must_use]
520    pub fn new() -> Self {
521        Self {
522            n_estimators: 100,
523            learning_rate: 0.1,
524            max_depth: Some(3),
525            min_samples_split: 2,
526            min_samples_leaf: 1,
527            subsample: 1.0,
528            loss: ClassificationLoss::LogLoss,
529            random_state: None,
530            _marker: std::marker::PhantomData,
531        }
532    }
533
534    /// Set the number of boosting stages.
535    #[must_use]
536    pub fn with_n_estimators(mut self, n: usize) -> Self {
537        self.n_estimators = n;
538        self
539    }
540
541    /// Set the learning rate (shrinkage).
542    #[must_use]
543    pub fn with_learning_rate(mut self, lr: f64) -> Self {
544        self.learning_rate = lr;
545        self
546    }
547
548    /// Set the maximum tree depth.
549    #[must_use]
550    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
551        self.max_depth = d;
552        self
553    }
554
555    /// Set the minimum number of samples to split a node.
556    #[must_use]
557    pub fn with_min_samples_split(mut self, n: usize) -> Self {
558        self.min_samples_split = n;
559        self
560    }
561
562    /// Set the minimum number of samples in a leaf.
563    #[must_use]
564    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
565        self.min_samples_leaf = n;
566        self
567    }
568
569    /// Set the subsample ratio.
570    #[must_use]
571    pub fn with_subsample(mut self, ratio: f64) -> Self {
572        self.subsample = ratio;
573        self
574    }
575
576    /// Set the random seed for reproducibility.
577    #[must_use]
578    pub fn with_random_state(mut self, seed: u64) -> Self {
579        self.random_state = Some(seed);
580        self
581    }
582}
583
584impl<F: Float> Default for GradientBoostingClassifier<F> {
585    fn default() -> Self {
586        Self::new()
587    }
588}
589
590// ---------------------------------------------------------------------------
591// FittedGradientBoostingClassifier
592// ---------------------------------------------------------------------------
593
594/// A fitted gradient boosting classifier.
595///
596/// For binary classification, stores a single sequence of trees predicting log-odds.
597/// For multiclass, stores `K` sequences of trees (one per class).
598#[derive(Debug, Clone)]
599pub struct FittedGradientBoostingClassifier<F> {
600    /// Sorted unique class labels.
601    classes: Vec<usize>,
602    /// Initial predictions per class (log-odds or log-prior).
603    init: Vec<F>,
604    /// Learning rate.
605    learning_rate: F,
606    /// Trees: for binary, `trees[0]` has all trees. For multiclass,
607    /// `trees[k]` has trees for class k.
608    trees: Vec<Vec<Vec<Node<F>>>>,
609    /// Number of features.
610    n_features: usize,
611    /// Per-feature importance scores (normalised).
612    feature_importances: Array1<F>,
613}
614
615impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>>
616    for GradientBoostingClassifier<F>
617{
618    type Fitted = FittedGradientBoostingClassifier<F>;
619    type Error = FerroError;
620
621    /// Fit the gradient boosting classifier.
622    ///
623    /// # Errors
624    ///
625    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
626    /// numbers of samples.
627    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
628    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
629    fn fit(
630        &self,
631        x: &Array2<F>,
632        y: &Array1<usize>,
633    ) -> Result<FittedGradientBoostingClassifier<F>, FerroError> {
634        let (n_samples, n_features) = x.dim();
635
636        if n_samples != y.len() {
637            return Err(FerroError::ShapeMismatch {
638                expected: vec![n_samples],
639                actual: vec![y.len()],
640                context: "y length must match number of samples in X".into(),
641            });
642        }
643        if n_samples == 0 {
644            return Err(FerroError::InsufficientSamples {
645                required: 1,
646                actual: 0,
647                context: "GradientBoostingClassifier requires at least one sample".into(),
648            });
649        }
650        if self.n_estimators == 0 {
651            return Err(FerroError::InvalidParameter {
652                name: "n_estimators".into(),
653                reason: "must be at least 1".into(),
654            });
655        }
656        if self.learning_rate <= 0.0 {
657            return Err(FerroError::InvalidParameter {
658                name: "learning_rate".into(),
659                reason: "must be positive".into(),
660            });
661        }
662        if self.subsample <= 0.0 || self.subsample > 1.0 {
663            return Err(FerroError::InvalidParameter {
664                name: "subsample".into(),
665                reason: "must be in (0, 1]".into(),
666            });
667        }
668
669        // Determine unique classes.
670        let mut classes: Vec<usize> = y.iter().copied().collect();
671        classes.sort_unstable();
672        classes.dedup();
673        let n_classes = classes.len();
674
675        if n_classes < 2 {
676            return Err(FerroError::InvalidParameter {
677                name: "y".into(),
678                reason: "need at least 2 distinct classes".into(),
679            });
680        }
681
682        let y_mapped: Vec<usize> = y
683            .iter()
684            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
685            .collect();
686
687        let lr = F::from(self.learning_rate).unwrap();
688        let params = decision_tree::TreeParams {
689            max_depth: self.max_depth,
690            min_samples_split: self.min_samples_split,
691            min_samples_leaf: self.min_samples_leaf,
692        };
693
694        let all_features: Vec<usize> = (0..n_features).collect();
695        let subsample_size = ((self.subsample * n_samples as f64).ceil() as usize)
696            .max(1)
697            .min(n_samples);
698
699        let mut rng = if let Some(seed) = self.random_state {
700            StdRng::seed_from_u64(seed)
701        } else {
702            use rand::RngCore;
703            StdRng::seed_from_u64(rand::rng().next_u64())
704        };
705
706        if n_classes == 2 {
707            // Binary classification: single model on log-odds.
708            self.fit_binary(
709                x,
710                &y_mapped,
711                n_samples,
712                n_features,
713                &classes,
714                lr,
715                &params,
716                &all_features,
717                subsample_size,
718                &mut rng,
719            )
720        } else {
721            // Multiclass: K trees per round.
722            self.fit_multiclass(
723                x,
724                &y_mapped,
725                n_samples,
726                n_features,
727                n_classes,
728                &classes,
729                lr,
730                &params,
731                &all_features,
732                subsample_size,
733                &mut rng,
734            )
735        }
736    }
737}
738
739impl<F: Float + Send + Sync + 'static> GradientBoostingClassifier<F> {
740    /// Fit binary classification (log-loss on log-odds).
741    #[allow(clippy::too_many_arguments)]
742    fn fit_binary(
743        &self,
744        x: &Array2<F>,
745        y_mapped: &[usize],
746        n_samples: usize,
747        n_features: usize,
748        classes: &[usize],
749        lr: F,
750        params: &decision_tree::TreeParams,
751        all_features: &[usize],
752        subsample_size: usize,
753        rng: &mut StdRng,
754    ) -> Result<FittedGradientBoostingClassifier<F>, FerroError> {
755        // Count positive class proportion for initial log-odds.
756        let pos_count = y_mapped.iter().filter(|&&c| c == 1).count();
757        let p = F::from(pos_count).unwrap() / F::from(n_samples).unwrap();
758        let eps = F::from(1e-15).unwrap();
759        let p_clipped = p.max(eps).min(F::one() - eps);
760        let init_val = (p_clipped / (F::one() - p_clipped)).ln();
761
762        let mut f_vals = Array1::from_elem(n_samples, init_val);
763        let mut trees_seq: Vec<Vec<Node<F>>> = Vec::with_capacity(self.n_estimators);
764
765        for _ in 0..self.n_estimators {
766            // Compute probabilities from current log-odds.
767            let probs: Vec<F> = f_vals.iter().map(|&fv| sigmoid(fv)).collect();
768
769            // Pseudo-residuals: y - p.
770            let mut residuals = Array1::zeros(n_samples);
771            for i in 0..n_samples {
772                let yi = F::from(y_mapped[i]).unwrap();
773                residuals[i] = yi - probs[i];
774            }
775
776            // Subsample.
777            let sample_indices = if subsample_size < n_samples {
778                rand_sample_indices(rng, n_samples, subsample_size).into_vec()
779            } else {
780                (0..n_samples).collect()
781            };
782
783            // Build tree on residuals.
784            let tree = build_regression_tree_with_feature_subset(
785                x,
786                &residuals,
787                &sample_indices,
788                all_features,
789                params,
790            );
791
792            // Update f_vals.
793            for i in 0..n_samples {
794                let row = x.row(i);
795                let leaf_idx = decision_tree::traverse(&tree, &row);
796                if let Node::Leaf { value, .. } = tree[leaf_idx] {
797                    f_vals[i] = f_vals[i] + lr * value;
798                }
799            }
800
801            trees_seq.push(tree);
802        }
803
804        // Feature importances.
805        let mut total_importances = Array1::<F>::zeros(n_features);
806        for tree_nodes in &trees_seq {
807            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
808            total_importances = total_importances + tree_imp;
809        }
810        let imp_sum: F = total_importances
811            .iter()
812            .copied()
813            .fold(F::zero(), |a, b| a + b);
814        if imp_sum > F::zero() {
815            total_importances.mapv_inplace(|v| v / imp_sum);
816        }
817
818        Ok(FittedGradientBoostingClassifier {
819            classes: classes.to_vec(),
820            init: vec![init_val],
821            learning_rate: lr,
822            trees: vec![trees_seq],
823            n_features,
824            feature_importances: total_importances,
825        })
826    }
827
828    /// Fit multiclass classification (K trees per round, softmax).
829    #[allow(clippy::too_many_arguments)]
830    fn fit_multiclass(
831        &self,
832        x: &Array2<F>,
833        y_mapped: &[usize],
834        n_samples: usize,
835        n_features: usize,
836        n_classes: usize,
837        classes: &[usize],
838        lr: F,
839        params: &decision_tree::TreeParams,
840        all_features: &[usize],
841        subsample_size: usize,
842        rng: &mut StdRng,
843    ) -> Result<FittedGradientBoostingClassifier<F>, FerroError> {
844        // Initial log-prior for each class.
845        let mut class_counts = vec![0usize; n_classes];
846        for &c in y_mapped {
847            class_counts[c] += 1;
848        }
849        let n_f = F::from(n_samples).unwrap();
850        let eps = F::from(1e-15).unwrap();
851        let init_vals: Vec<F> = class_counts
852            .iter()
853            .map(|&cnt| {
854                let p = (F::from(cnt).unwrap() / n_f).max(eps);
855                p.ln()
856            })
857            .collect();
858
859        // f_vals[k][i] = current raw score for class k, sample i.
860        let mut f_vals: Vec<Array1<F>> = init_vals
861            .iter()
862            .map(|&init| Array1::from_elem(n_samples, init))
863            .collect();
864
865        let mut trees_per_class: Vec<Vec<Vec<Node<F>>>> = (0..n_classes)
866            .map(|_| Vec::with_capacity(self.n_estimators))
867            .collect();
868
869        for _ in 0..self.n_estimators {
870            // Compute softmax probabilities.
871            let probs = softmax_matrix(&f_vals, n_samples, n_classes);
872
873            // Subsample.
874            let sample_indices = if subsample_size < n_samples {
875                rand_sample_indices(rng, n_samples, subsample_size).into_vec()
876            } else {
877                (0..n_samples).collect()
878            };
879
880            // For each class, compute residuals and fit a tree.
881            for k in 0..n_classes {
882                let mut residuals = Array1::zeros(n_samples);
883                for i in 0..n_samples {
884                    let yi_k = if y_mapped[i] == k {
885                        F::one()
886                    } else {
887                        F::zero()
888                    };
889                    residuals[i] = yi_k - probs[k][i];
890                }
891
892                let tree = build_regression_tree_with_feature_subset(
893                    x,
894                    &residuals,
895                    &sample_indices,
896                    all_features,
897                    params,
898                );
899
900                // Update f_vals for class k.
901                for (i, fv) in f_vals[k].iter_mut().enumerate() {
902                    let row = x.row(i);
903                    let leaf_idx = decision_tree::traverse(&tree, &row);
904                    if let Node::Leaf { value, .. } = tree[leaf_idx] {
905                        *fv = *fv + lr * value;
906                    }
907                }
908
909                trees_per_class[k].push(tree);
910            }
911        }
912
913        // Feature importances aggregated across all classes and rounds.
914        let mut total_importances = Array1::<F>::zeros(n_features);
915        for class_trees in &trees_per_class {
916            for tree_nodes in class_trees {
917                let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
918                total_importances = total_importances + tree_imp;
919            }
920        }
921        let imp_sum: F = total_importances
922            .iter()
923            .copied()
924            .fold(F::zero(), |a, b| a + b);
925        if imp_sum > F::zero() {
926            total_importances.mapv_inplace(|v| v / imp_sum);
927        }
928
929        Ok(FittedGradientBoostingClassifier {
930            classes: classes.to_vec(),
931            init: init_vals,
932            learning_rate: lr,
933            trees: trees_per_class,
934            n_features,
935            feature_importances: total_importances,
936        })
937    }
938}
939
940impl<F: Float + Send + Sync + 'static> FittedGradientBoostingClassifier<F> {
941    /// Returns the initial predictions per class (log-odds or log-prior).
942    #[must_use]
943    pub fn init(&self) -> &[F] {
944        &self.init
945    }
946
947    /// Returns the learning rate used during training.
948    #[must_use]
949    pub fn learning_rate(&self) -> F {
950        self.learning_rate
951    }
952
953    /// Returns a reference to the tree ensemble.
954    ///
955    /// For binary classification, `trees()[0]` contains all trees.
956    /// For multiclass, `trees()[k]` contains trees for class `k`.
957    #[must_use]
958    pub fn trees(&self) -> &[Vec<Vec<Node<F>>>] {
959        &self.trees
960    }
961
962    /// Returns the number of features the model was trained on.
963    #[must_use]
964    pub fn n_features(&self) -> usize {
965        self.n_features
966    }
967
968    /// Mean accuracy on the given test data and labels.
969    /// Equivalent to sklearn's `ClassifierMixin.score`.
970    ///
971    /// # Errors
972    ///
973    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
974    /// the feature count does not match the training data.
975    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
976        if x.nrows() != y.len() {
977            return Err(FerroError::ShapeMismatch {
978                expected: vec![x.nrows()],
979                actual: vec![y.len()],
980                context: "y length must match number of samples in X".into(),
981            });
982        }
983        let preds = self.predict(x)?;
984        Ok(crate::mean_accuracy(&preds, y))
985    }
986
987    /// Predict class probabilities. Mirrors sklearn's
988    /// `GradientBoostingClassifier.predict_proba`.
989    ///
990    /// Binary: applies the logistic link to the cumulative log-odds.
991    /// Multiclass: softmax over K cumulative scores.
992    ///
993    /// Returns shape `(n_samples, n_classes)`; rows sum to 1.
994    ///
995    /// # Errors
996    ///
997    /// Returns [`FerroError::ShapeMismatch`] if the number of features
998    /// does not match the fitted model.
999    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1000        if x.ncols() != self.n_features {
1001            return Err(FerroError::ShapeMismatch {
1002                expected: vec![self.n_features],
1003                actual: vec![x.ncols()],
1004                context: "number of features must match fitted model".into(),
1005            });
1006        }
1007        let n_samples = x.nrows();
1008        let n_classes = self.classes.len();
1009        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
1010
1011        if n_classes == 2 {
1012            let init = self.init[0];
1013            for i in 0..n_samples {
1014                let row = x.row(i);
1015                let mut f_val = init;
1016                for tree_nodes in &self.trees[0] {
1017                    let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1018                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1019                        f_val = f_val + self.learning_rate * value;
1020                    }
1021                }
1022                let p1 = sigmoid(f_val);
1023                proba[[i, 0]] = F::one() - p1;
1024                proba[[i, 1]] = p1;
1025            }
1026        } else {
1027            for i in 0..n_samples {
1028                let row = x.row(i);
1029                let mut scores = vec![F::zero(); n_classes];
1030                for k in 0..n_classes {
1031                    let mut f_val = self.init[k];
1032                    for tree_nodes in &self.trees[k] {
1033                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1034                        if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1035                            f_val = f_val + self.learning_rate * value;
1036                        }
1037                    }
1038                    scores[k] = f_val;
1039                }
1040                let max_s = scores
1041                    .iter()
1042                    .copied()
1043                    .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
1044                let mut sum_exp = F::zero();
1045                for k in 0..n_classes {
1046                    let e = (scores[k] - max_s).exp();
1047                    proba[[i, k]] = e;
1048                    sum_exp = sum_exp + e;
1049                }
1050                if sum_exp > F::zero() {
1051                    for k in 0..n_classes {
1052                        proba[[i, k]] = proba[[i, k]] / sum_exp;
1053                    }
1054                }
1055            }
1056        }
1057        Ok(proba)
1058    }
1059
1060    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
1061    /// sklearn's `ClassifierMixin.predict_log_proba`.
1062    ///
1063    /// # Errors
1064    ///
1065    /// Forwards any error from [`predict_proba`](Self::predict_proba).
1066    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1067        let proba = self.predict_proba(x)?;
1068        Ok(crate::log_proba(&proba))
1069    }
1070
1071    /// Cumulative raw scores per sample (pre-link). Mirrors sklearn's
1072    /// `GradientBoostingClassifier.decision_function`.
1073    ///
1074    /// Binary: shape `(n_samples, 1)` containing the cumulative log-odds.
1075    /// Multiclass: shape `(n_samples, n_classes)` containing per-class
1076    /// cumulative scores. (sklearn returns shape `(n_samples,)` for the
1077    /// binary case; ferrolearn keeps a 2-D shape for type-uniformity.)
1078    ///
1079    /// # Errors
1080    ///
1081    /// Returns [`FerroError::ShapeMismatch`] if the number of features
1082    /// does not match the fitted model.
1083    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1084        if x.ncols() != self.n_features {
1085            return Err(FerroError::ShapeMismatch {
1086                expected: vec![self.n_features],
1087                actual: vec![x.ncols()],
1088                context: "number of features must match fitted model".into(),
1089            });
1090        }
1091        let n_samples = x.nrows();
1092        let n_classes = self.classes.len();
1093
1094        if n_classes == 2 {
1095            let init = self.init[0];
1096            let mut out = Array2::<F>::zeros((n_samples, 1));
1097            for i in 0..n_samples {
1098                let row = x.row(i);
1099                let mut f_val = init;
1100                for tree_nodes in &self.trees[0] {
1101                    let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1102                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1103                        f_val = f_val + self.learning_rate * value;
1104                    }
1105                }
1106                out[[i, 0]] = f_val;
1107            }
1108            Ok(out)
1109        } else {
1110            let mut out = Array2::<F>::zeros((n_samples, n_classes));
1111            for i in 0..n_samples {
1112                let row = x.row(i);
1113                for k in 0..n_classes {
1114                    let mut f_val = self.init[k];
1115                    for tree_nodes in &self.trees[k] {
1116                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1117                        if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1118                            f_val = f_val + self.learning_rate * value;
1119                        }
1120                    }
1121                    out[[i, k]] = f_val;
1122                }
1123            }
1124            Ok(out)
1125        }
1126    }
1127}
1128
1129impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedGradientBoostingClassifier<F> {
1130    type Output = Array1<usize>;
1131    type Error = FerroError;
1132
1133    /// Predict class labels.
1134    ///
1135    /// # Errors
1136    ///
1137    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1138    /// not match the fitted model.
1139    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
1140        if x.ncols() != self.n_features {
1141            return Err(FerroError::ShapeMismatch {
1142                expected: vec![self.n_features],
1143                actual: vec![x.ncols()],
1144                context: "number of features must match fitted model".into(),
1145            });
1146        }
1147
1148        let n_samples = x.nrows();
1149        let n_classes = self.classes.len();
1150
1151        if n_classes == 2 {
1152            // Binary: single log-odds model.
1153            let init = self.init[0];
1154            let mut predictions = Array1::zeros(n_samples);
1155            for i in 0..n_samples {
1156                let row = x.row(i);
1157                let mut f_val = init;
1158                for tree_nodes in &self.trees[0] {
1159                    let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1160                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1161                        f_val = f_val + self.learning_rate * value;
1162                    }
1163                }
1164                let prob = sigmoid(f_val);
1165                let class_idx = if prob >= F::from(0.5).unwrap() { 1 } else { 0 };
1166                predictions[i] = self.classes[class_idx];
1167            }
1168            Ok(predictions)
1169        } else {
1170            // Multiclass: K models, argmax of softmax.
1171            let mut predictions = Array1::zeros(n_samples);
1172            for i in 0..n_samples {
1173                let row = x.row(i);
1174                let mut scores = Vec::with_capacity(n_classes);
1175                for k in 0..n_classes {
1176                    let mut f_val = self.init[k];
1177                    for tree_nodes in &self.trees[k] {
1178                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
1179                        if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
1180                            f_val = f_val + self.learning_rate * value;
1181                        }
1182                    }
1183                    scores.push(f_val);
1184                }
1185                let best_k = scores
1186                    .iter()
1187                    .enumerate()
1188                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1189                    .map_or(0, |(k, _)| k);
1190                predictions[i] = self.classes[best_k];
1191            }
1192            Ok(predictions)
1193        }
1194    }
1195}
1196
1197impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1198    for FittedGradientBoostingClassifier<F>
1199{
1200    fn feature_importances(&self) -> &Array1<F> {
1201        &self.feature_importances
1202    }
1203}
1204
1205impl<F: Float + Send + Sync + 'static> HasClasses for FittedGradientBoostingClassifier<F> {
1206    fn classes(&self) -> &[usize] {
1207        &self.classes
1208    }
1209
1210    fn n_classes(&self) -> usize {
1211        self.classes.len()
1212    }
1213}
1214
1215// Pipeline integration.
1216impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
1217    for GradientBoostingClassifier<F>
1218{
1219    fn fit_pipeline(
1220        &self,
1221        x: &Array2<F>,
1222        y: &Array1<F>,
1223    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1224        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
1225        let fitted = self.fit(x, &y_usize)?;
1226        Ok(Box::new(FittedGbcPipelineAdapter(fitted)))
1227    }
1228}
1229
1230/// Pipeline adapter for `FittedGradientBoostingClassifier<F>`.
1231struct FittedGbcPipelineAdapter<F: Float + Send + Sync + 'static>(
1232    FittedGradientBoostingClassifier<F>,
1233);
1234
1235impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
1236    for FittedGbcPipelineAdapter<F>
1237{
1238    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1239        let preds = self.0.predict(x)?;
1240        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
1241    }
1242}
1243
1244// ---------------------------------------------------------------------------
1245// Internal helpers
1246// ---------------------------------------------------------------------------
1247
1248/// Sigmoid function: 1 / (1 + exp(-x)).
1249fn sigmoid<F: Float>(x: F) -> F {
1250    F::one() / (F::one() + (-x).exp())
1251}
1252
1253/// Compute softmax probabilities for each class across all samples.
1254///
1255/// Returns `probs[k][i]` = probability of class k for sample i.
1256fn softmax_matrix<F: Float>(
1257    f_vals: &[Array1<F>],
1258    n_samples: usize,
1259    n_classes: usize,
1260) -> Vec<Vec<F>> {
1261    let mut probs: Vec<Vec<F>> = vec![vec![F::zero(); n_samples]; n_classes];
1262
1263    for i in 0..n_samples {
1264        // Find max for numerical stability.
1265        let max_val = (0..n_classes)
1266            .map(|k| f_vals[k][i])
1267            .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
1268
1269        let mut sum = F::zero();
1270        let mut exps = vec![F::zero(); n_classes];
1271        for k in 0..n_classes {
1272            exps[k] = (f_vals[k][i] - max_val).exp();
1273            sum = sum + exps[k];
1274        }
1275
1276        let eps = F::from(1e-15).unwrap();
1277        if sum < eps {
1278            sum = eps;
1279        }
1280
1281        for k in 0..n_classes {
1282            probs[k][i] = exps[k] / sum;
1283        }
1284    }
1285
1286    probs
1287}
1288
1289/// Compute the median of an Array1.
1290fn median_f<F: Float>(arr: &Array1<F>) -> F {
1291    let mut sorted: Vec<F> = arr.iter().copied().collect();
1292    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1293    let n = sorted.len();
1294    if n == 0 {
1295        return F::zero();
1296    }
1297    if n % 2 == 1 {
1298        sorted[n / 2]
1299    } else {
1300        (sorted[n / 2 - 1] + sorted[n / 2]) / F::from(2.0).unwrap()
1301    }
1302}
1303
1304/// Compute the quantile of a slice at level `alpha` (0..1).
1305fn quantile_f<F: Float>(vals: &[F], alpha: f64) -> F {
1306    if vals.is_empty() {
1307        return F::zero();
1308    }
1309    let mut sorted: Vec<F> = vals.to_vec();
1310    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1311    let idx = ((sorted.len() as f64 - 1.0) * alpha).round() as usize;
1312    let idx = idx.min(sorted.len() - 1);
1313    sorted[idx]
1314}
1315
1316/// Compute pseudo-residuals (negative gradient) for regression losses.
1317fn compute_regression_residuals<F: Float>(
1318    y: &Array1<F>,
1319    f_vals: &Array1<F>,
1320    loss: RegressionLoss,
1321    huber_alpha: f64,
1322) -> Array1<F> {
1323    let n = y.len();
1324    match loss {
1325        RegressionLoss::LeastSquares => {
1326            // negative gradient of 0.5*(y - f)^2 is (y - f)
1327            let mut residuals = Array1::zeros(n);
1328            for i in 0..n {
1329                residuals[i] = y[i] - f_vals[i];
1330            }
1331            residuals
1332        }
1333        RegressionLoss::Lad => {
1334            // negative gradient of |y - f| is sign(y - f)
1335            let mut residuals = Array1::zeros(n);
1336            for i in 0..n {
1337                let diff = y[i] - f_vals[i];
1338                residuals[i] = if diff > F::zero() {
1339                    F::one()
1340                } else if diff < F::zero() {
1341                    -F::one()
1342                } else {
1343                    F::zero()
1344                };
1345            }
1346            residuals
1347        }
1348        RegressionLoss::Huber => {
1349            // Compute residuals and delta from quantile.
1350            let raw_residuals: Vec<F> = (0..n).map(|i| (y[i] - f_vals[i]).abs()).collect();
1351            let delta = quantile_f(&raw_residuals, huber_alpha);
1352
1353            let mut residuals = Array1::zeros(n);
1354            for i in 0..n {
1355                let diff = y[i] - f_vals[i];
1356                if diff.abs() <= delta {
1357                    residuals[i] = diff;
1358                } else if diff > F::zero() {
1359                    residuals[i] = delta;
1360                } else {
1361                    residuals[i] = -delta;
1362                }
1363            }
1364            residuals
1365        }
1366    }
1367}
1368
1369// ---------------------------------------------------------------------------
1370// Tests
1371// ---------------------------------------------------------------------------
1372
1373#[cfg(test)]
1374mod tests {
1375    use super::*;
1376    use approx::assert_relative_eq;
1377    use ndarray::array;
1378
1379    // -- Regressor tests --
1380
1381    #[test]
1382    fn test_gbr_simple_least_squares() {
1383        let x =
1384            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1385        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1386
1387        let model = GradientBoostingRegressor::<f64>::new()
1388            .with_n_estimators(50)
1389            .with_learning_rate(0.1)
1390            .with_random_state(42);
1391        let fitted = model.fit(&x, &y).unwrap();
1392        let preds = fitted.predict(&x).unwrap();
1393
1394        assert_eq!(preds.len(), 8);
1395        for i in 0..4 {
1396            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
1397        }
1398        for i in 4..8 {
1399            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
1400        }
1401    }
1402
1403    #[test]
1404    fn test_gbr_lad_loss() {
1405        let x =
1406            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1407        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1408
1409        let model = GradientBoostingRegressor::<f64>::new()
1410            .with_n_estimators(50)
1411            .with_loss(RegressionLoss::Lad)
1412            .with_random_state(42);
1413        let fitted = model.fit(&x, &y).unwrap();
1414        let preds = fitted.predict(&x).unwrap();
1415
1416        assert_eq!(preds.len(), 8);
1417        // LAD should still separate the two groups.
1418        for i in 0..4 {
1419            assert!(preds[i] < 3.5, "LAD expected <3.5, got {}", preds[i]);
1420        }
1421        for i in 4..8 {
1422            assert!(preds[i] > 2.5, "LAD expected >2.5, got {}", preds[i]);
1423        }
1424    }
1425
1426    #[test]
1427    fn test_gbr_huber_loss() {
1428        let x =
1429            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1430        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1431
1432        let model = GradientBoostingRegressor::<f64>::new()
1433            .with_n_estimators(50)
1434            .with_loss(RegressionLoss::Huber)
1435            .with_huber_alpha(0.9)
1436            .with_random_state(42);
1437        let fitted = model.fit(&x, &y).unwrap();
1438        let preds = fitted.predict(&x).unwrap();
1439
1440        assert_eq!(preds.len(), 8);
1441    }
1442
1443    #[test]
1444    fn test_gbr_reproducibility() {
1445        let x =
1446            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1447        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1448
1449        let model = GradientBoostingRegressor::<f64>::new()
1450            .with_n_estimators(20)
1451            .with_random_state(123);
1452
1453        let fitted1 = model.fit(&x, &y).unwrap();
1454        let fitted2 = model.fit(&x, &y).unwrap();
1455
1456        let preds1 = fitted1.predict(&x).unwrap();
1457        let preds2 = fitted2.predict(&x).unwrap();
1458
1459        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
1460            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1461        }
1462    }
1463
1464    #[test]
1465    fn test_gbr_feature_importances() {
1466        let x = Array2::from_shape_vec(
1467            (10, 3),
1468            vec![
1469                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
1470                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
1471            ],
1472        )
1473        .unwrap();
1474        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
1475
1476        let model = GradientBoostingRegressor::<f64>::new()
1477            .with_n_estimators(20)
1478            .with_random_state(42);
1479        let fitted = model.fit(&x, &y).unwrap();
1480        let importances = fitted.feature_importances();
1481
1482        assert_eq!(importances.len(), 3);
1483        // First feature should be most important.
1484        assert!(importances[0] > importances[1]);
1485        assert!(importances[0] > importances[2]);
1486    }
1487
1488    #[test]
1489    fn test_gbr_shape_mismatch_fit() {
1490        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1491        let y = array![1.0, 2.0];
1492
1493        let model = GradientBoostingRegressor::<f64>::new().with_n_estimators(5);
1494        assert!(model.fit(&x, &y).is_err());
1495    }
1496
1497    #[test]
1498    fn test_gbr_shape_mismatch_predict() {
1499        let x =
1500            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1501        let y = array![1.0, 2.0, 3.0, 4.0];
1502
1503        let model = GradientBoostingRegressor::<f64>::new()
1504            .with_n_estimators(5)
1505            .with_random_state(0);
1506        let fitted = model.fit(&x, &y).unwrap();
1507
1508        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1509        assert!(fitted.predict(&x_bad).is_err());
1510    }
1511
1512    #[test]
1513    fn test_gbr_empty_data() {
1514        let x = Array2::<f64>::zeros((0, 2));
1515        let y = Array1::<f64>::zeros(0);
1516
1517        let model = GradientBoostingRegressor::<f64>::new().with_n_estimators(5);
1518        assert!(model.fit(&x, &y).is_err());
1519    }
1520
1521    #[test]
1522    fn test_gbr_zero_estimators() {
1523        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1524        let y = array![1.0, 2.0, 3.0, 4.0];
1525
1526        let model = GradientBoostingRegressor::<f64>::new().with_n_estimators(0);
1527        assert!(model.fit(&x, &y).is_err());
1528    }
1529
1530    #[test]
1531    fn test_gbr_invalid_learning_rate() {
1532        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1533        let y = array![1.0, 2.0, 3.0, 4.0];
1534
1535        let model = GradientBoostingRegressor::<f64>::new()
1536            .with_n_estimators(5)
1537            .with_learning_rate(0.0);
1538        assert!(model.fit(&x, &y).is_err());
1539    }
1540
1541    #[test]
1542    fn test_gbr_invalid_subsample() {
1543        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1544        let y = array![1.0, 2.0, 3.0, 4.0];
1545
1546        let model = GradientBoostingRegressor::<f64>::new()
1547            .with_n_estimators(5)
1548            .with_subsample(0.0);
1549        assert!(model.fit(&x, &y).is_err());
1550
1551        let model2 = GradientBoostingRegressor::<f64>::new()
1552            .with_n_estimators(5)
1553            .with_subsample(1.5);
1554        assert!(model2.fit(&x, &y).is_err());
1555    }
1556
1557    #[test]
1558    fn test_gbr_subsample() {
1559        let x =
1560            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1561        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1562
1563        let model = GradientBoostingRegressor::<f64>::new()
1564            .with_n_estimators(50)
1565            .with_subsample(0.5)
1566            .with_random_state(42);
1567        let fitted = model.fit(&x, &y).unwrap();
1568        let preds = fitted.predict(&x).unwrap();
1569
1570        assert_eq!(preds.len(), 8);
1571    }
1572
1573    #[test]
1574    fn test_gbr_pipeline_integration() {
1575        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1576        let y = array![1.0, 2.0, 3.0, 4.0];
1577
1578        let model = GradientBoostingRegressor::<f64>::new()
1579            .with_n_estimators(10)
1580            .with_random_state(42);
1581        let fitted = model.fit_pipeline(&x, &y).unwrap();
1582        let preds = fitted.predict_pipeline(&x).unwrap();
1583        assert_eq!(preds.len(), 4);
1584    }
1585
1586    #[test]
1587    fn test_gbr_f32_support() {
1588        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1589        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1590
1591        let model = GradientBoostingRegressor::<f32>::new()
1592            .with_n_estimators(10)
1593            .with_random_state(42);
1594        let fitted = model.fit(&x, &y).unwrap();
1595        let preds = fitted.predict(&x).unwrap();
1596        assert_eq!(preds.len(), 4);
1597    }
1598
1599    #[test]
1600    fn test_gbr_max_depth() {
1601        let x =
1602            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1603        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1604
1605        let model = GradientBoostingRegressor::<f64>::new()
1606            .with_n_estimators(20)
1607            .with_max_depth(Some(1))
1608            .with_random_state(42);
1609        let fitted = model.fit(&x, &y).unwrap();
1610        let preds = fitted.predict(&x).unwrap();
1611        assert_eq!(preds.len(), 8);
1612    }
1613
1614    #[test]
1615    fn test_gbr_default_trait() {
1616        let model = GradientBoostingRegressor::<f64>::default();
1617        assert_eq!(model.n_estimators, 100);
1618        assert!((model.learning_rate - 0.1).abs() < 1e-10);
1619    }
1620
1621    // -- Classifier tests --
1622
1623    #[test]
1624    fn test_gbc_binary_simple() {
1625        let x = Array2::from_shape_vec(
1626            (8, 2),
1627            vec![
1628                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1629            ],
1630        )
1631        .unwrap();
1632        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1633
1634        let model = GradientBoostingClassifier::<f64>::new()
1635            .with_n_estimators(50)
1636            .with_learning_rate(0.1)
1637            .with_random_state(42);
1638        let fitted = model.fit(&x, &y).unwrap();
1639        let preds = fitted.predict(&x).unwrap();
1640
1641        assert_eq!(preds.len(), 8);
1642        for i in 0..4 {
1643            assert_eq!(preds[i], 0, "Expected 0 at index {}, got {}", i, preds[i]);
1644        }
1645        for i in 4..8 {
1646            assert_eq!(preds[i], 1, "Expected 1 at index {}, got {}", i, preds[i]);
1647        }
1648    }
1649
1650    #[test]
1651    fn test_gbc_multiclass() {
1652        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1653            .unwrap();
1654        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1655
1656        let model = GradientBoostingClassifier::<f64>::new()
1657            .with_n_estimators(50)
1658            .with_learning_rate(0.1)
1659            .with_random_state(42);
1660        let fitted = model.fit(&x, &y).unwrap();
1661        let preds = fitted.predict(&x).unwrap();
1662
1663        assert_eq!(preds.len(), 9);
1664        // At least training data should mostly be correct.
1665        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
1666        assert!(
1667            correct >= 6,
1668            "Expected at least 6/9 correct, got {correct}/9"
1669        );
1670    }
1671
1672    #[test]
1673    fn test_gbc_has_classes() {
1674        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1675        let y = array![0, 1, 2, 0, 1, 2];
1676
1677        let model = GradientBoostingClassifier::<f64>::new()
1678            .with_n_estimators(5)
1679            .with_random_state(0);
1680        let fitted = model.fit(&x, &y).unwrap();
1681
1682        assert_eq!(fitted.classes(), &[0, 1, 2]);
1683        assert_eq!(fitted.n_classes(), 3);
1684    }
1685
1686    #[test]
1687    fn test_gbc_reproducibility() {
1688        let x = Array2::from_shape_vec(
1689            (8, 2),
1690            vec![
1691                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1692            ],
1693        )
1694        .unwrap();
1695        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1696
1697        let model = GradientBoostingClassifier::<f64>::new()
1698            .with_n_estimators(10)
1699            .with_random_state(42);
1700
1701        let fitted1 = model.fit(&x, &y).unwrap();
1702        let fitted2 = model.fit(&x, &y).unwrap();
1703
1704        let preds1 = fitted1.predict(&x).unwrap();
1705        let preds2 = fitted2.predict(&x).unwrap();
1706        assert_eq!(preds1, preds2);
1707    }
1708
1709    #[test]
1710    fn test_gbc_feature_importances() {
1711        let x = Array2::from_shape_vec(
1712            (10, 3),
1713            vec![
1714                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
1715                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
1716            ],
1717        )
1718        .unwrap();
1719        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
1720
1721        let model = GradientBoostingClassifier::<f64>::new()
1722            .with_n_estimators(20)
1723            .with_random_state(42);
1724        let fitted = model.fit(&x, &y).unwrap();
1725        let importances = fitted.feature_importances();
1726
1727        assert_eq!(importances.len(), 3);
1728        assert!(importances[0] > importances[1]);
1729        assert!(importances[0] > importances[2]);
1730    }
1731
1732    #[test]
1733    fn test_gbc_shape_mismatch_fit() {
1734        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1735        let y = array![0, 1];
1736
1737        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(5);
1738        assert!(model.fit(&x, &y).is_err());
1739    }
1740
1741    #[test]
1742    fn test_gbc_shape_mismatch_predict() {
1743        let x =
1744            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1745        let y = array![0, 0, 1, 1];
1746
1747        let model = GradientBoostingClassifier::<f64>::new()
1748            .with_n_estimators(5)
1749            .with_random_state(0);
1750        let fitted = model.fit(&x, &y).unwrap();
1751
1752        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1753        assert!(fitted.predict(&x_bad).is_err());
1754    }
1755
1756    #[test]
1757    fn test_gbc_empty_data() {
1758        let x = Array2::<f64>::zeros((0, 2));
1759        let y = Array1::<usize>::zeros(0);
1760
1761        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(5);
1762        assert!(model.fit(&x, &y).is_err());
1763    }
1764
1765    #[test]
1766    fn test_gbc_single_class() {
1767        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1768        let y = array![0, 0, 0];
1769
1770        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(5);
1771        assert!(model.fit(&x, &y).is_err());
1772    }
1773
1774    #[test]
1775    fn test_gbc_zero_estimators() {
1776        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1777        let y = array![0, 0, 1, 1];
1778
1779        let model = GradientBoostingClassifier::<f64>::new().with_n_estimators(0);
1780        assert!(model.fit(&x, &y).is_err());
1781    }
1782
1783    #[test]
1784    fn test_gbc_pipeline_integration() {
1785        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1786        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1787
1788        let model = GradientBoostingClassifier::<f64>::new()
1789            .with_n_estimators(10)
1790            .with_random_state(42);
1791        let fitted = model.fit_pipeline(&x, &y).unwrap();
1792        let preds = fitted.predict_pipeline(&x).unwrap();
1793        assert_eq!(preds.len(), 6);
1794    }
1795
1796    #[test]
1797    fn test_gbc_f32_support() {
1798        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1799        let y = array![0, 0, 0, 1, 1, 1];
1800
1801        let model = GradientBoostingClassifier::<f32>::new()
1802            .with_n_estimators(10)
1803            .with_random_state(42);
1804        let fitted = model.fit(&x, &y).unwrap();
1805        let preds = fitted.predict(&x).unwrap();
1806        assert_eq!(preds.len(), 6);
1807    }
1808
1809    #[test]
1810    fn test_gbc_subsample() {
1811        let x = Array2::from_shape_vec(
1812            (8, 2),
1813            vec![
1814                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1815            ],
1816        )
1817        .unwrap();
1818        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1819
1820        let model = GradientBoostingClassifier::<f64>::new()
1821            .with_n_estimators(20)
1822            .with_subsample(0.5)
1823            .with_random_state(42);
1824        let fitted = model.fit(&x, &y).unwrap();
1825        let preds = fitted.predict(&x).unwrap();
1826        assert_eq!(preds.len(), 8);
1827    }
1828
1829    #[test]
1830    fn test_gbc_default_trait() {
1831        let model = GradientBoostingClassifier::<f64>::default();
1832        assert_eq!(model.n_estimators, 100);
1833        assert!((model.learning_rate - 0.1).abs() < 1e-10);
1834    }
1835
1836    #[test]
1837    fn test_gbc_non_contiguous_labels() {
1838        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1839        let y = array![10, 10, 10, 20, 20, 20];
1840
1841        let model = GradientBoostingClassifier::<f64>::new()
1842            .with_n_estimators(20)
1843            .with_random_state(42);
1844        let fitted = model.fit(&x, &y).unwrap();
1845        let preds = fitted.predict(&x).unwrap();
1846
1847        assert_eq!(preds.len(), 6);
1848        for &p in &preds {
1849            assert!(p == 10 || p == 20);
1850        }
1851    }
1852
1853    // -- Helper tests --
1854
1855    #[test]
1856    fn test_sigmoid() {
1857        assert_relative_eq!(sigmoid(0.0f64), 0.5, epsilon = 1e-10);
1858        assert!(sigmoid(10.0f64) > 0.999);
1859        assert!(sigmoid(-10.0f64) < 0.001);
1860    }
1861
1862    #[test]
1863    fn test_median_f_odd() {
1864        let arr = array![3.0, 1.0, 2.0];
1865        assert_relative_eq!(median_f(&arr), 2.0, epsilon = 1e-10);
1866    }
1867
1868    #[test]
1869    fn test_median_f_even() {
1870        let arr = array![4.0, 1.0, 3.0, 2.0];
1871        assert_relative_eq!(median_f(&arr), 2.5, epsilon = 1e-10);
1872    }
1873
1874    #[test]
1875    fn test_median_f_empty() {
1876        let arr = Array1::<f64>::zeros(0);
1877        assert_relative_eq!(median_f(&arr), 0.0, epsilon = 1e-10);
1878    }
1879
1880    #[test]
1881    fn test_quantile_f() {
1882        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1883        let q90 = quantile_f(&vals, 0.9);
1884        assert!((4.0..=5.0).contains(&q90));
1885    }
1886
1887    #[test]
1888    fn test_regression_residuals_least_squares() {
1889        let y = array![1.0, 2.0, 3.0];
1890        let f = array![0.5, 2.5, 2.0];
1891        let r = compute_regression_residuals(&y, &f, RegressionLoss::LeastSquares, 0.9);
1892        assert_relative_eq!(r[0], 0.5, epsilon = 1e-10);
1893        assert_relative_eq!(r[1], -0.5, epsilon = 1e-10);
1894        assert_relative_eq!(r[2], 1.0, epsilon = 1e-10);
1895    }
1896
1897    #[test]
1898    fn test_regression_residuals_lad() {
1899        let y = array![1.0, 2.0, 3.0];
1900        let f = array![0.5, 2.5, 3.0];
1901        let r = compute_regression_residuals(&y, &f, RegressionLoss::Lad, 0.9);
1902        assert_relative_eq!(r[0], 1.0, epsilon = 1e-10);
1903        assert_relative_eq!(r[1], -1.0, epsilon = 1e-10);
1904        assert_relative_eq!(r[2], 0.0, epsilon = 1e-10);
1905    }
1906
1907    #[test]
1908    fn test_regression_residuals_huber() {
1909        let y = array![1.0, 2.0, 10.0, 3.0, 4.0];
1910        let f = array![1.5, 2.5, 2.0, 3.5, 4.5];
1911        // abs residuals: [0.5, 0.5, 8.0, 0.5, 0.5]
1912        // alpha=0.9 quantile index = round(4 * 0.9) = 4 => sorted[4] = 8.0
1913        // So delta = 8.0, meaning all residuals are within delta and treated as L2.
1914        let r = compute_regression_residuals(&y, &f, RegressionLoss::Huber, 0.9);
1915        // All residuals should be y - f.
1916        assert_relative_eq!(r[0], -0.5, epsilon = 1e-10);
1917        assert_relative_eq!(r[1], -0.5, epsilon = 1e-10);
1918        assert_relative_eq!(r[2], 8.0, epsilon = 1e-10);
1919        assert_relative_eq!(r[3], -0.5, epsilon = 1e-10);
1920        assert_relative_eq!(r[4], -0.5, epsilon = 1e-10);
1921
1922        // Test with lower alpha to trigger clipping.
1923        // alpha=0.1, quantile idx = round(4*0.1) = 0 => sorted[0] = 0.5
1924        // delta = 0.5, so the 8.0 residual is clipped.
1925        let r2 = compute_regression_residuals(&y, &f, RegressionLoss::Huber, 0.1);
1926        assert_relative_eq!(r2[0], -0.5, epsilon = 1e-10);
1927        // Third residual: diff=8.0 > delta=0.5, so clipped to delta=0.5.
1928        assert_relative_eq!(r2[2], 0.5, epsilon = 1e-10);
1929    }
1930
1931    #[test]
1932    fn test_gbc_multiclass_4_classes() {
1933        let x = Array2::from_shape_vec(
1934            (12, 1),
1935            vec![
1936                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1937            ],
1938        )
1939        .unwrap();
1940        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3];
1941
1942        let model = GradientBoostingClassifier::<f64>::new()
1943            .with_n_estimators(50)
1944            .with_random_state(42);
1945        let fitted = model.fit(&x, &y).unwrap();
1946        let preds = fitted.predict(&x).unwrap();
1947
1948        assert_eq!(preds.len(), 12);
1949        assert_eq!(fitted.n_classes(), 4);
1950    }
1951
1952    #[test]
1953    fn test_gbc_invalid_learning_rate() {
1954        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1955        let y = array![0, 0, 1, 1];
1956
1957        let model = GradientBoostingClassifier::<f64>::new()
1958            .with_n_estimators(5)
1959            .with_learning_rate(-0.1);
1960        assert!(model.fit(&x, &y).is_err());
1961    }
1962}