Skip to main content

ferrolearn_tree/
adaboost.rs

1//! AdaBoost classifier.
2//!
3//! This module provides [`AdaBoostClassifier`], which implements the Adaptive
4//! Boosting algorithm using decision tree stumps (depth-1 trees) as base
5//! estimators. Two algorithm variants are supported:
6//!
7//! - **SAMME**: uses discrete class predictions and works with multiclass
8//!   problems directly.
9//! - **SAMME.R** (default): uses class probability estimates, typically giving
10//!   better performance than SAMME.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::AdaBoostClassifier;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let x = Array2::from_shape_vec((8, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
22//! ]).unwrap();
23//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
24//!
25//! let model = AdaBoostClassifier::<f64>::new()
26//!     .with_n_estimators(50)
27//!     .with_random_state(42);
28//! let fitted = model.fit(&x, &y).unwrap();
29//! let preds = fitted.predict(&x).unwrap();
30//! ```
31
32use crate::decision_tree::{
33    self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
34};
35use ferrolearn_core::error::FerroError;
36use ferrolearn_core::introspection::HasClasses;
37use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2};
40use num_traits::{Float, FromPrimitive, ToPrimitive};
41
42// ---------------------------------------------------------------------------
43// Algorithm enum
44// ---------------------------------------------------------------------------
45
46/// AdaBoost algorithm variant.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AdaBoostAlgorithm {
49    /// SAMME: Stagewise Additive Modeling using a Multi-class Exponential loss.
50    ///
51    /// Uses discrete class predictions from each base estimator.
52    Samme,
53    /// SAMME.R: the "real" variant that uses class probability estimates.
54    ///
55    /// Generally outperforms SAMME.
56    SammeR,
57}
58
59// ---------------------------------------------------------------------------
60// AdaBoostClassifier
61// ---------------------------------------------------------------------------
62
63/// AdaBoost classifier using decision tree stumps as base estimators.
64///
65/// At each boosting round a decision tree stump (max depth = 1) is fitted
66/// to the weighted training data. Misclassified samples receive higher
67/// weight in subsequent rounds, allowing the ensemble to focus on hard
68/// examples.
69///
70/// # Type Parameters
71///
72/// - `F`: The floating-point type (`f32` or `f64`).
73#[derive(Debug, Clone)]
74pub struct AdaBoostClassifier<F> {
75    /// Number of boosting stages (stumps).
76    pub n_estimators: usize,
77    /// Learning rate (shrinkage). Lower values require more estimators.
78    pub learning_rate: f64,
79    /// Algorithm variant (`SAMME` or `SAMME.R`).
80    pub algorithm: AdaBoostAlgorithm,
81    /// Random seed for reproducibility.
82    pub random_state: Option<u64>,
83    _marker: std::marker::PhantomData<F>,
84}
85
86impl<F: Float> AdaBoostClassifier<F> {
87    /// Create a new `AdaBoostClassifier` with default settings.
88    ///
89    /// Defaults: `n_estimators = 50`, `learning_rate = 1.0`,
90    /// `algorithm = SAMME.R`, `random_state = None`.
91    #[must_use]
92    pub fn new() -> Self {
93        Self {
94            n_estimators: 50,
95            learning_rate: 1.0,
96            algorithm: AdaBoostAlgorithm::SammeR,
97            random_state: None,
98            _marker: std::marker::PhantomData,
99        }
100    }
101
102    /// Set the number of boosting stages.
103    #[must_use]
104    pub fn with_n_estimators(mut self, n: usize) -> Self {
105        self.n_estimators = n;
106        self
107    }
108
109    /// Set the learning rate.
110    #[must_use]
111    pub fn with_learning_rate(mut self, lr: f64) -> Self {
112        self.learning_rate = lr;
113        self
114    }
115
116    /// Set the algorithm variant.
117    #[must_use]
118    pub fn with_algorithm(mut self, algo: AdaBoostAlgorithm) -> Self {
119        self.algorithm = algo;
120        self
121    }
122
123    /// Set the random seed for reproducibility.
124    #[must_use]
125    pub fn with_random_state(mut self, seed: u64) -> Self {
126        self.random_state = Some(seed);
127        self
128    }
129}
130
131impl<F: Float> Default for AdaBoostClassifier<F> {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137// ---------------------------------------------------------------------------
138// FittedAdaBoostClassifier
139// ---------------------------------------------------------------------------
140
141/// A fitted AdaBoost classifier.
142///
143/// Stores the sequence of stumps and their weights. Predictions are made
144/// by weighted majority vote (SAMME) or weighted probability averaging
145/// (SAMME.R).
146#[derive(Debug, Clone)]
147pub struct FittedAdaBoostClassifier<F> {
148    /// Sorted unique class labels.
149    classes: Vec<usize>,
150    /// Sequence of fitted tree stumps.
151    estimators: Vec<Vec<Node<F>>>,
152    /// Weight of each estimator (SAMME) or kept for SAMME.R bookkeeping.
153    estimator_weights: Vec<F>,
154    /// Number of features.
155    n_features: usize,
156    /// Number of classes.
157    n_classes: usize,
158    /// Algorithm used.
159    algorithm: AdaBoostAlgorithm,
160}
161
162impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for AdaBoostClassifier<F> {
163    type Fitted = FittedAdaBoostClassifier<F>;
164    type Error = FerroError;
165
166    /// Fit the AdaBoost classifier.
167    ///
168    /// # Errors
169    ///
170    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
171    /// numbers of samples.
172    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
173    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
174    fn fit(
175        &self,
176        x: &Array2<F>,
177        y: &Array1<usize>,
178    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
179        let (n_samples, n_features) = x.dim();
180
181        if n_samples != y.len() {
182            return Err(FerroError::ShapeMismatch {
183                expected: vec![n_samples],
184                actual: vec![y.len()],
185                context: "y length must match number of samples in X".into(),
186            });
187        }
188        if n_samples == 0 {
189            return Err(FerroError::InsufficientSamples {
190                required: 1,
191                actual: 0,
192                context: "AdaBoostClassifier requires at least one sample".into(),
193            });
194        }
195        if self.n_estimators == 0 {
196            return Err(FerroError::InvalidParameter {
197                name: "n_estimators".into(),
198                reason: "must be at least 1".into(),
199            });
200        }
201        if self.learning_rate <= 0.0 {
202            return Err(FerroError::InvalidParameter {
203                name: "learning_rate".into(),
204                reason: "must be positive".into(),
205            });
206        }
207
208        // Determine unique classes.
209        let mut classes: Vec<usize> = y.iter().copied().collect();
210        classes.sort_unstable();
211        classes.dedup();
212        let n_classes = classes.len();
213
214        if n_classes < 2 {
215            return Err(FerroError::InvalidParameter {
216                name: "y".into(),
217                reason: "need at least 2 distinct classes".into(),
218            });
219        }
220
221        let y_mapped: Vec<usize> = y
222            .iter()
223            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
224            .collect();
225
226        match self.algorithm {
227            AdaBoostAlgorithm::Samme => {
228                self.fit_samme(x, &y_mapped, n_samples, n_features, n_classes, &classes)
229            }
230            AdaBoostAlgorithm::SammeR => {
231                self.fit_samme_r(x, &y_mapped, n_samples, n_features, n_classes, &classes)
232            }
233        }
234    }
235}
236
237impl<F: Float + Send + Sync + 'static> AdaBoostClassifier<F> {
238    /// Fit using the SAMME algorithm (discrete predictions).
239    fn fit_samme(
240        &self,
241        x: &Array2<F>,
242        y_mapped: &[usize],
243        n_samples: usize,
244        n_features: usize,
245        n_classes: usize,
246        classes: &[usize],
247    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
248        let lr = F::from(self.learning_rate).unwrap();
249        let n_f = F::from(n_samples).unwrap();
250        let eps = F::from(1e-10).unwrap();
251
252        // Initialize sample weights uniformly.
253        let mut weights = vec![F::one() / n_f; n_samples];
254
255        let all_features: Vec<usize> = (0..n_features).collect();
256        let stump_params = decision_tree::TreeParams {
257            max_depth: Some(1),
258            min_samples_split: 2,
259            min_samples_leaf: 1,
260        };
261
262        let mut estimators = Vec::with_capacity(self.n_estimators);
263        let mut estimator_weights = Vec::with_capacity(self.n_estimators);
264
265        for _ in 0..self.n_estimators {
266            // Build weighted sample indices: replicate indices proportional to weight.
267            let indices = resample_weighted(&weights, n_samples);
268
269            let tree = build_classification_tree_with_feature_subset(
270                x,
271                y_mapped,
272                n_classes,
273                &indices,
274                &all_features,
275                &stump_params,
276                ClassificationCriterion::Gini,
277            );
278
279            // Compute predictions and weighted error.
280            let mut weighted_error = F::zero();
281            let mut preds = vec![0usize; n_samples];
282            for i in 0..n_samples {
283                let row = x.row(i);
284                let leaf_idx = decision_tree::traverse(&tree, &row);
285                if let Node::Leaf { value, .. } = tree[leaf_idx] {
286                    preds[i] = value.to_f64().map_or(0, |f| f.round() as usize);
287                }
288                if preds[i] != y_mapped[i] {
289                    weighted_error = weighted_error + weights[i];
290                }
291            }
292
293            // Normalise error.
294            let weight_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
295            let err = if weight_sum > F::zero() {
296                weighted_error / weight_sum
297            } else {
298                F::from(0.5).unwrap()
299            };
300
301            // If error is too high or zero, stop or skip.
302            if err >= F::one() - F::one() / F::from(n_classes).unwrap() {
303                // Error too high; stop boosting.
304                if estimators.is_empty() {
305                    // Keep at least one estimator.
306                    estimators.push(tree);
307                    estimator_weights.push(F::one());
308                }
309                break;
310            }
311
312            // Estimator weight: SAMME formula.
313            let alpha = lr * ((F::one() - err).max(eps) / err.max(eps)).ln()
314                + lr * (F::from(n_classes - 1).unwrap()).ln();
315
316            // Update sample weights.
317            for i in 0..n_samples {
318                if preds[i] != y_mapped[i] {
319                    weights[i] = weights[i] * alpha.exp();
320                }
321            }
322
323            // Normalise weights.
324            let new_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
325            if new_sum > F::zero() {
326                for w in &mut weights {
327                    *w = *w / new_sum;
328                }
329            }
330
331            estimators.push(tree);
332            estimator_weights.push(alpha);
333        }
334
335        Ok(FittedAdaBoostClassifier {
336            classes: classes.to_vec(),
337            estimators,
338            estimator_weights,
339            n_features,
340            n_classes,
341            algorithm: AdaBoostAlgorithm::Samme,
342        })
343    }
344
345    /// Fit using the SAMME.R algorithm (real-valued / probability-based).
346    fn fit_samme_r(
347        &self,
348        x: &Array2<F>,
349        y_mapped: &[usize],
350        n_samples: usize,
351        n_features: usize,
352        n_classes: usize,
353        classes: &[usize],
354    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
355        let lr = F::from(self.learning_rate).unwrap();
356        let n_f = F::from(n_samples).unwrap();
357        let eps = F::from(1e-10).unwrap();
358        let k_f = F::from(n_classes).unwrap();
359
360        // Initialize sample weights uniformly.
361        let mut weights = vec![F::one() / n_f; n_samples];
362
363        let all_features: Vec<usize> = (0..n_features).collect();
364        let stump_params = decision_tree::TreeParams {
365            max_depth: Some(1),
366            min_samples_split: 2,
367            min_samples_leaf: 1,
368        };
369
370        let mut estimators = Vec::with_capacity(self.n_estimators);
371        let mut estimator_weights = Vec::with_capacity(self.n_estimators);
372
373        for _ in 0..self.n_estimators {
374            let indices = resample_weighted(&weights, n_samples);
375
376            let tree = build_classification_tree_with_feature_subset(
377                x,
378                y_mapped,
379                n_classes,
380                &indices,
381                &all_features,
382                &stump_params,
383                ClassificationCriterion::Gini,
384            );
385
386            // Get class probability estimates for each sample.
387            let mut proba = vec![vec![F::zero(); n_classes]; n_samples];
388            for (i, proba_row) in proba.iter_mut().enumerate() {
389                let row = x.row(i);
390                let leaf_idx = decision_tree::traverse(&tree, &row);
391                if let Node::Leaf {
392                    class_distribution: Some(ref dist),
393                    ..
394                } = tree[leaf_idx]
395                {
396                    for (k, &p) in dist.iter().enumerate() {
397                        proba_row[k] = p.max(eps);
398                    }
399                } else {
400                    // Fallback: uniform.
401                    for val in proba_row.iter_mut() {
402                        *val = F::one() / k_f;
403                    }
404                }
405                // Normalise.
406                let row_sum: F = proba_row.iter().copied().fold(F::zero(), |a, b| a + b);
407                if row_sum > F::zero() {
408                    for val in proba_row.iter_mut() {
409                        *val = *val / row_sum;
410                    }
411                }
412            }
413
414            // SAMME.R weight update: based on log-probability.
415            // h_k(x) = (K-1) * (log(p_k(x)) - (1/K) * sum_j log(p_j(x)))
416            // Then update: w_i *= exp(-(K-1)/K * lr * sum_k y_{ik} * log(p_k(x)))
417            // Simplified: w_i *= exp(-lr * (K-1)/K * log(p_{y_i}(x)))
418            let factor = lr * (k_f - F::one()) / k_f;
419            let mut any_update = false;
420
421            for i in 0..n_samples {
422                let p_correct = proba[i][y_mapped[i]].max(eps);
423                let exponent = -factor * p_correct.ln();
424                weights[i] = weights[i] * exponent.exp();
425                if exponent.abs() > eps {
426                    any_update = true;
427                }
428            }
429
430            // Normalise weights.
431            let new_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
432            if new_sum > F::zero() {
433                for w in &mut weights {
434                    *w = *w / new_sum;
435                }
436            }
437
438            estimators.push(tree);
439            estimator_weights.push(F::one()); // SAMME.R uses equal weight; prediction uses probabilities.
440
441            if !any_update {
442                break;
443            }
444        }
445
446        Ok(FittedAdaBoostClassifier {
447            classes: classes.to_vec(),
448            estimators,
449            estimator_weights,
450            n_features,
451            n_classes,
452            algorithm: AdaBoostAlgorithm::SammeR,
453        })
454    }
455}
456
457impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedAdaBoostClassifier<F> {
458    type Output = Array1<usize>;
459    type Error = FerroError;
460
461    /// Predict class labels.
462    ///
463    /// - **SAMME**: weighted majority vote using estimator weights.
464    /// - **SAMME.R**: weighted average of log-probabilities.
465    ///
466    /// # Errors
467    ///
468    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
469    /// not match the fitted model.
470    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
471        if x.ncols() != self.n_features {
472            return Err(FerroError::ShapeMismatch {
473                expected: vec![self.n_features],
474                actual: vec![x.ncols()],
475                context: "number of features must match fitted model".into(),
476            });
477        }
478
479        let n_samples = x.nrows();
480
481        match self.algorithm {
482            AdaBoostAlgorithm::Samme => self.predict_samme(x, n_samples),
483            AdaBoostAlgorithm::SammeR => self.predict_samme_r(x, n_samples),
484        }
485    }
486}
487
488impl<F: Float + Send + Sync + 'static> FittedAdaBoostClassifier<F> {
489    /// Predict using SAMME (weighted majority vote).
490    fn predict_samme(&self, x: &Array2<F>, n_samples: usize) -> Result<Array1<usize>, FerroError> {
491        let mut predictions = Array1::zeros(n_samples);
492
493        for i in 0..n_samples {
494            let row = x.row(i);
495            let mut class_scores = vec![F::zero(); self.n_classes];
496
497            for (t, tree_nodes) in self.estimators.iter().enumerate() {
498                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
499                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
500                    let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
501                    if class_idx < self.n_classes {
502                        class_scores[class_idx] =
503                            class_scores[class_idx] + self.estimator_weights[t];
504                    }
505                }
506            }
507
508            let best = class_scores
509                .iter()
510                .enumerate()
511                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
512                .map_or(0, |(k, _)| k);
513            predictions[i] = self.classes[best];
514        }
515
516        Ok(predictions)
517    }
518
519    /// Predict using SAMME.R (weighted probability averaging).
520    fn predict_samme_r(
521        &self,
522        x: &Array2<F>,
523        n_samples: usize,
524    ) -> Result<Array1<usize>, FerroError> {
525        let eps = F::from(1e-10).unwrap();
526        let k_f = F::from(self.n_classes).unwrap();
527        let k_minus_1 = k_f - F::one();
528
529        let mut predictions = Array1::zeros(n_samples);
530
531        for i in 0..n_samples {
532            let row = x.row(i);
533            let mut accumulated = vec![F::zero(); self.n_classes];
534
535            for tree_nodes in &self.estimators {
536                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
537                if let Node::Leaf {
538                    class_distribution: Some(ref dist),
539                    ..
540                } = tree_nodes[leaf_idx]
541                {
542                    // h_k(x) = (K-1) * (log(p_k) - mean(log(p_j)))
543                    let log_probs: Vec<F> = dist.iter().map(|&p| p.max(eps).ln()).collect();
544                    let mean_log: F = log_probs.iter().copied().fold(F::zero(), |a, b| a + b) / k_f;
545
546                    for k in 0..self.n_classes {
547                        accumulated[k] = accumulated[k] + k_minus_1 * (log_probs[k] - mean_log);
548                    }
549                } else {
550                    // Leaf without distribution: predict from value.
551                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
552                        let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
553                        if class_idx < self.n_classes {
554                            accumulated[class_idx] = accumulated[class_idx] + F::one();
555                        }
556                    }
557                }
558            }
559
560            let best = accumulated
561                .iter()
562                .enumerate()
563                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
564                .map_or(0, |(k, _)| k);
565            predictions[i] = self.classes[best];
566        }
567
568        Ok(predictions)
569    }
570}
571
572impl<F: Float + Send + Sync + 'static> HasClasses for FittedAdaBoostClassifier<F> {
573    fn classes(&self) -> &[usize] {
574        &self.classes
575    }
576
577    fn n_classes(&self) -> usize {
578        self.classes.len()
579    }
580}
581
582// Pipeline integration.
583impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
584    for AdaBoostClassifier<F>
585{
586    fn fit_pipeline(
587        &self,
588        x: &Array2<F>,
589        y: &Array1<F>,
590    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
591        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
592        let fitted = self.fit(x, &y_usize)?;
593        Ok(Box::new(FittedAdaBoostPipelineAdapter(fitted)))
594    }
595}
596
597/// Pipeline adapter for `FittedAdaBoostClassifier<F>`.
598struct FittedAdaBoostPipelineAdapter<F: Float + Send + Sync + 'static>(FittedAdaBoostClassifier<F>);
599
600impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
601    for FittedAdaBoostPipelineAdapter<F>
602{
603    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
604        let preds = self.0.predict(x)?;
605        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
606    }
607}
608
609// ---------------------------------------------------------------------------
610// Internal helpers
611// ---------------------------------------------------------------------------
612
613/// Resample indices proportional to weights (weighted bootstrap).
614///
615/// Uses a systematic resampling approach: the cumulative weight distribution
616/// determines which original indices appear in the resampled set.
617fn resample_weighted<F: Float>(weights: &[F], n: usize) -> Vec<usize> {
618    if weights.is_empty() {
619        return Vec::new();
620    }
621
622    // Build cumulative distribution.
623    let mut cumsum = Vec::with_capacity(weights.len());
624    let mut running = F::zero();
625    for &w in weights {
626        running = running + w;
627        cumsum.push(running);
628    }
629
630    // Normalise (in case weights don't sum to 1).
631    let total = running;
632    if total <= F::zero() {
633        return (0..n).collect();
634    }
635
636    let mut indices = Vec::with_capacity(n);
637    let step = total / F::from(n).unwrap();
638    let mut threshold = step / F::from(2.0).unwrap(); // Start in the middle of the first bin.
639    let mut j = 0;
640
641    for _ in 0..n {
642        while j < cumsum.len() - 1 && cumsum[j] < threshold {
643            j += 1;
644        }
645        indices.push(j);
646        threshold = threshold + step;
647    }
648
649    indices
650}
651
652// ---------------------------------------------------------------------------
653// Tests
654// ---------------------------------------------------------------------------
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659    use ndarray::array;
660
661    // -- SAMME.R tests --
662
663    #[test]
664    fn test_adaboost_sammer_binary_simple() {
665        let x = Array2::from_shape_vec(
666            (8, 2),
667            vec![
668                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
669            ],
670        )
671        .unwrap();
672        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
673
674        let model = AdaBoostClassifier::<f64>::new()
675            .with_n_estimators(50)
676            .with_random_state(42);
677        let fitted = model.fit(&x, &y).unwrap();
678        let preds = fitted.predict(&x).unwrap();
679
680        assert_eq!(preds.len(), 8);
681        for i in 0..4 {
682            assert_eq!(preds[i], 0);
683        }
684        for i in 4..8 {
685            assert_eq!(preds[i], 1);
686        }
687    }
688
689    #[test]
690    fn test_adaboost_sammer_multiclass() {
691        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
692            .unwrap();
693        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
694
695        let model = AdaBoostClassifier::<f64>::new()
696            .with_n_estimators(50)
697            .with_random_state(42);
698        let fitted = model.fit(&x, &y).unwrap();
699        let preds = fitted.predict(&x).unwrap();
700
701        assert_eq!(preds.len(), 9);
702        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
703        assert!(
704            correct >= 5,
705            "Expected at least 5/9 correct, got {correct}/9"
706        );
707    }
708
709    // -- SAMME tests --
710
711    #[test]
712    fn test_adaboost_samme_binary_simple() {
713        let x = Array2::from_shape_vec(
714            (8, 2),
715            vec![
716                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
717            ],
718        )
719        .unwrap();
720        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
721
722        let model = AdaBoostClassifier::<f64>::new()
723            .with_n_estimators(50)
724            .with_algorithm(AdaBoostAlgorithm::Samme)
725            .with_random_state(42);
726        let fitted = model.fit(&x, &y).unwrap();
727        let preds = fitted.predict(&x).unwrap();
728
729        assert_eq!(preds.len(), 8);
730        for i in 0..4 {
731            assert_eq!(preds[i], 0);
732        }
733        for i in 4..8 {
734            assert_eq!(preds[i], 1);
735        }
736    }
737
738    #[test]
739    fn test_adaboost_samme_multiclass() {
740        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
741            .unwrap();
742        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
743
744        let model = AdaBoostClassifier::<f64>::new()
745            .with_n_estimators(50)
746            .with_algorithm(AdaBoostAlgorithm::Samme)
747            .with_random_state(42);
748        let fitted = model.fit(&x, &y).unwrap();
749        let preds = fitted.predict(&x).unwrap();
750
751        assert_eq!(preds.len(), 9);
752        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
753        assert!(
754            correct >= 5,
755            "Expected at least 5/9 correct for SAMME multiclass, got {correct}/9"
756        );
757    }
758
759    // -- Common tests --
760
761    #[test]
762    fn test_adaboost_has_classes() {
763        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
764        let y = array![0, 1, 2, 0, 1, 2];
765
766        let model = AdaBoostClassifier::<f64>::new()
767            .with_n_estimators(5)
768            .with_random_state(0);
769        let fitted = model.fit(&x, &y).unwrap();
770
771        assert_eq!(fitted.classes(), &[0, 1, 2]);
772        assert_eq!(fitted.n_classes(), 3);
773    }
774
775    #[test]
776    fn test_adaboost_reproducibility() {
777        let x = Array2::from_shape_vec(
778            (8, 2),
779            vec![
780                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
781            ],
782        )
783        .unwrap();
784        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
785
786        let model = AdaBoostClassifier::<f64>::new()
787            .with_n_estimators(10)
788            .with_random_state(42);
789
790        let fitted1 = model.fit(&x, &y).unwrap();
791        let fitted2 = model.fit(&x, &y).unwrap();
792
793        let preds1 = fitted1.predict(&x).unwrap();
794        let preds2 = fitted2.predict(&x).unwrap();
795        assert_eq!(preds1, preds2);
796    }
797
798    #[test]
799    fn test_adaboost_shape_mismatch_fit() {
800        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
801        let y = array![0, 1];
802
803        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
804        assert!(model.fit(&x, &y).is_err());
805    }
806
807    #[test]
808    fn test_adaboost_shape_mismatch_predict() {
809        let x =
810            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
811        let y = array![0, 0, 1, 1];
812
813        let model = AdaBoostClassifier::<f64>::new()
814            .with_n_estimators(5)
815            .with_random_state(0);
816        let fitted = model.fit(&x, &y).unwrap();
817
818        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
819        assert!(fitted.predict(&x_bad).is_err());
820    }
821
822    #[test]
823    fn test_adaboost_empty_data() {
824        let x = Array2::<f64>::zeros((0, 2));
825        let y = Array1::<usize>::zeros(0);
826
827        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
828        assert!(model.fit(&x, &y).is_err());
829    }
830
831    #[test]
832    fn test_adaboost_single_class() {
833        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
834        let y = array![0, 0, 0];
835
836        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
837        assert!(model.fit(&x, &y).is_err());
838    }
839
840    #[test]
841    fn test_adaboost_zero_estimators() {
842        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
843        let y = array![0, 0, 1, 1];
844
845        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(0);
846        assert!(model.fit(&x, &y).is_err());
847    }
848
849    #[test]
850    fn test_adaboost_invalid_learning_rate() {
851        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
852        let y = array![0, 0, 1, 1];
853
854        let model = AdaBoostClassifier::<f64>::new()
855            .with_n_estimators(5)
856            .with_learning_rate(0.0);
857        assert!(model.fit(&x, &y).is_err());
858    }
859
860    #[test]
861    fn test_adaboost_pipeline_integration() {
862        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
863        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
864
865        let model = AdaBoostClassifier::<f64>::new()
866            .with_n_estimators(10)
867            .with_random_state(42);
868        let fitted = model.fit_pipeline(&x, &y).unwrap();
869        let preds = fitted.predict_pipeline(&x).unwrap();
870        assert_eq!(preds.len(), 6);
871    }
872
873    #[test]
874    fn test_adaboost_f32_support() {
875        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
876        let y = array![0, 0, 0, 1, 1, 1];
877
878        let model = AdaBoostClassifier::<f32>::new()
879            .with_n_estimators(10)
880            .with_random_state(42);
881        let fitted = model.fit(&x, &y).unwrap();
882        let preds = fitted.predict(&x).unwrap();
883        assert_eq!(preds.len(), 6);
884    }
885
886    #[test]
887    fn test_adaboost_default_trait() {
888        let model = AdaBoostClassifier::<f64>::default();
889        assert_eq!(model.n_estimators, 50);
890        assert!((model.learning_rate - 1.0).abs() < 1e-10);
891        assert_eq!(model.algorithm, AdaBoostAlgorithm::SammeR);
892    }
893
894    #[test]
895    fn test_adaboost_non_contiguous_labels() {
896        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
897        let y = array![10, 10, 10, 20, 20, 20];
898
899        let model = AdaBoostClassifier::<f64>::new()
900            .with_n_estimators(20)
901            .with_random_state(42);
902        let fitted = model.fit(&x, &y).unwrap();
903        let preds = fitted.predict(&x).unwrap();
904
905        assert_eq!(preds.len(), 6);
906        for &p in &preds {
907            assert!(p == 10 || p == 20);
908        }
909    }
910
911    #[test]
912    fn test_adaboost_sammer_learning_rate_effect() {
913        let x =
914            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
915        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
916
917        // Low learning rate should still work (just slower convergence).
918        let model = AdaBoostClassifier::<f64>::new()
919            .with_n_estimators(50)
920            .with_learning_rate(0.1)
921            .with_random_state(42);
922        let fitted = model.fit(&x, &y).unwrap();
923        let preds = fitted.predict(&x).unwrap();
924        assert_eq!(preds.len(), 8);
925    }
926
927    #[test]
928    fn test_adaboost_samme_learning_rate_effect() {
929        let x =
930            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
931        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
932
933        let model = AdaBoostClassifier::<f64>::new()
934            .with_n_estimators(50)
935            .with_algorithm(AdaBoostAlgorithm::Samme)
936            .with_learning_rate(0.5)
937            .with_random_state(42);
938        let fitted = model.fit(&x, &y).unwrap();
939        let preds = fitted.predict(&x).unwrap();
940        assert_eq!(preds.len(), 8);
941    }
942
943    #[test]
944    fn test_adaboost_many_features() {
945        // 4 features, only first one is informative.
946        let x = Array2::from_shape_vec(
947            (8, 4),
948            vec![
949                1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0,
950                5.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 8.0, 0.0, 0.0, 0.0,
951            ],
952        )
953        .unwrap();
954        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
955
956        let model = AdaBoostClassifier::<f64>::new()
957            .with_n_estimators(20)
958            .with_random_state(42);
959        let fitted = model.fit(&x, &y).unwrap();
960        let preds = fitted.predict(&x).unwrap();
961        assert_eq!(preds.len(), 8);
962    }
963
964    #[test]
965    fn test_adaboost_4_classes() {
966        let x = Array2::from_shape_vec(
967            (12, 1),
968            vec![
969                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
970            ],
971        )
972        .unwrap();
973        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3];
974
975        let model = AdaBoostClassifier::<f64>::new()
976            .with_n_estimators(50)
977            .with_random_state(42);
978        let fitted = model.fit(&x, &y).unwrap();
979        let preds = fitted.predict(&x).unwrap();
980
981        assert_eq!(preds.len(), 12);
982        assert_eq!(fitted.n_classes(), 4);
983    }
984
985    // -- Resample helper tests --
986
987    #[test]
988    fn test_resample_weighted_uniform() {
989        let weights = vec![0.25, 0.25, 0.25, 0.25];
990        let indices = resample_weighted(&weights, 4);
991        // With uniform weights, each index should appear once in order.
992        assert_eq!(indices, vec![0, 1, 2, 3]);
993    }
994
995    #[test]
996    fn test_resample_weighted_skewed() {
997        let weights = vec![0.0, 0.0, 0.0, 1.0];
998        let indices = resample_weighted(&weights, 4);
999        assert_eq!(indices.len(), 4);
1000        // All weight on last index.
1001        for &idx in &indices {
1002            assert_eq!(idx, 3);
1003        }
1004    }
1005
1006    #[test]
1007    fn test_resample_weighted_empty() {
1008        let weights: Vec<f64> = Vec::new();
1009        let indices = resample_weighted(&weights, 0);
1010        assert!(indices.is_empty());
1011    }
1012
1013    #[test]
1014    fn test_adaboost_sammer_single_estimator() {
1015        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1016        let y = array![0, 0, 0, 1, 1, 1];
1017
1018        let model = AdaBoostClassifier::<f64>::new()
1019            .with_n_estimators(1)
1020            .with_random_state(42);
1021        let fitted = model.fit(&x, &y).unwrap();
1022        let preds = fitted.predict(&x).unwrap();
1023        assert_eq!(preds.len(), 6);
1024    }
1025
1026    #[test]
1027    fn test_adaboost_samme_single_estimator() {
1028        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1029        let y = array![0, 0, 0, 1, 1, 1];
1030
1031        let model = AdaBoostClassifier::<f64>::new()
1032            .with_n_estimators(1)
1033            .with_algorithm(AdaBoostAlgorithm::Samme)
1034            .with_random_state(42);
1035        let fitted = model.fit(&x, &y).unwrap();
1036        let preds = fitted.predict(&x).unwrap();
1037        assert_eq!(preds.len(), 6);
1038    }
1039
1040    #[test]
1041    fn test_adaboost_negative_learning_rate() {
1042        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1043        let y = array![0, 0, 1, 1];
1044
1045        let model = AdaBoostClassifier::<f64>::new()
1046            .with_n_estimators(5)
1047            .with_learning_rate(-0.1);
1048        assert!(model.fit(&x, &y).is_err());
1049    }
1050}