Skip to main content

ferrolearn_tree/
adaboost.rs

1//! AdaBoost classifier.
2//!
3//! This module provides [`AdaBoostClassifier`], which implements the Adaptive
4//! Boosting algorithm using decision tree stumps (depth-1 trees) as base
5//! estimators. Two algorithm variants are supported:
6//!
7//! - **SAMME**: uses discrete class predictions and works with multiclass
8//!   problems directly.
9//! - **SAMME.R** (default): uses class probability estimates, typically giving
10//!   better performance than SAMME.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::AdaBoostClassifier;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let x = Array2::from_shape_vec((8, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
22//! ]).unwrap();
23//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
24//!
25//! let model = AdaBoostClassifier::<f64>::new()
26//!     .with_n_estimators(50)
27//!     .with_random_state(42);
28//! let fitted = model.fit(&x, &y).unwrap();
29//! let preds = fitted.predict(&x).unwrap();
30//! ```
31
32use crate::decision_tree::{
33    self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
34};
35use ferrolearn_core::error::FerroError;
36use ferrolearn_core::introspection::HasClasses;
37use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2};
40use num_traits::{Float, FromPrimitive, ToPrimitive};
41
42// ---------------------------------------------------------------------------
43// Algorithm enum
44// ---------------------------------------------------------------------------
45
46/// AdaBoost algorithm variant.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AdaBoostAlgorithm {
49    /// SAMME: Stagewise Additive Modeling using a Multi-class Exponential loss.
50    ///
51    /// Uses discrete class predictions from each base estimator.
52    Samme,
53    /// SAMME.R: the "real" variant that uses class probability estimates.
54    ///
55    /// Generally outperforms SAMME.
56    SammeR,
57}
58
59// ---------------------------------------------------------------------------
60// AdaBoostClassifier
61// ---------------------------------------------------------------------------
62
63/// AdaBoost classifier using decision tree stumps as base estimators.
64///
65/// At each boosting round a decision tree stump (max depth = 1) is fitted
66/// to the weighted training data. Misclassified samples receive higher
67/// weight in subsequent rounds, allowing the ensemble to focus on hard
68/// examples.
69///
70/// # Type Parameters
71///
72/// - `F`: The floating-point type (`f32` or `f64`).
73#[derive(Debug, Clone)]
74pub struct AdaBoostClassifier<F> {
75    /// Number of boosting stages (stumps).
76    pub n_estimators: usize,
77    /// Learning rate (shrinkage). Lower values require more estimators.
78    pub learning_rate: f64,
79    /// Algorithm variant (`SAMME` or `SAMME.R`).
80    pub algorithm: AdaBoostAlgorithm,
81    /// Random seed for reproducibility.
82    pub random_state: Option<u64>,
83    _marker: std::marker::PhantomData<F>,
84}
85
86impl<F: Float> AdaBoostClassifier<F> {
87    /// Create a new `AdaBoostClassifier` with default settings.
88    ///
89    /// Defaults: `n_estimators = 50`, `learning_rate = 1.0`,
90    /// `algorithm = SAMME.R`, `random_state = None`.
91    #[must_use]
92    pub fn new() -> Self {
93        Self {
94            n_estimators: 50,
95            learning_rate: 1.0,
96            algorithm: AdaBoostAlgorithm::SammeR,
97            random_state: None,
98            _marker: std::marker::PhantomData,
99        }
100    }
101
102    /// Set the number of boosting stages.
103    #[must_use]
104    pub fn with_n_estimators(mut self, n: usize) -> Self {
105        self.n_estimators = n;
106        self
107    }
108
109    /// Set the learning rate.
110    #[must_use]
111    pub fn with_learning_rate(mut self, lr: f64) -> Self {
112        self.learning_rate = lr;
113        self
114    }
115
116    /// Set the algorithm variant.
117    #[must_use]
118    pub fn with_algorithm(mut self, algo: AdaBoostAlgorithm) -> Self {
119        self.algorithm = algo;
120        self
121    }
122
123    /// Set the random seed for reproducibility.
124    #[must_use]
125    pub fn with_random_state(mut self, seed: u64) -> Self {
126        self.random_state = Some(seed);
127        self
128    }
129}
130
131impl<F: Float> Default for AdaBoostClassifier<F> {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137// ---------------------------------------------------------------------------
138// FittedAdaBoostClassifier
139// ---------------------------------------------------------------------------
140
141/// A fitted AdaBoost classifier.
142///
143/// Stores the sequence of stumps and their weights. Predictions are made
144/// by weighted majority vote (SAMME) or weighted probability averaging
145/// (SAMME.R).
146#[derive(Debug, Clone)]
147pub struct FittedAdaBoostClassifier<F> {
148    /// Sorted unique class labels.
149    classes: Vec<usize>,
150    /// Sequence of fitted tree stumps.
151    estimators: Vec<Vec<Node<F>>>,
152    /// Weight of each estimator (SAMME) or kept for SAMME.R bookkeeping.
153    estimator_weights: Vec<F>,
154    /// Number of features.
155    n_features: usize,
156    /// Number of classes.
157    n_classes: usize,
158    /// Algorithm used.
159    algorithm: AdaBoostAlgorithm,
160}
161
162impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for AdaBoostClassifier<F> {
163    type Fitted = FittedAdaBoostClassifier<F>;
164    type Error = FerroError;
165
166    /// Fit the AdaBoost classifier.
167    ///
168    /// # Errors
169    ///
170    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
171    /// numbers of samples.
172    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
173    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
174    fn fit(
175        &self,
176        x: &Array2<F>,
177        y: &Array1<usize>,
178    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
179        let (n_samples, n_features) = x.dim();
180
181        if n_samples != y.len() {
182            return Err(FerroError::ShapeMismatch {
183                expected: vec![n_samples],
184                actual: vec![y.len()],
185                context: "y length must match number of samples in X".into(),
186            });
187        }
188        if n_samples == 0 {
189            return Err(FerroError::InsufficientSamples {
190                required: 1,
191                actual: 0,
192                context: "AdaBoostClassifier requires at least one sample".into(),
193            });
194        }
195        if self.n_estimators == 0 {
196            return Err(FerroError::InvalidParameter {
197                name: "n_estimators".into(),
198                reason: "must be at least 1".into(),
199            });
200        }
201        if self.learning_rate <= 0.0 {
202            return Err(FerroError::InvalidParameter {
203                name: "learning_rate".into(),
204                reason: "must be positive".into(),
205            });
206        }
207
208        // Determine unique classes.
209        let mut classes: Vec<usize> = y.iter().copied().collect();
210        classes.sort_unstable();
211        classes.dedup();
212        let n_classes = classes.len();
213
214        if n_classes < 2 {
215            return Err(FerroError::InvalidParameter {
216                name: "y".into(),
217                reason: "need at least 2 distinct classes".into(),
218            });
219        }
220
221        let y_mapped: Vec<usize> = y
222            .iter()
223            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
224            .collect();
225
226        match self.algorithm {
227            AdaBoostAlgorithm::Samme => {
228                self.fit_samme(x, &y_mapped, n_samples, n_features, n_classes, &classes)
229            }
230            AdaBoostAlgorithm::SammeR => {
231                self.fit_samme_r(x, &y_mapped, n_samples, n_features, n_classes, &classes)
232            }
233        }
234    }
235}
236
237impl<F: Float + Send + Sync + 'static> AdaBoostClassifier<F> {
238    /// Fit using the SAMME algorithm (discrete predictions).
239    fn fit_samme(
240        &self,
241        x: &Array2<F>,
242        y_mapped: &[usize],
243        n_samples: usize,
244        n_features: usize,
245        n_classes: usize,
246        classes: &[usize],
247    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
248        let lr = F::from(self.learning_rate).unwrap();
249        let n_f = F::from(n_samples).unwrap();
250        let eps = F::from(1e-10).unwrap();
251
252        // Initialize sample weights uniformly.
253        let mut weights = vec![F::one() / n_f; n_samples];
254
255        let all_features: Vec<usize> = (0..n_features).collect();
256        let stump_params = decision_tree::TreeParams {
257            max_depth: Some(1),
258            min_samples_split: 2,
259            min_samples_leaf: 1,
260        };
261
262        let mut estimators = Vec::with_capacity(self.n_estimators);
263        let mut estimator_weights = Vec::with_capacity(self.n_estimators);
264
265        for _ in 0..self.n_estimators {
266            // Build weighted sample indices: replicate indices proportional to weight.
267            let indices = resample_weighted(&weights, n_samples);
268
269            let tree = build_classification_tree_with_feature_subset(
270                x,
271                y_mapped,
272                n_classes,
273                &indices,
274                &all_features,
275                &stump_params,
276                ClassificationCriterion::Gini,
277            );
278
279            // Compute predictions and weighted error.
280            let mut weighted_error = F::zero();
281            let mut preds = vec![0usize; n_samples];
282            for i in 0..n_samples {
283                let row = x.row(i);
284                let leaf_idx = decision_tree::traverse(&tree, &row);
285                if let Node::Leaf { value, .. } = tree[leaf_idx] {
286                    preds[i] = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
287                }
288                if preds[i] != y_mapped[i] {
289                    weighted_error = weighted_error + weights[i];
290                }
291            }
292
293            // Normalise error.
294            let weight_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
295            let err = if weight_sum > F::zero() {
296                weighted_error / weight_sum
297            } else {
298                F::from(0.5).unwrap()
299            };
300
301            // If error is too high or zero, stop or skip.
302            if err >= F::one() - F::one() / F::from(n_classes).unwrap() {
303                // Error too high; stop boosting.
304                if estimators.is_empty() {
305                    // Keep at least one estimator.
306                    estimators.push(tree);
307                    estimator_weights.push(F::one());
308                }
309                break;
310            }
311
312            // Estimator weight: SAMME formula.
313            let alpha = lr * ((F::one() - err).max(eps) / err.max(eps)).ln()
314                + lr * (F::from(n_classes - 1).unwrap()).ln();
315
316            // Update sample weights.
317            for i in 0..n_samples {
318                if preds[i] != y_mapped[i] {
319                    weights[i] = weights[i] * alpha.exp();
320                }
321            }
322
323            // Normalise weights.
324            let new_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
325            if new_sum > F::zero() {
326                for w in &mut weights {
327                    *w = *w / new_sum;
328                }
329            }
330
331            estimators.push(tree);
332            estimator_weights.push(alpha);
333        }
334
335        Ok(FittedAdaBoostClassifier {
336            classes: classes.to_vec(),
337            estimators,
338            estimator_weights,
339            n_features,
340            n_classes,
341            algorithm: AdaBoostAlgorithm::Samme,
342        })
343    }
344
345    /// Fit using the SAMME.R algorithm (real-valued / probability-based).
346    fn fit_samme_r(
347        &self,
348        x: &Array2<F>,
349        y_mapped: &[usize],
350        n_samples: usize,
351        n_features: usize,
352        n_classes: usize,
353        classes: &[usize],
354    ) -> Result<FittedAdaBoostClassifier<F>, FerroError> {
355        let lr = F::from(self.learning_rate).unwrap();
356        let n_f = F::from(n_samples).unwrap();
357        let eps = F::from(1e-10).unwrap();
358        let k_f = F::from(n_classes).unwrap();
359
360        // Initialize sample weights uniformly.
361        let mut weights = vec![F::one() / n_f; n_samples];
362
363        let all_features: Vec<usize> = (0..n_features).collect();
364        let stump_params = decision_tree::TreeParams {
365            max_depth: Some(1),
366            min_samples_split: 2,
367            min_samples_leaf: 1,
368        };
369
370        let mut estimators = Vec::with_capacity(self.n_estimators);
371        let mut estimator_weights = Vec::with_capacity(self.n_estimators);
372
373        for _ in 0..self.n_estimators {
374            let indices = resample_weighted(&weights, n_samples);
375
376            let tree = build_classification_tree_with_feature_subset(
377                x,
378                y_mapped,
379                n_classes,
380                &indices,
381                &all_features,
382                &stump_params,
383                ClassificationCriterion::Gini,
384            );
385
386            // Get class probability estimates for each sample.
387            let mut proba = vec![vec![F::zero(); n_classes]; n_samples];
388            for (i, proba_row) in proba.iter_mut().enumerate() {
389                let row = x.row(i);
390                let leaf_idx = decision_tree::traverse(&tree, &row);
391                if let Node::Leaf {
392                    class_distribution: Some(ref dist),
393                    ..
394                } = tree[leaf_idx]
395                {
396                    for (k, &p) in dist.iter().enumerate() {
397                        proba_row[k] = p.max(eps);
398                    }
399                } else {
400                    // Fallback: uniform.
401                    for val in proba_row.iter_mut() {
402                        *val = F::one() / k_f;
403                    }
404                }
405                // Normalise.
406                let row_sum: F = proba_row.iter().copied().fold(F::zero(), |a, b| a + b);
407                if row_sum > F::zero() {
408                    for val in proba_row.iter_mut() {
409                        *val = *val / row_sum;
410                    }
411                }
412            }
413
414            // SAMME.R weight update: based on log-probability.
415            // h_k(x) = (K-1) * (log(p_k(x)) - (1/K) * sum_j log(p_j(x)))
416            // Then update: w_i *= exp(-(K-1)/K * lr * sum_k y_{ik} * log(p_k(x)))
417            // Simplified: w_i *= exp(-lr * (K-1)/K * log(p_{y_i}(x)))
418            let factor = lr * (k_f - F::one()) / k_f;
419            let mut any_update = false;
420
421            for i in 0..n_samples {
422                let p_correct = proba[i][y_mapped[i]].max(eps);
423                let exponent = -factor * p_correct.ln();
424                weights[i] = weights[i] * exponent.exp();
425                if exponent.abs() > eps {
426                    any_update = true;
427                }
428            }
429
430            // Normalise weights.
431            let new_sum: F = weights.iter().copied().fold(F::zero(), |a, b| a + b);
432            if new_sum > F::zero() {
433                for w in &mut weights {
434                    *w = *w / new_sum;
435                }
436            }
437
438            estimators.push(tree);
439            estimator_weights.push(F::one()); // SAMME.R uses equal weight; prediction uses probabilities.
440
441            if !any_update {
442                break;
443            }
444        }
445
446        Ok(FittedAdaBoostClassifier {
447            classes: classes.to_vec(),
448            estimators,
449            estimator_weights,
450            n_features,
451            n_classes,
452            algorithm: AdaBoostAlgorithm::SammeR,
453        })
454    }
455}
456
457impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedAdaBoostClassifier<F> {
458    type Output = Array1<usize>;
459    type Error = FerroError;
460
461    /// Predict class labels.
462    ///
463    /// - **SAMME**: weighted majority vote using estimator weights.
464    /// - **SAMME.R**: weighted average of log-probabilities.
465    ///
466    /// # Errors
467    ///
468    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
469    /// not match the fitted model.
470    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
471        if x.ncols() != self.n_features {
472            return Err(FerroError::ShapeMismatch {
473                expected: vec![self.n_features],
474                actual: vec![x.ncols()],
475                context: "number of features must match fitted model".into(),
476            });
477        }
478
479        let n_samples = x.nrows();
480
481        match self.algorithm {
482            AdaBoostAlgorithm::Samme => self.predict_samme(x, n_samples),
483            AdaBoostAlgorithm::SammeR => self.predict_samme_r(x, n_samples),
484        }
485    }
486}
487
488impl<F: Float + Send + Sync + 'static> FittedAdaBoostClassifier<F> {
489    /// Predict using SAMME (weighted majority vote).
490    fn predict_samme(&self, x: &Array2<F>, n_samples: usize) -> Result<Array1<usize>, FerroError> {
491        let mut predictions = Array1::zeros(n_samples);
492
493        for i in 0..n_samples {
494            let row = x.row(i);
495            let mut class_scores = vec![F::zero(); self.n_classes];
496
497            for (t, tree_nodes) in self.estimators.iter().enumerate() {
498                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
499                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
500                    let class_idx = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
501                    if class_idx < self.n_classes {
502                        class_scores[class_idx] =
503                            class_scores[class_idx] + self.estimator_weights[t];
504                    }
505                }
506            }
507
508            let best = class_scores
509                .iter()
510                .enumerate()
511                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
512                .map(|(k, _)| k)
513                .unwrap_or(0);
514            predictions[i] = self.classes[best];
515        }
516
517        Ok(predictions)
518    }
519
520    /// Predict using SAMME.R (weighted probability averaging).
521    fn predict_samme_r(
522        &self,
523        x: &Array2<F>,
524        n_samples: usize,
525    ) -> Result<Array1<usize>, FerroError> {
526        let eps = F::from(1e-10).unwrap();
527        let k_f = F::from(self.n_classes).unwrap();
528        let k_minus_1 = k_f - F::one();
529
530        let mut predictions = Array1::zeros(n_samples);
531
532        for i in 0..n_samples {
533            let row = x.row(i);
534            let mut accumulated = vec![F::zero(); self.n_classes];
535
536            for tree_nodes in &self.estimators {
537                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
538                if let Node::Leaf {
539                    class_distribution: Some(ref dist),
540                    ..
541                } = tree_nodes[leaf_idx]
542                {
543                    // h_k(x) = (K-1) * (log(p_k) - mean(log(p_j)))
544                    let log_probs: Vec<F> = dist.iter().map(|&p| p.max(eps).ln()).collect();
545                    let mean_log: F = log_probs.iter().copied().fold(F::zero(), |a, b| a + b) / k_f;
546
547                    for k in 0..self.n_classes {
548                        accumulated[k] = accumulated[k] + k_minus_1 * (log_probs[k] - mean_log);
549                    }
550                } else {
551                    // Leaf without distribution: predict from value.
552                    if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
553                        let class_idx = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
554                        if class_idx < self.n_classes {
555                            accumulated[class_idx] = accumulated[class_idx] + F::one();
556                        }
557                    }
558                }
559            }
560
561            let best = accumulated
562                .iter()
563                .enumerate()
564                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
565                .map(|(k, _)| k)
566                .unwrap_or(0);
567            predictions[i] = self.classes[best];
568        }
569
570        Ok(predictions)
571    }
572}
573
574impl<F: Float + Send + Sync + 'static> HasClasses for FittedAdaBoostClassifier<F> {
575    fn classes(&self) -> &[usize] {
576        &self.classes
577    }
578
579    fn n_classes(&self) -> usize {
580        self.classes.len()
581    }
582}
583
584// Pipeline integration.
585impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
586    for AdaBoostClassifier<F>
587{
588    fn fit_pipeline(
589        &self,
590        x: &Array2<F>,
591        y: &Array1<F>,
592    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
593        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
594        let fitted = self.fit(x, &y_usize)?;
595        Ok(Box::new(FittedAdaBoostPipelineAdapter(fitted)))
596    }
597}
598
599/// Pipeline adapter for `FittedAdaBoostClassifier<F>`.
600struct FittedAdaBoostPipelineAdapter<F: Float + Send + Sync + 'static>(FittedAdaBoostClassifier<F>);
601
602impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
603    for FittedAdaBoostPipelineAdapter<F>
604{
605    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
606        let preds = self.0.predict(x)?;
607        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
608    }
609}
610
611// ---------------------------------------------------------------------------
612// Internal helpers
613// ---------------------------------------------------------------------------
614
615/// Resample indices proportional to weights (weighted bootstrap).
616///
617/// Uses a systematic resampling approach: the cumulative weight distribution
618/// determines which original indices appear in the resampled set.
619fn resample_weighted<F: Float>(weights: &[F], n: usize) -> Vec<usize> {
620    if weights.is_empty() {
621        return Vec::new();
622    }
623
624    // Build cumulative distribution.
625    let mut cumsum = Vec::with_capacity(weights.len());
626    let mut running = F::zero();
627    for &w in weights {
628        running = running + w;
629        cumsum.push(running);
630    }
631
632    // Normalise (in case weights don't sum to 1).
633    let total = running;
634    if total <= F::zero() {
635        return (0..n).collect();
636    }
637
638    let mut indices = Vec::with_capacity(n);
639    let step = total / F::from(n).unwrap();
640    let mut threshold = step / F::from(2.0).unwrap(); // Start in the middle of the first bin.
641    let mut j = 0;
642
643    for _ in 0..n {
644        while j < cumsum.len() - 1 && cumsum[j] < threshold {
645            j += 1;
646        }
647        indices.push(j);
648        threshold = threshold + step;
649    }
650
651    indices
652}
653
654// ---------------------------------------------------------------------------
655// Tests
656// ---------------------------------------------------------------------------
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use ndarray::array;
662
663    // -- SAMME.R tests --
664
665    #[test]
666    fn test_adaboost_sammer_binary_simple() {
667        let x = Array2::from_shape_vec(
668            (8, 2),
669            vec![
670                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
671            ],
672        )
673        .unwrap();
674        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
675
676        let model = AdaBoostClassifier::<f64>::new()
677            .with_n_estimators(50)
678            .with_random_state(42);
679        let fitted = model.fit(&x, &y).unwrap();
680        let preds = fitted.predict(&x).unwrap();
681
682        assert_eq!(preds.len(), 8);
683        for i in 0..4 {
684            assert_eq!(preds[i], 0);
685        }
686        for i in 4..8 {
687            assert_eq!(preds[i], 1);
688        }
689    }
690
691    #[test]
692    fn test_adaboost_sammer_multiclass() {
693        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
694            .unwrap();
695        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
696
697        let model = AdaBoostClassifier::<f64>::new()
698            .with_n_estimators(50)
699            .with_random_state(42);
700        let fitted = model.fit(&x, &y).unwrap();
701        let preds = fitted.predict(&x).unwrap();
702
703        assert_eq!(preds.len(), 9);
704        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
705        assert!(
706            correct >= 5,
707            "Expected at least 5/9 correct, got {}/9",
708            correct
709        );
710    }
711
712    // -- SAMME tests --
713
714    #[test]
715    fn test_adaboost_samme_binary_simple() {
716        let x = Array2::from_shape_vec(
717            (8, 2),
718            vec![
719                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
720            ],
721        )
722        .unwrap();
723        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
724
725        let model = AdaBoostClassifier::<f64>::new()
726            .with_n_estimators(50)
727            .with_algorithm(AdaBoostAlgorithm::Samme)
728            .with_random_state(42);
729        let fitted = model.fit(&x, &y).unwrap();
730        let preds = fitted.predict(&x).unwrap();
731
732        assert_eq!(preds.len(), 8);
733        for i in 0..4 {
734            assert_eq!(preds[i], 0);
735        }
736        for i in 4..8 {
737            assert_eq!(preds[i], 1);
738        }
739    }
740
741    #[test]
742    fn test_adaboost_samme_multiclass() {
743        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
744            .unwrap();
745        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
746
747        let model = AdaBoostClassifier::<f64>::new()
748            .with_n_estimators(50)
749            .with_algorithm(AdaBoostAlgorithm::Samme)
750            .with_random_state(42);
751        let fitted = model.fit(&x, &y).unwrap();
752        let preds = fitted.predict(&x).unwrap();
753
754        assert_eq!(preds.len(), 9);
755        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
756        assert!(
757            correct >= 5,
758            "Expected at least 5/9 correct for SAMME multiclass, got {}/9",
759            correct
760        );
761    }
762
763    // -- Common tests --
764
765    #[test]
766    fn test_adaboost_has_classes() {
767        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
768        let y = array![0, 1, 2, 0, 1, 2];
769
770        let model = AdaBoostClassifier::<f64>::new()
771            .with_n_estimators(5)
772            .with_random_state(0);
773        let fitted = model.fit(&x, &y).unwrap();
774
775        assert_eq!(fitted.classes(), &[0, 1, 2]);
776        assert_eq!(fitted.n_classes(), 3);
777    }
778
779    #[test]
780    fn test_adaboost_reproducibility() {
781        let x = Array2::from_shape_vec(
782            (8, 2),
783            vec![
784                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
785            ],
786        )
787        .unwrap();
788        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
789
790        let model = AdaBoostClassifier::<f64>::new()
791            .with_n_estimators(10)
792            .with_random_state(42);
793
794        let fitted1 = model.fit(&x, &y).unwrap();
795        let fitted2 = model.fit(&x, &y).unwrap();
796
797        let preds1 = fitted1.predict(&x).unwrap();
798        let preds2 = fitted2.predict(&x).unwrap();
799        assert_eq!(preds1, preds2);
800    }
801
802    #[test]
803    fn test_adaboost_shape_mismatch_fit() {
804        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
805        let y = array![0, 1];
806
807        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
808        assert!(model.fit(&x, &y).is_err());
809    }
810
811    #[test]
812    fn test_adaboost_shape_mismatch_predict() {
813        let x =
814            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
815        let y = array![0, 0, 1, 1];
816
817        let model = AdaBoostClassifier::<f64>::new()
818            .with_n_estimators(5)
819            .with_random_state(0);
820        let fitted = model.fit(&x, &y).unwrap();
821
822        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
823        assert!(fitted.predict(&x_bad).is_err());
824    }
825
826    #[test]
827    fn test_adaboost_empty_data() {
828        let x = Array2::<f64>::zeros((0, 2));
829        let y = Array1::<usize>::zeros(0);
830
831        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
832        assert!(model.fit(&x, &y).is_err());
833    }
834
835    #[test]
836    fn test_adaboost_single_class() {
837        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
838        let y = array![0, 0, 0];
839
840        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(5);
841        assert!(model.fit(&x, &y).is_err());
842    }
843
844    #[test]
845    fn test_adaboost_zero_estimators() {
846        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
847        let y = array![0, 0, 1, 1];
848
849        let model = AdaBoostClassifier::<f64>::new().with_n_estimators(0);
850        assert!(model.fit(&x, &y).is_err());
851    }
852
853    #[test]
854    fn test_adaboost_invalid_learning_rate() {
855        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
856        let y = array![0, 0, 1, 1];
857
858        let model = AdaBoostClassifier::<f64>::new()
859            .with_n_estimators(5)
860            .with_learning_rate(0.0);
861        assert!(model.fit(&x, &y).is_err());
862    }
863
864    #[test]
865    fn test_adaboost_pipeline_integration() {
866        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
867        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
868
869        let model = AdaBoostClassifier::<f64>::new()
870            .with_n_estimators(10)
871            .with_random_state(42);
872        let fitted = model.fit_pipeline(&x, &y).unwrap();
873        let preds = fitted.predict_pipeline(&x).unwrap();
874        assert_eq!(preds.len(), 6);
875    }
876
877    #[test]
878    fn test_adaboost_f32_support() {
879        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
880        let y = array![0, 0, 0, 1, 1, 1];
881
882        let model = AdaBoostClassifier::<f32>::new()
883            .with_n_estimators(10)
884            .with_random_state(42);
885        let fitted = model.fit(&x, &y).unwrap();
886        let preds = fitted.predict(&x).unwrap();
887        assert_eq!(preds.len(), 6);
888    }
889
890    #[test]
891    fn test_adaboost_default_trait() {
892        let model = AdaBoostClassifier::<f64>::default();
893        assert_eq!(model.n_estimators, 50);
894        assert!((model.learning_rate - 1.0).abs() < 1e-10);
895        assert_eq!(model.algorithm, AdaBoostAlgorithm::SammeR);
896    }
897
898    #[test]
899    fn test_adaboost_non_contiguous_labels() {
900        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
901        let y = array![10, 10, 10, 20, 20, 20];
902
903        let model = AdaBoostClassifier::<f64>::new()
904            .with_n_estimators(20)
905            .with_random_state(42);
906        let fitted = model.fit(&x, &y).unwrap();
907        let preds = fitted.predict(&x).unwrap();
908
909        assert_eq!(preds.len(), 6);
910        for &p in preds.iter() {
911            assert!(p == 10 || p == 20);
912        }
913    }
914
915    #[test]
916    fn test_adaboost_sammer_learning_rate_effect() {
917        let x =
918            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
919        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
920
921        // Low learning rate should still work (just slower convergence).
922        let model = AdaBoostClassifier::<f64>::new()
923            .with_n_estimators(50)
924            .with_learning_rate(0.1)
925            .with_random_state(42);
926        let fitted = model.fit(&x, &y).unwrap();
927        let preds = fitted.predict(&x).unwrap();
928        assert_eq!(preds.len(), 8);
929    }
930
931    #[test]
932    fn test_adaboost_samme_learning_rate_effect() {
933        let x =
934            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
935        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
936
937        let model = AdaBoostClassifier::<f64>::new()
938            .with_n_estimators(50)
939            .with_algorithm(AdaBoostAlgorithm::Samme)
940            .with_learning_rate(0.5)
941            .with_random_state(42);
942        let fitted = model.fit(&x, &y).unwrap();
943        let preds = fitted.predict(&x).unwrap();
944        assert_eq!(preds.len(), 8);
945    }
946
947    #[test]
948    fn test_adaboost_many_features() {
949        // 4 features, only first one is informative.
950        let x = Array2::from_shape_vec(
951            (8, 4),
952            vec![
953                1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0,
954                5.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 8.0, 0.0, 0.0, 0.0,
955            ],
956        )
957        .unwrap();
958        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
959
960        let model = AdaBoostClassifier::<f64>::new()
961            .with_n_estimators(20)
962            .with_random_state(42);
963        let fitted = model.fit(&x, &y).unwrap();
964        let preds = fitted.predict(&x).unwrap();
965        assert_eq!(preds.len(), 8);
966    }
967
968    #[test]
969    fn test_adaboost_4_classes() {
970        let x = Array2::from_shape_vec(
971            (12, 1),
972            vec![
973                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
974            ],
975        )
976        .unwrap();
977        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3];
978
979        let model = AdaBoostClassifier::<f64>::new()
980            .with_n_estimators(50)
981            .with_random_state(42);
982        let fitted = model.fit(&x, &y).unwrap();
983        let preds = fitted.predict(&x).unwrap();
984
985        assert_eq!(preds.len(), 12);
986        assert_eq!(fitted.n_classes(), 4);
987    }
988
989    // -- Resample helper tests --
990
991    #[test]
992    fn test_resample_weighted_uniform() {
993        let weights = vec![0.25, 0.25, 0.25, 0.25];
994        let indices = resample_weighted(&weights, 4);
995        assert_eq!(indices.len(), 4);
996        // With uniform weights, each index should appear once.
997        for i in 0..4 {
998            assert_eq!(indices[i], i);
999        }
1000    }
1001
1002    #[test]
1003    fn test_resample_weighted_skewed() {
1004        let weights = vec![0.0, 0.0, 0.0, 1.0];
1005        let indices = resample_weighted(&weights, 4);
1006        assert_eq!(indices.len(), 4);
1007        // All weight on last index.
1008        for &idx in &indices {
1009            assert_eq!(idx, 3);
1010        }
1011    }
1012
1013    #[test]
1014    fn test_resample_weighted_empty() {
1015        let weights: Vec<f64> = Vec::new();
1016        let indices = resample_weighted(&weights, 0);
1017        assert!(indices.is_empty());
1018    }
1019
1020    #[test]
1021    fn test_adaboost_sammer_single_estimator() {
1022        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1023        let y = array![0, 0, 0, 1, 1, 1];
1024
1025        let model = AdaBoostClassifier::<f64>::new()
1026            .with_n_estimators(1)
1027            .with_random_state(42);
1028        let fitted = model.fit(&x, &y).unwrap();
1029        let preds = fitted.predict(&x).unwrap();
1030        assert_eq!(preds.len(), 6);
1031    }
1032
1033    #[test]
1034    fn test_adaboost_samme_single_estimator() {
1035        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1036        let y = array![0, 0, 0, 1, 1, 1];
1037
1038        let model = AdaBoostClassifier::<f64>::new()
1039            .with_n_estimators(1)
1040            .with_algorithm(AdaBoostAlgorithm::Samme)
1041            .with_random_state(42);
1042        let fitted = model.fit(&x, &y).unwrap();
1043        let preds = fitted.predict(&x).unwrap();
1044        assert_eq!(preds.len(), 6);
1045    }
1046
1047    #[test]
1048    fn test_adaboost_negative_learning_rate() {
1049        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1050        let y = array![0, 0, 1, 1];
1051
1052        let model = AdaBoostClassifier::<f64>::new()
1053            .with_n_estimators(5)
1054            .with_learning_rate(-0.1);
1055        assert!(model.fit(&x, &y).is_err());
1056    }
1057}