Skip to main content

ferrolearn_tree/
adaboost.rs

1//! AdaBoost classifier.
2//!
3//! This module provides [`AdaBoostClassifier`], which implements the Adaptive
4//! Boosting algorithm using decision tree stumps (depth-1 trees) as base
5//! estimators. Two algorithm variants are supported:
6//!
7//! - **SAMME**: uses discrete class predictions and works with multiclass
8//!   problems directly.
9//! - **SAMME.R** (default): uses class probability estimates, typically giving
10//!   better performance than SAMME.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::AdaBoostClassifier;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let x = Array2::from_shape_vec((8, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
22//! ]).unwrap();
23//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
24//!
25//! let model = AdaBoostClassifier::<f64>::new()
26//!     .with_n_estimators(50)
27//!     .with_random_state(42);
28//! let fitted = model.fit(&x, &y).unwrap();
29//! let preds = fitted.predict(&x).unwrap();
30//! ```
31
32use crate::decision_tree::{
33    self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
34};
35use ferrolearn_core::error::FerroError;
36use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
37use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2};
40use num_traits::{Float, FromPrimitive, ToPrimitive};
41
42// ---------------------------------------------------------------------------
43// Algorithm enum
44// ---------------------------------------------------------------------------
45
46/// AdaBoost algorithm variant.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AdaBoostAlgorithm {
49    /// SAMME: Stagewise Additive Modeling using a Multi-class Exponential loss.
50    ///
51    /// Uses discrete class predictions from each base estimator.
52    Samme,
53    /// SAMME.R: the "real" variant that uses class probability estimates.
54    ///
55    /// Generally outperforms SAMME.
56    SammeR,
57}
58
59// ---------------------------------------------------------------------------
60// AdaBoostClassifier
61// ---------------------------------------------------------------------------
62
63/// AdaBoost classifier using decision tree stumps as base estimators.
64///
65/// At each boosting round a decision tree stump (max depth = 1) is fitted
66/// to the weighted training data. Misclassified samples receive higher
67/// weight in subsequent rounds, allowing the ensemble to focus on hard
68/// examples.
69///
70/// # Type Parameters
71///
72/// - `F`: The floating-point type (`f32` or `f64`).
73#[derive(Debug, Clone)]
74pub struct AdaBoostClassifier<F> {
75    /// Number of boosting stages (stumps).
76    pub n_estimators: usize,
77    /// Learning rate (shrinkage). Lower values require more estimators.
78    pub learning_rate: f64,
79    /// Algorithm variant (`SAMME` or `SAMME.R`).
80    pub algorithm: AdaBoostAlgorithm,
81    /// Random seed for reproducibility.
82    pub random_state: Option<u64>,
83    _marker: std::marker::PhantomData<F>,
84}
85
86impl<F: Float> AdaBoostClassifier<F> {
87    /// Create a new `AdaBoostClassifier` with default settings.
88    ///
89    /// Defaults: `n_estimators = 50`, `learning_rate = 1.0`,
90    /// `algorithm = SAMME`, `random_state = None`.
91    ///
92    /// The default algorithm is `SAMME` to match scikit-learn ≥ 1.4,
93    /// which removed `SAMME.R` in 1.6 and made `SAMME` the only option.
94    #[must_use]
95    pub fn new() -> Self {
96        Self {
97            n_estimators: 50,
98            learning_rate: 1.0,
99            algorithm: AdaBoostAlgorithm::Samme,
100            random_state: None,
101            _marker: std::marker::PhantomData,
102        }
103    }
104
105    /// Set the number of boosting stages.
106    #[must_use]
107    pub fn with_n_estimators(mut self, n: usize) -> Self {
108        self.n_estimators = n;
109        self
110    }
111
112    /// Set the learning rate.
113    #[must_use]
114    pub fn with_learning_rate(mut self, lr: f64) -> Self {
115        self.learning_rate = lr;
116        self
117    }
118
119    /// Set the algorithm variant.
120    #[must_use]
121    pub fn with_algorithm(mut self, algo: AdaBoostAlgorithm) -> Self {
122        self.algorithm = algo;
123        self
124    }
125
126    /// Set the random seed for reproducibility.
127    #[must_use]
128    pub fn with_random_state(mut self, seed: u64) -> Self {
129        self.random_state = Some(seed);
130        self
131    }
132}
133
134impl<F: Float> Default for AdaBoostClassifier<F> {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140// ---------------------------------------------------------------------------
141// FittedAdaBoostClassifier
142// ---------------------------------------------------------------------------
143
144/// A fitted AdaBoost classifier.
145///
146/// Stores the sequence of stumps and their weights. Predictions are made
147/// by weighted majority vote (SAMME) or weighted probability averaging
148/// (SAMME.R).
149#[derive(Debug, Clone)]
150pub struct FittedAdaBoostClassifier<F> {
151    /// Sorted unique class labels.
152    classes: Vec<usize>,
153    /// Sequence of fitted tree stumps.
154    estimators: Vec<Vec<Node<F>>>,
155    /// Weight of each estimator (SAMME) or kept for SAMME.R bookkeeping.
156    estimator_weights: Vec<F>,
157    /// Number of features.
158    n_features: usize,
159    /// Number of classes.
160    n_classes: usize,
161    /// Algorithm used.
162    algorithm: AdaBoostAlgorithm,
163    /// Per-feature importance scores aggregated across the boosted stumps,
164    /// weighted by `estimator_weights` (normalized to sum to 1).
165    feature_importances: Array1<F>,
166}
167
168impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedAdaBoostClassifier<F> {
169    fn feature_importances(&self) -> &Array1<F> {
170        &self.feature_importances
171    }
172}
173
174impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for AdaBoostClassifier<F> {
175    type Fitted = FittedAdaBoostClassifier<F>;
176    type Error = FerroError;
177
178    /// Fit the AdaBoost classifier.
179    ///
180    /// # Errors
181    ///
182    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
183    /// numbers of samples.
184    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
185    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
186    fn fit(
187        &self,
188        x: &Array2<F>,
189        y: &Array1<usize>,
190    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
191        let (n_samples, n_features) = x.dim();
192
193        if n_samples != y.len() {
194            return Err(FerroError::ShapeMismatch {
195                expected: vec![n_samples],
196                actual: vec![y.len()],
197                context: "y length must match number of samples in X".into(),
198            });
199        }
200        if n_samples == 0 {
201            return Err(FerroError::InsufficientSamples {
202                required: 1,
203                actual: 0,
204                context: "AdaBoostClassifier requires at least one sample".into(),
205            });
206        }
207        if self.n_estimators == 0 {
208            return Err(FerroError::InvalidParameter {
209                name: "n_estimators".into(),
210                reason: "must be at least 1".into(),
211            });
212        }
213        if self.learning_rate <= 0.0 {
214            return Err(FerroError::InvalidParameter {
215                name: "learning_rate".into(),
216                reason: "must be positive".into(),
217            });
218        }
219
220        // Determine unique classes.
221        let mut classes: Vec<usize> = y.iter().copied().collect();
222        classes.sort_unstable();
223        classes.dedup();
224        let n_classes = classes.len();
225
226        if n_classes < 2 {
227            return Err(FerroError::InvalidParameter {
228                name: "y".into(),
229                reason: "need at least 2 distinct classes".into(),
230            });
231        }
232
233        let y_mapped: Vec<usize> = y
234            .iter()
235            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
236            .collect();
237
238        match self.algorithm {
239            AdaBoostAlgorithm::Samme => {
240                self.fit_samme(x, &y_mapped, n_samples, n_features, n_classes, &classes)
241            }
242            AdaBoostAlgorithm::SammeR => {
243                self.fit_samme_r(x, &y_mapped, n_samples, n_features, n_classes, &classes)
244            }
245        }
246    }
247}
248
249impl<F: Float + Send + Sync + 'static> AdaBoostClassifier<F> {
250    /// Fit using the SAMME algorithm (discrete predictions).
251    fn fit_samme(
252        &self,
253        x: &Array2<F>,
254        y_mapped: &[usize],
255        n_samples: usize,
256        n_features: usize,
257        n_classes: usize,
258        classes: &[usize],
259    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
260        let lr = F::from(self.learning_rate).unwrap();
261        let n_f = F::from(n_samples).unwrap();
262        let eps = F::from(1e-10).unwrap();
263
264        // Initialize sample weights uniformly.
265        let mut weights = vec![F::one() / n_f; n_samples];
266
267        let all_features: Vec<usize> = (0..n_features).collect();
268        let stump_params = decision_tree::TreeParams {
269            max_depth: Some(1),
270            min_samples_split: 2,
271            min_samples_leaf: 1,
272        };
273
274        let mut estimators = Vec::with_capacity(self.n_estimators);
275        let mut estimator_weights = Vec::with_capacity(self.n_estimators);
276
277        for _ in 0..self.n_estimators {
278            // Build weighted sample indices: replicate indices proportional to weight.
279            let indices = resample_weighted(&weights, n_samples);
280
281            let tree = build_classification_tree_with_feature_subset(
282                x,
283                y_mapped,
284                n_classes,
285                &indices,
286                &all_features,
287                &stump_params,
288                ClassificationCriterion::Gini,
289            );
290
291            // Compute predictions and weighted error.
292            let mut weighted_error = F::zero();
293            let mut preds = vec![0usize; n_samples];
294            for i in 0..n_samples {
295                let row = x.row(i);
296                let leaf_idx = decision_tree::traverse(&tree, &row);
297                if let Node::Leaf { value, .. } = tree[leaf_idx] {
298                    preds[i] = value.to_f64().map_or(0, |f| f.round() as usize);
299                }
300                if preds[i] != y_mapped[i] {
301                    weighted_error = weighted_error + weights[i];
302                }
303            }
304
305            // Normalise error.
306            let weight_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
307            let err = if weight_sum > F::zero() {
308                weighted_error / weight_sum
309            } else {
310                F::from(0.5).unwrap()
311            };
312
313            // If error is too high or zero, stop or skip.
314            if err >= F::one() - F::one() / F::from(n_classes).unwrap() {
315                // Error too high; stop boosting.
316                if estimators.is_empty() {
317                    // Keep at least one estimator.
318                    estimators.push(tree);
319                    estimator_weights.push(F::one());
320                }
321                break;
322            }
323
324            // Estimator weight: SAMME formula.
325            let alpha = lr * ((F::one() - err).max(eps) / err.max(eps)).ln()
326                + lr * (F::from(n_classes - 1).unwrap()).ln();
327
328            // Update sample weights.
329            for i in 0..n_samples {
330                if preds[i] != y_mapped[i] {
331                    weights[i] = weights[i] * alpha.exp();
332                }
333            }
334
335            // Normalise weights.
336            let new_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
337            if new_sum > F::zero() {
338                for w in &mut weights {
339                    *w = *w / new_sum;
340                }
341            }
342
343            estimators.push(tree);
344            estimator_weights.push(alpha);
345        }
346
347        let feature_importances = decision_tree::aggregate_tree_importances(
348            &estimators,
349            None,
350            Some(&estimator_weights),
351            n_features,
352        );
353
354        Ok(FittedAdaBoostClassifier {
355            classes: classes.to_vec(),
356            estimators,
357            estimator_weights,
358            n_features,
359            n_classes,
360            algorithm: AdaBoostAlgorithm::Samme,
361            feature_importances,
362        })
363    }
364
365    /// Fit using the SAMME.R algorithm (real-valued / probability-based).
366    fn fit_samme_r(
367        &self,
368        x: &Array2<F>,
369        y_mapped: &[usize],
370        n_samples: usize,
371        n_features: usize,
372        n_classes: usize,
373        classes: &[usize],
374    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
375        let lr = F::from(self.learning_rate).unwrap();
376        let n_f = F::from(n_samples).unwrap();
377        let eps = F::from(1e-10).unwrap();
378        let k_f = F::from(n_classes).unwrap();
379
380        // Initialize sample weights uniformly.
381        let mut weights = vec![F::one() / n_f; n_samples];
382
383        let all_features: Vec<usize> = (0..n_features).collect();
384        let stump_params = decision_tree::TreeParams {
385            max_depth: Some(1),
386            min_samples_split: 2,
387            min_samples_leaf: 1,
388        };
389
390        let mut estimators = Vec::with_capacity(self.n_estimators);
391        let mut estimator_weights = Vec::with_capacity(self.n_estimators);
392
393        for _ in 0..self.n_estimators {
394            let indices = resample_weighted(&weights, n_samples);
395
396            let tree = build_classification_tree_with_feature_subset(
397                x,
398                y_mapped,
399                n_classes,
400                &indices,
401                &all_features,
402                &stump_params,
403                ClassificationCriterion::Gini,
404            );
405
406            // Get class probability estimates for each sample.
407            let mut proba = vec![vec![F::zero(); n_classes]; n_samples];
408            for (i, proba_row) in proba.iter_mut().enumerate() {
409                let row = x.row(i);
410                let leaf_idx = decision_tree::traverse(&tree, &row);
411                if let Node::Leaf {
412                    class_distribution: Some(ref dist),
413                    ..
414                } = tree[leaf_idx]
415                {
416                    for (k, &p) in dist.iter().enumerate() {
417                        proba_row[k] = p.max(eps);
418                    }
419                } else {
420                    // Fallback: uniform.
421                    for val in proba_row.iter_mut() {
422                        *val = F::one() / k_f;
423                    }
424                }
425                // Normalise.
426                let row_sum: F = proba_row.iter().copied().fold(F::zero(), |a, b| a + b);
427                if row_sum > F::zero() {
428                    for val in proba_row.iter_mut() {
429                        *val = *val / row_sum;
430                    }
431                }
432            }
433
434            // SAMME.R weight update: based on log-probability.
435            // h_k(x) = (K-1) * (log(p_k(x)) - (1/K) * sum_j log(p_j(x)))
436            // Then update: w_i *= exp(-(K-1)/K * lr * sum_k y_{ik} * log(p_k(x)))
437            // Simplified: w_i *= exp(-lr * (K-1)/K * log(p_{y_i}(x)))
438            let factor = lr * (k_f - F::one()) / k_f;
439            let mut any_update = false;
440
441            for i in 0..n_samples {
442                let p_correct = proba[i][y_mapped[i]].max(eps);
443                let exponent = -factor * p_correct.ln();
444                weights[i] = weights[i] * exponent.exp();
445                if exponent.abs() > eps {
446                    any_update = true;
447                }
448            }
449
450            // Normalise weights.
451            let new_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
452            if new_sum > F::zero() {
453                for w in &mut weights {
454                    *w = *w / new_sum;
455                }
456            }
457
458            estimators.push(tree);
459            estimator_weights.push(F::one()); // SAMME.R uses equal weight; prediction uses probabilities.
460
461            if !any_update {
462                break;
463            }
464        }
465
466        let feature_importances = decision_tree::aggregate_tree_importances(
467            &estimators,
468            None,
469            Some(&estimator_weights),
470            n_features,
471        );
472
473        Ok(FittedAdaBoostClassifier {
474            classes: classes.to_vec(),
475            estimators,
476            estimator_weights,
477            n_features,
478            n_classes,
479            algorithm: AdaBoostAlgorithm::SammeR,
480            feature_importances,
481        })
482    }
483}
484
485impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedAdaBoostClassifier<F> {
486    type Output = Array1<usize>;
487    type Error = FerroError;
488
489    /// Predict class labels.
490    ///
491    /// - **SAMME**: weighted majority vote using estimator weights.
492    /// - **SAMME.R**: weighted average of log-probabilities.
493    ///
494    /// # Errors
495    ///
496    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
497    /// not match the fitted model.
498    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
499        if x.ncols() != self.n_features {
500            return Err(FerroError::ShapeMismatch {
501                expected: vec![self.n_features],
502                actual: vec![x.ncols()],
503                context: "number of features must match fitted model".into(),
504            });
505        }
506
507        let n_samples = x.nrows();
508
509        match self.algorithm {
510            AdaBoostAlgorithm::Samme => self.predict_samme(x, n_samples),
511            AdaBoostAlgorithm::SammeR => self.predict_samme_r(x, n_samples),
512        }
513    }
514}
515
516impl<F: Float + Send + Sync + 'static> FittedAdaBoostClassifier<F> {
517    /// Predict using SAMME (weighted majority vote).
518    fn predict_samme(&self, x: &Array2<F>, n_samples: usize) -> Result<Array1<usize>, FerroError> {
519        let mut predictions = Array1::zeros(n_samples);
520
521        for i in 0..n_samples {
522            let row = x.row(i);
523            let mut class_scores = vec![F::zero(); self.n_classes];
524
525            for (t, tree_nodes) in self.estimators.iter().enumerate() {
526                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
527                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
528                    let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
529                    if class_idx < self.n_classes {
530                        class_scores[class_idx] =
531                            class_scores[class_idx] + self.estimator_weights[t];
532                    }
533                }
534            }
535
536            let best = class_scores
537                .iter()
538                .enumerate()
539                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
540                .map_or(0, |(k, _)| k);
541            predictions[i] = self.classes[best];
542        }
543
544        Ok(predictions)
545    }
546
547    /// Predict using SAMME.R (weighted probability averaging).
548    fn predict_samme_r(
549        &self,
550        x: &Array2<F>,
551        n_samples: usize,
552    ) -> Result<Array1<usize>, FerroError> {
553        let eps = F::from(1e-10).unwrap();
554        let k_f = F::from(self.n_classes).unwrap();
555        let k_minus_1 = k_f - F::one();
556
557        let mut predictions = Array1::zeros(n_samples);
558
559        for i in 0..n_samples {
560            let row = x.row(i);
561            let mut accumulated = vec![F::zero(); self.n_classes];
562
563            for tree_nodes in &self.estimators {
564                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
565                if let Node::Leaf {
566                    class_distribution: Some(ref dist),
567                    ..
568                } = tree_nodes[leaf_idx]
569                {
570                    // h_k(x) = (K-1) * (log(p_k) - mean(log(p_j)))
571                    let log_probs: Vec<F> = dist.iter().map(|&p| p.max(eps).ln()).collect();
572                    let mean_log: F = log_probs.iter().copied().fold(F::zero(), |a, b| a + b) / k_f;
573
574                    for k in 0..self.n_classes {
575                        accumulated[k] = accumulated[k] + k_minus_1 * (log_probs[k] - mean_log);
576                    }
577                } else {
578                    // Leaf without distribution: predict from value.
579                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
580                        let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
581                        if class_idx < self.n_classes {
582                            accumulated[class_idx] = accumulated[class_idx] + F::one();
583                        }
584                    }
585                }
586            }
587
588            let best = accumulated
589                .iter()
590                .enumerate()
591                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
592                .map_or(0, |(k, _)| k);
593            predictions[i] = self.classes[best];
594        }
595
596        Ok(predictions)
597    }
598
599    /// Mean accuracy on the given test data and labels.
600    /// Equivalent to sklearn's `ClassifierMixin.score`.
601    ///
602    /// # Errors
603    ///
604    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
605    /// the feature count does not match the training data.
606    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
607        if x.nrows() != y.len() {
608            return Err(FerroError::ShapeMismatch {
609                expected: vec![x.nrows()],
610                actual: vec![y.len()],
611                context: "y length must match number of samples in X".into(),
612            });
613        }
614        let preds = self.predict(x)?;
615        Ok(crate::mean_accuracy(&preds, y))
616    }
617
618    /// Predict class probabilities for each sample. Mirrors sklearn's
619    /// `AdaBoostClassifier.predict_proba`.
620    ///
621    /// SAMME: normalizes the weighted-vote vector per row.
622    /// SAMME.R: applies softmax to the accumulated `(K-1)*(log p_k - mean)`
623    /// scores per row.
624    ///
625    /// Returns shape `(n_samples, n_classes)`; rows sum to 1.
626    ///
627    /// # Errors
628    ///
629    /// Returns [`FerroError::ShapeMismatch`] if the number of features
630    /// does not match the fitted model.
631    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
632        if x.ncols() != self.n_features {
633            return Err(FerroError::ShapeMismatch {
634                expected: vec![self.n_features],
635                actual: vec![x.ncols()],
636                context: "number of features must match fitted model".into(),
637            });
638        }
639        let n_samples = x.nrows();
640        let n_classes = self.n_classes;
641        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
642
643        match self.algorithm {
644            AdaBoostAlgorithm::Samme => {
645                for i in 0..n_samples {
646                    let row = x.row(i);
647                    let mut scores = vec![F::zero(); n_classes];
648                    for (t, tree_nodes) in self.estimators.iter().enumerate() {
649                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
650                        if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
651                            let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
652                            if class_idx < n_classes {
653                                scores[class_idx] = scores[class_idx] + self.estimator_weights[t];
654                            }
655                        }
656                    }
657                    let total: F = scores.iter().copied().fold(F::zero(), |a, b| a + b);
658                    if total > F::zero() {
659                        for k in 0..n_classes {
660                            proba[[i, k]] = scores[k] / total;
661                        }
662                    } else {
663                        let u = F::one() / F::from(n_classes).unwrap();
664                        for k in 0..n_classes {
665                            proba[[i, k]] = u;
666                        }
667                    }
668                }
669            }
670            AdaBoostAlgorithm::SammeR => {
671                let eps = F::from(1e-10).unwrap();
672                let k_f = F::from(n_classes).unwrap();
673                let k_minus_1 = k_f - F::one();
674                for i in 0..n_samples {
675                    let row = x.row(i);
676                    let mut accumulated = vec![F::zero(); n_classes];
677                    for tree_nodes in &self.estimators {
678                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
679                        if let Node::Leaf {
680                            class_distribution: Some(ref dist),
681                            ..
682                        } = tree_nodes[leaf_idx]
683                        {
684                            let log_probs: Vec<F> =
685                                dist.iter().map(|&p| p.max(eps).ln()).collect();
686                            let mean_log: F =
687                                log_probs.iter().copied().fold(F::zero(), |a, b| a + b) / k_f;
688                            for k in 0..n_classes {
689                                accumulated[k] =
690                                    accumulated[k] + k_minus_1 * (log_probs[k] - mean_log);
691                            }
692                        } else if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
693                            let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
694                            if class_idx < n_classes {
695                                accumulated[class_idx] = accumulated[class_idx] + F::one();
696                            }
697                        }
698                    }
699                    // Softmax of accumulated.
700                    let max_score = accumulated
701                        .iter()
702                        .copied()
703                        .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
704                    let mut sum_exp = F::zero();
705                    for k in 0..n_classes {
706                        let e = (accumulated[k] - max_score).exp();
707                        proba[[i, k]] = e;
708                        sum_exp = sum_exp + e;
709                    }
710                    if sum_exp > F::zero() {
711                        for k in 0..n_classes {
712                            proba[[i, k]] = proba[[i, k]] / sum_exp;
713                        }
714                    }
715                }
716            }
717        }
718        Ok(proba)
719    }
720
721    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
722    /// sklearn's `ClassifierMixin.predict_log_proba`.
723    ///
724    /// # Errors
725    ///
726    /// Forwards any error from [`predict_proba`](Self::predict_proba).
727    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
728        let proba = self.predict_proba(x)?;
729        Ok(crate::log_proba(&proba))
730    }
731
732    /// Per-class raw scores. Mirrors sklearn's
733    /// `AdaBoostClassifier.decision_function`.
734    ///
735    /// SAMME: returns the cumulative weighted vote per class (unnormalized).
736    /// SAMME.R: returns the accumulated `(K-1)*(log p_k - mean log p)`
737    /// scores.
738    ///
739    /// Returns shape `(n_samples, n_classes)`. The argmax of each row
740    /// agrees with [`Predict::predict`].
741    ///
742    /// # Errors
743    ///
744    /// Returns [`FerroError::ShapeMismatch`] if the number of features
745    /// does not match the fitted model.
746    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
747        if x.ncols() != self.n_features {
748            return Err(FerroError::ShapeMismatch {
749                expected: vec![self.n_features],
750                actual: vec![x.ncols()],
751                context: "number of features must match fitted model".into(),
752            });
753        }
754        let n_samples = x.nrows();
755        let n_classes = self.n_classes;
756        let mut out = Array2::<F>::zeros((n_samples, n_classes));
757
758        match self.algorithm {
759            AdaBoostAlgorithm::Samme => {
760                for i in 0..n_samples {
761                    let row = x.row(i);
762                    for (t, tree_nodes) in self.estimators.iter().enumerate() {
763                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
764                        if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
765                            let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
766                            if class_idx < n_classes {
767                                out[[i, class_idx]] =
768                                    out[[i, class_idx]] + self.estimator_weights[t];
769                            }
770                        }
771                    }
772                }
773            }
774            AdaBoostAlgorithm::SammeR => {
775                let eps = F::from(1e-10).unwrap();
776                let k_f = F::from(n_classes).unwrap();
777                let k_minus_1 = k_f - F::one();
778                for i in 0..n_samples {
779                    let row = x.row(i);
780                    for tree_nodes in &self.estimators {
781                        let leaf_idx = decision_tree::traverse(tree_nodes, &row);
782                        if let Node::Leaf {
783                            class_distribution: Some(ref dist),
784                            ..
785                        } = tree_nodes[leaf_idx]
786                        {
787                            let log_probs: Vec<F> =
788                                dist.iter().map(|&p| p.max(eps).ln()).collect();
789                            let mean_log: F =
790                                log_probs.iter().copied().fold(F::zero(), |a, b| a + b) / k_f;
791                            for k in 0..n_classes {
792                                out[[i, k]] =
793                                    out[[i, k]] + k_minus_1 * (log_probs[k] - mean_log);
794                            }
795                        } else if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
796                            let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
797                            if class_idx < n_classes {
798                                out[[i, class_idx]] = out[[i, class_idx]] + F::one();
799                            }
800                        }
801                    }
802                }
803            }
804        }
805        Ok(out)
806    }
807}
808
809impl<F: Float + Send + Sync + 'static> HasClasses for FittedAdaBoostClassifier<F> {
810    fn classes(&self) -> &[usize] {
811        &self.classes
812    }
813
814    fn n_classes(&self) -> usize {
815        self.classes.len()
816    }
817}
818
819// Pipeline integration.
820impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
821    for AdaBoostClassifier<F>
822{
823    fn fit_pipeline(
824        &self,
825        x: &Array2<F>,
826        y: &Array1<F>,
827    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
828        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
829        let fitted = self.fit(x, &y_usize)?;
830        Ok(Box::new(FittedAdaBoostPipelineAdapter(fitted)))
831    }
832}
833
834/// Pipeline adapter for `FittedAdaBoostClassifier<F>`.
835struct FittedAdaBoostPipelineAdapter<F: Float + Send + Sync + 'static>(FittedAdaBoostClassifier<F>);
836
837impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
838    for FittedAdaBoostPipelineAdapter<F>
839{
840    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
841        let preds = self.0.predict(x)?;
842        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
843    }
844}
845
846// ---------------------------------------------------------------------------
847// Internal helpers
848// ---------------------------------------------------------------------------
849
850/// Resample indices proportional to weights (weighted bootstrap).
851///
852/// Uses a systematic resampling approach: the cumulative weight distribution
853/// determines which original indices appear in the resampled set.
854fn resample_weighted<F: Float>(weights: &[F], n: usize) -> Vec<usize> {
855    if weights.is_empty() {
856        return Vec::new();
857    }
858
859    // Build cumulative distribution.
860    let mut cumsum = Vec::with_capacity(weights.len());
861    let mut running = F::zero();
862    for &w in weights {
863        running = running + w;
864        cumsum.push(running);
865    }
866
867    // Normalise (in case weights don't sum to 1).
868    let total = running;
869    if total <= F::zero() {
870        return (0..n).collect();
871    }
872
873    let mut indices = Vec::with_capacity(n);
874    let step = total / F::from(n).unwrap();
875    let mut threshold = step / F::from(2.0).unwrap(); // Start in the middle of the first bin.
876    let mut j = 0;
877
878    for _ in 0..n {
879        while j < cumsum.len() - 1 && cumsum[j] < threshold {
880            j += 1;
881        }
882        indices.push(j);
883        threshold = threshold + step;
884    }
885
886    indices
887}
888
889// ---------------------------------------------------------------------------
890// Tests
891// ---------------------------------------------------------------------------
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896    use ndarray::array;
897
898    // -- SAMME.R tests --
899
900    #[test]
901    fn test_adaboost_sammer_binary_simple() {
902        let x = Array2::from_shape_vec(
903            (8, 2),
904            vec![
905                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
906            ],
907        )
908        .unwrap();
909        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
910
911        let model = AdaBoostClassifier::<f64>::new()
912            .with_n_estimators(50)
913            .with_random_state(42);
914        let fitted = model.fit(&x, &y).unwrap();
915        let preds = fitted.predict(&x).unwrap();
916
917        assert_eq!(preds.len(), 8);
918        for i in 0..4 {
919            assert_eq!(preds[i], 0);
920        }
921        for i in 4..8 {
922            assert_eq!(preds[i], 1);
923        }
924    }
925
926    #[test]
927    fn test_adaboost_sammer_multiclass() {
928        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
929            .unwrap();
930        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
931
932        let model = AdaBoostClassifier::<f64>::new()
933            .with_n_estimators(50)
934            .with_random_state(42);
935        let fitted = model.fit(&x, &y).unwrap();
936        let preds = fitted.predict(&x).unwrap();
937
938        assert_eq!(preds.len(), 9);
939        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
940        assert!(
941            correct >= 5,
942            "Expected at least 5/9 correct, got {correct}/9"
943        );
944    }
945
946    // -- SAMME tests --
947
948    #[test]
949    fn test_adaboost_samme_binary_simple() {
950        let x = Array2::from_shape_vec(
951            (8, 2),
952            vec![
953                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
954            ],
955        )
956        .unwrap();
957        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
958
959        let model = AdaBoostClassifier::<f64>::new()
960            .with_n_estimators(50)
961            .with_algorithm(AdaBoostAlgorithm::Samme)
962            .with_random_state(42);
963        let fitted = model.fit(&x, &y).unwrap();
964        let preds = fitted.predict(&x).unwrap();
965
966        assert_eq!(preds.len(), 8);
967        for i in 0..4 {
968            assert_eq!(preds[i], 0);
969        }
970        for i in 4..8 {
971            assert_eq!(preds[i], 1);
972        }
973    }
974
975    #[test]
976    fn test_adaboost_samme_multiclass() {
977        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
978            .unwrap();
979        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
980
981        let model = AdaBoostClassifier::<f64>::new()
982            .with_n_estimators(50)
983            .with_algorithm(AdaBoostAlgorithm::Samme)
984            .with_random_state(42);
985        let fitted = model.fit(&x, &y).unwrap();
986        let preds = fitted.predict(&x).unwrap();
987
988        assert_eq!(preds.len(), 9);
989        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
990        assert!(
991            correct >= 5,
992            "Expected at least 5/9 correct for SAMME multiclass, got {correct}/9"
993        );
994    }
995
996    // -- Common tests --
997
998    #[test]
999    fn test_adaboost_has_classes() {
1000        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1001        let y = array![0, 1, 2, 0, 1, 2];
1002
1003        let model = AdaBoostClassifier::<f64>::new()
1004            .with_n_estimators(5)
1005            .with_random_state(0);
1006        let fitted = model.fit(&x, &y).unwrap();
1007
1008        assert_eq!(fitted.classes(), &[0, 1, 2]);
1009        assert_eq!(fitted.n_classes(), 3);
1010    }
1011
1012    #[test]
1013    fn test_adaboost_reproducibility() {
1014        let x = Array2::from_shape_vec(
1015            (8, 2),
1016            vec![
1017                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1018            ],
1019        )
1020        .unwrap();
1021        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1022
1023        let model = AdaBoostClassifier::<f64>::new()
1024            .with_n_estimators(10)
1025            .with_random_state(42);
1026
1027        let fitted1 = model.fit(&x, &y).unwrap();
1028        let fitted2 = model.fit(&x, &y).unwrap();
1029
1030        let preds1 = fitted1.predict(&x).unwrap();
1031        let preds2 = fitted2.predict(&x).unwrap();
1032        assert_eq!(preds1, preds2);
1033    }
1034
1035    #[test]
1036    fn test_adaboost_shape_mismatch_fit() {
1037        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1038        let y = array![0, 1];
1039
1040        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
1041        assert!(model.fit(&x, &y).is_err());
1042    }
1043
1044    #[test]
1045    fn test_adaboost_shape_mismatch_predict() {
1046        let x =
1047            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1048        let y = array![0, 0, 1, 1];
1049
1050        let model = AdaBoostClassifier::<f64>::new()
1051            .with_n_estimators(5)
1052            .with_random_state(0);
1053        let fitted = model.fit(&x, &y).unwrap();
1054
1055        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1056        assert!(fitted.predict(&x_bad).is_err());
1057    }
1058
1059    #[test]
1060    fn test_adaboost_empty_data() {
1061        let x = Array2::<f64>::zeros((0, 2));
1062        let y = Array1::<usize>::zeros(0);
1063
1064        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
1065        assert!(model.fit(&x, &y).is_err());
1066    }
1067
1068    #[test]
1069    fn test_adaboost_single_class() {
1070        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1071        let y = array![0, 0, 0];
1072
1073        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
1074        assert!(model.fit(&x, &y).is_err());
1075    }
1076
1077    #[test]
1078    fn test_adaboost_zero_estimators() {
1079        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1080        let y = array![0, 0, 1, 1];
1081
1082        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(0);
1083        assert!(model.fit(&x, &y).is_err());
1084    }
1085
1086    #[test]
1087    fn test_adaboost_invalid_learning_rate() {
1088        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1089        let y = array![0, 0, 1, 1];
1090
1091        let model = AdaBoostClassifier::<f64>::new()
1092            .with_n_estimators(5)
1093            .with_learning_rate(0.0);
1094        assert!(model.fit(&x, &y).is_err());
1095    }
1096
1097    #[test]
1098    fn test_adaboost_pipeline_integration() {
1099        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1100        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1101
1102        let model = AdaBoostClassifier::<f64>::new()
1103            .with_n_estimators(10)
1104            .with_random_state(42);
1105        let fitted = model.fit_pipeline(&x, &y).unwrap();
1106        let preds = fitted.predict_pipeline(&x).unwrap();
1107        assert_eq!(preds.len(), 6);
1108    }
1109
1110    #[test]
1111    fn test_adaboost_f32_support() {
1112        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1113        let y = array![0, 0, 0, 1, 1, 1];
1114
1115        let model = AdaBoostClassifier::<f32>::new()
1116            .with_n_estimators(10)
1117            .with_random_state(42);
1118        let fitted = model.fit(&x, &y).unwrap();
1119        let preds = fitted.predict(&x).unwrap();
1120        assert_eq!(preds.len(), 6);
1121    }
1122
1123    #[test]
1124    fn test_adaboost_default_trait() {
1125        let model = AdaBoostClassifier::<f64>::default();
1126        assert_eq!(model.n_estimators, 50);
1127        assert!((model.learning_rate - 1.0).abs() < 1e-10);
1128        assert_eq!(model.algorithm, AdaBoostAlgorithm::Samme);
1129    }
1130
1131    #[test]
1132    fn test_adaboost_non_contiguous_labels() {
1133        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1134        let y = array![10, 10, 10, 20, 20, 20];
1135
1136        let model = AdaBoostClassifier::<f64>::new()
1137            .with_n_estimators(20)
1138            .with_random_state(42);
1139        let fitted = model.fit(&x, &y).unwrap();
1140        let preds = fitted.predict(&x).unwrap();
1141
1142        assert_eq!(preds.len(), 6);
1143        for &p in &preds {
1144            assert!(p == 10 || p == 20);
1145        }
1146    }
1147
1148    #[test]
1149    fn test_adaboost_sammer_learning_rate_effect() {
1150        let x =
1151            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1152        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1153
1154        // Low learning rate should still work (just slower convergence).
1155        let model = AdaBoostClassifier::<f64>::new()
1156            .with_n_estimators(50)
1157            .with_learning_rate(0.1)
1158            .with_random_state(42);
1159        let fitted = model.fit(&x, &y).unwrap();
1160        let preds = fitted.predict(&x).unwrap();
1161        assert_eq!(preds.len(), 8);
1162    }
1163
1164    #[test]
1165    fn test_adaboost_samme_learning_rate_effect() {
1166        let x =
1167            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1168        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1169
1170        let model = AdaBoostClassifier::<f64>::new()
1171            .with_n_estimators(50)
1172            .with_algorithm(AdaBoostAlgorithm::Samme)
1173            .with_learning_rate(0.5)
1174            .with_random_state(42);
1175        let fitted = model.fit(&x, &y).unwrap();
1176        let preds = fitted.predict(&x).unwrap();
1177        assert_eq!(preds.len(), 8);
1178    }
1179
1180    #[test]
1181    fn test_adaboost_many_features() {
1182        // 4 features, only first one is informative.
1183        let x = Array2::from_shape_vec(
1184            (8, 4),
1185            vec![
1186                1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0,
1187                5.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 8.0, 0.0, 0.0, 0.0,
1188            ],
1189        )
1190        .unwrap();
1191        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1192
1193        let model = AdaBoostClassifier::<f64>::new()
1194            .with_n_estimators(20)
1195            .with_random_state(42);
1196        let fitted = model.fit(&x, &y).unwrap();
1197        let preds = fitted.predict(&x).unwrap();
1198        assert_eq!(preds.len(), 8);
1199    }
1200
1201    #[test]
1202    fn test_adaboost_4_classes() {
1203        let x = Array2::from_shape_vec(
1204            (12, 1),
1205            vec![
1206                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1207            ],
1208        )
1209        .unwrap();
1210        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3];
1211
1212        let model = AdaBoostClassifier::<f64>::new()
1213            .with_n_estimators(50)
1214            .with_random_state(42);
1215        let fitted = model.fit(&x, &y).unwrap();
1216        let preds = fitted.predict(&x).unwrap();
1217
1218        assert_eq!(preds.len(), 12);
1219        assert_eq!(fitted.n_classes(), 4);
1220    }
1221
1222    // -- Resample helper tests --
1223
1224    #[test]
1225    fn test_resample_weighted_uniform() {
1226        let weights = vec![0.25, 0.25, 0.25, 0.25];
1227        let indices = resample_weighted(&weights, 4);
1228        // With uniform weights, each index should appear once in order.
1229        assert_eq!(indices, vec![0, 1, 2, 3]);
1230    }
1231
1232    #[test]
1233    fn test_resample_weighted_skewed() {
1234        let weights = vec![0.0, 0.0, 0.0, 1.0];
1235        let indices = resample_weighted(&weights, 4);
1236        assert_eq!(indices.len(), 4);
1237        // All weight on last index.
1238        for &idx in &indices {
1239            assert_eq!(idx, 3);
1240        }
1241    }
1242
1243    #[test]
1244    fn test_resample_weighted_empty() {
1245        let weights: Vec<f64> = Vec::new();
1246        let indices = resample_weighted(&weights, 0);
1247        assert!(indices.is_empty());
1248    }
1249
1250    #[test]
1251    fn test_adaboost_sammer_single_estimator() {
1252        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1253        let y = array![0, 0, 0, 1, 1, 1];
1254
1255        let model = AdaBoostClassifier::<f64>::new()
1256            .with_n_estimators(1)
1257            .with_random_state(42);
1258        let fitted = model.fit(&x, &y).unwrap();
1259        let preds = fitted.predict(&x).unwrap();
1260        assert_eq!(preds.len(), 6);
1261    }
1262
1263    #[test]
1264    fn test_adaboost_samme_single_estimator() {
1265        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1266        let y = array![0, 0, 0, 1, 1, 1];
1267
1268        let model = AdaBoostClassifier::<f64>::new()
1269            .with_n_estimators(1)
1270            .with_algorithm(AdaBoostAlgorithm::Samme)
1271            .with_random_state(42);
1272        let fitted = model.fit(&x, &y).unwrap();
1273        let preds = fitted.predict(&x).unwrap();
1274        assert_eq!(preds.len(), 6);
1275    }
1276
1277    #[test]
1278    fn test_adaboost_negative_learning_rate() {
1279        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1280        let y = array![0, 0, 1, 1];
1281
1282        let model = AdaBoostClassifier::<f64>::new()
1283            .with_n_estimators(5)
1284            .with_learning_rate(-0.1);
1285        assert!(model.fit(&x, &y).is_err());
1286    }
1287}