Skip to main content

ferrolearn_tree/
voting.rs

1//! Voting ensemble classifiers and regressors.
2//!
3//! This module provides [`VotingClassifier`] and [`VotingRegressor`], which
4//! train multiple decision trees with different hyperparameter configurations
5//! on the full dataset and aggregate their predictions (majority vote for
6//! classification, averaging for regression).
7//!
8//! Unlike [`RandomForestClassifier`](crate::RandomForestClassifier), voting
9//! ensembles do **not** use bootstrap sampling — each tree sees the entire
10//! dataset. Diversity comes from varying the tree hyperparameters.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::VotingClassifier;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let x = Array2::from_shape_vec((8, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
22//! ]).unwrap();
23//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
24//!
25//! let model = VotingClassifier::<f64>::new()
26//!     .with_max_depths(vec![Some(2), Some(3), Some(5), None]);
27//! let fitted = model.fit(&x, &y).unwrap();
28//! let preds = fitted.predict(&x).unwrap();
29//! ```
30
31use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::HasClasses;
33use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
34use ferrolearn_core::traits::{Fit, Predict};
35use ndarray::{Array1, Array2};
36use num_traits::{Float, FromPrimitive, ToPrimitive};
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    ClassificationCriterion, DecisionTreeClassifier, DecisionTreeRegressor,
41    FittedDecisionTreeClassifier, FittedDecisionTreeRegressor,
42};
43
44// ---------------------------------------------------------------------------
45// VotingClassifier
46// ---------------------------------------------------------------------------
47
48/// Voting ensemble classifier.
49///
50/// Trains multiple decision tree classifiers with different hyperparameter
51/// configurations on the full dataset. Final predictions are made by majority
52/// vote across all trees.
53///
54/// Diversity is introduced by varying `max_depth` across the ensemble members.
55/// If no explicit depths are provided, a default set of depths is used.
56///
57/// # Type Parameters
58///
59/// - `F`: The floating-point type (`f32` or `f64`).
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct VotingClassifier<F> {
62    /// Maximum depth settings for each tree in the ensemble.
63    /// Each entry produces one decision tree.
64    pub max_depths: Vec<Option<usize>>,
65    /// Minimum number of samples required to split an internal node.
66    pub min_samples_split: usize,
67    /// Minimum number of samples required in a leaf node.
68    pub min_samples_leaf: usize,
69    /// Splitting criterion for all trees.
70    pub criterion: ClassificationCriterion,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> VotingClassifier<F> {
75    /// Create a new `VotingClassifier` with default settings.
76    ///
77    /// Defaults: `max_depths = [Some(2), Some(4), Some(6), None]`,
78    /// `min_samples_split = 2`, `min_samples_leaf = 1`, `criterion = Gini`.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            max_depths: vec![Some(2), Some(4), Some(6), None],
83            min_samples_split: 2,
84            min_samples_leaf: 1,
85            criterion: ClassificationCriterion::Gini,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Set the maximum depth settings for each ensemble member.
91    ///
92    /// Each entry in the vector produces one decision tree.
93    #[must_use]
94    pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
95        self.max_depths = max_depths;
96        self
97    }
98
99    /// Set the minimum number of samples required to split a node.
100    #[must_use]
101    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
102        self.min_samples_split = min_samples_split;
103        self
104    }
105
106    /// Set the minimum number of samples required in a leaf node.
107    #[must_use]
108    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
109        self.min_samples_leaf = min_samples_leaf;
110        self
111    }
112
113    /// Set the splitting criterion for all trees.
114    #[must_use]
115    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
116        self.criterion = criterion;
117        self
118    }
119}
120
121impl<F: Float> Default for VotingClassifier<F> {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127// ---------------------------------------------------------------------------
128// FittedVotingClassifier
129// ---------------------------------------------------------------------------
130
131/// A fitted voting ensemble classifier.
132///
133/// Stores the individually fitted decision trees and aggregates their
134/// predictions by majority vote.
135#[derive(Debug, Clone)]
136pub struct FittedVotingClassifier<F> {
137    /// The fitted decision tree classifiers.
138    trees: Vec<FittedDecisionTreeClassifier<F>>,
139    /// Sorted unique class labels.
140    classes: Vec<usize>,
141}
142
143impl<F: Float + Send + Sync + 'static> FittedVotingClassifier<F> {
144    /// Returns the number of trees in the ensemble.
145    #[must_use]
146    pub fn n_estimators(&self) -> usize {
147        self.trees.len()
148    }
149
150    /// Mean accuracy on the given test data and labels.
151    /// Equivalent to sklearn's `ClassifierMixin.score`.
152    ///
153    /// # Errors
154    ///
155    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
156    /// the feature count does not match the training data.
157    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
158        if x.nrows() != y.len() {
159            return Err(FerroError::ShapeMismatch {
160                expected: vec![x.nrows()],
161                actual: vec![y.len()],
162                context: "y length must match number of samples in X".into(),
163            });
164        }
165        let preds = self.predict(x)?;
166        Ok(crate::mean_accuracy(&preds, y))
167    }
168
169    /// Predict class probabilities by averaging the per-tree
170    /// `predict_proba` outputs (sklearn's `voting='soft'` semantics).
171    ///
172    /// Returns shape `(n_samples, n_classes)`. Each row sums to 1.
173    ///
174    /// # Errors
175    ///
176    /// Returns [`FerroError::ShapeMismatch`] if the number of features
177    /// does not match the fitted model.
178    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
179        let n_samples = x.nrows();
180        let n_classes = self.classes.len();
181        let n_trees_f = F::from(self.trees.len()).unwrap();
182        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
183
184        for tree in &self.trees {
185            let tree_proba = tree.predict_proba(x)?;
186            for i in 0..n_samples {
187                for j in 0..n_classes {
188                    proba[[i, j]] = proba[[i, j]] + tree_proba[[i, j]];
189                }
190            }
191        }
192        for i in 0..n_samples {
193            for j in 0..n_classes {
194                proba[[i, j]] = proba[[i, j]] / n_trees_f;
195            }
196        }
197        Ok(proba)
198    }
199
200    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
201    /// sklearn's `ClassifierMixin.predict_log_proba`.
202    ///
203    /// # Errors
204    ///
205    /// Forwards any error from [`predict_proba`](Self::predict_proba).
206    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207        let proba = self.predict_proba(x)?;
208        Ok(crate::log_proba(&proba))
209    }
210}
211
212impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for VotingClassifier<F> {
213    type Fitted = FittedVotingClassifier<F>;
214    type Error = FerroError;
215
216    /// Fit the voting classifier by training each decision tree on the full dataset.
217    ///
218    /// # Errors
219    ///
220    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
221    /// numbers of samples.
222    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
223    /// Returns [`FerroError::InvalidParameter`] if configuration is invalid.
224    fn fit(
225        &self,
226        x: &Array2<F>,
227        y: &Array1<usize>,
228    ) -> Result<FittedVotingClassifier<F>, FerroError> {
229        let n_samples = x.nrows();
230
231        if n_samples != y.len() {
232            return Err(FerroError::ShapeMismatch {
233                expected: vec![n_samples],
234                actual: vec![y.len()],
235                context: "y length must match number of samples in X".into(),
236            });
237        }
238        if n_samples == 0 {
239            return Err(FerroError::InsufficientSamples {
240                required: 1,
241                actual: 0,
242                context: "VotingClassifier requires at least one sample".into(),
243            });
244        }
245        if self.max_depths.is_empty() {
246            return Err(FerroError::InvalidParameter {
247                name: "max_depths".into(),
248                reason: "must contain at least one entry".into(),
249            });
250        }
251
252        // Determine unique classes from the full dataset.
253        let mut classes: Vec<usize> = y.iter().copied().collect();
254        classes.sort_unstable();
255        classes.dedup();
256
257        let mut trees = Vec::with_capacity(self.max_depths.len());
258        for &max_depth in &self.max_depths {
259            let tree = DecisionTreeClassifier::<F>::new()
260                .with_max_depth(max_depth)
261                .with_min_samples_split(self.min_samples_split)
262                .with_min_samples_leaf(self.min_samples_leaf)
263                .with_criterion(self.criterion);
264            let fitted = tree.fit(x, y)?;
265            trees.push(fitted);
266        }
267
268        Ok(FittedVotingClassifier { trees, classes })
269    }
270}
271
272impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingClassifier<F> {
273    type Output = Array1<usize>;
274    type Error = FerroError;
275
276    /// Predict class labels by majority vote across all trees.
277    ///
278    /// # Errors
279    ///
280    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
281    /// not match the fitted model.
282    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
283        let n_samples = x.nrows();
284        let n_classes = self.classes.len();
285
286        // Collect predictions from all trees.
287        let all_preds: Vec<Array1<usize>> = self
288            .trees
289            .iter()
290            .map(|tree| tree.predict(x))
291            .collect::<Result<Vec<_>, _>>()?;
292
293        let mut predictions = Array1::zeros(n_samples);
294        for i in 0..n_samples {
295            let mut votes = vec![0usize; n_classes];
296            for tree_preds in &all_preds {
297                let pred = tree_preds[i];
298                if let Some(class_idx) = self.classes.iter().position(|&c| c == pred) {
299                    votes[class_idx] += 1;
300                }
301            }
302            let winner = votes
303                .iter()
304                .enumerate()
305                .max_by_key(|&(_, &count)| count)
306                .map_or(0, |(idx, _)| idx);
307            predictions[i] = self.classes[winner];
308        }
309
310        Ok(predictions)
311    }
312}
313
314impl<F: Float + Send + Sync + 'static> HasClasses for FittedVotingClassifier<F> {
315    fn classes(&self) -> &[usize] {
316        &self.classes
317    }
318
319    fn n_classes(&self) -> usize {
320        self.classes.len()
321    }
322}
323
324// Pipeline integration.
325impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
326    for VotingClassifier<F>
327{
328    fn fit_pipeline(
329        &self,
330        x: &Array2<F>,
331        y: &Array1<F>,
332    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
333        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
334        let fitted = self.fit(x, &y_usize)?;
335        Ok(Box::new(FittedVotingClassifierPipelineAdapter(fitted)))
336    }
337}
338
339/// Pipeline adapter for `FittedVotingClassifier<F>`.
340struct FittedVotingClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
341    FittedVotingClassifier<F>,
342);
343
344impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
345    for FittedVotingClassifierPipelineAdapter<F>
346{
347    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
348        let preds = self.0.predict(x)?;
349        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
350    }
351}
352
353// ---------------------------------------------------------------------------
354// VotingRegressor
355// ---------------------------------------------------------------------------
356
357/// Voting ensemble regressor.
358///
359/// Trains multiple decision tree regressors with different hyperparameter
360/// configurations on the full dataset. Final predictions are the average
361/// across all trees.
362///
363/// Diversity is introduced by varying `max_depth` across the ensemble members.
364/// If no explicit depths are provided, a default set of depths is used.
365///
366/// # Type Parameters
367///
368/// - `F`: The floating-point type (`f32` or `f64`).
369///
370/// # Examples
371///
372/// ```
373/// use ferrolearn_tree::VotingRegressor;
374/// use ferrolearn_core::{Fit, Predict};
375/// use ndarray::{array, Array1, Array2};
376///
377/// let x = Array2::from_shape_vec((6, 2), vec![
378///     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
379///     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
380/// ]).unwrap();
381/// let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
382///
383/// let model = VotingRegressor::<f64>::new()
384///     .with_max_depths(vec![Some(2), Some(4), None]);
385/// let fitted = model.fit(&x, &y).unwrap();
386/// let preds = fitted.predict(&x).unwrap();
387/// assert_eq!(preds.len(), 6);
388/// ```
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct VotingRegressor<F> {
391    /// Maximum depth settings for each tree in the ensemble.
392    pub max_depths: Vec<Option<usize>>,
393    /// Minimum number of samples required to split an internal node.
394    pub min_samples_split: usize,
395    /// Minimum number of samples required in a leaf node.
396    pub min_samples_leaf: usize,
397    _marker: std::marker::PhantomData<F>,
398}
399
400impl<F: Float> VotingRegressor<F> {
401    /// Create a new `VotingRegressor` with default settings.
402    ///
403    /// Defaults: `max_depths = [Some(2), Some(4), Some(6), None]`,
404    /// `min_samples_split = 2`, `min_samples_leaf = 1`.
405    #[must_use]
406    pub fn new() -> Self {
407        Self {
408            max_depths: vec![Some(2), Some(4), Some(6), None],
409            min_samples_split: 2,
410            min_samples_leaf: 1,
411            _marker: std::marker::PhantomData,
412        }
413    }
414
415    /// Set the maximum depth settings for each ensemble member.
416    #[must_use]
417    pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
418        self.max_depths = max_depths;
419        self
420    }
421
422    /// Set the minimum number of samples required to split a node.
423    #[must_use]
424    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
425        self.min_samples_split = min_samples_split;
426        self
427    }
428
429    /// Set the minimum number of samples required in a leaf node.
430    #[must_use]
431    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
432        self.min_samples_leaf = min_samples_leaf;
433        self
434    }
435}
436
437impl<F: Float> Default for VotingRegressor<F> {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443// ---------------------------------------------------------------------------
444// FittedVotingRegressor
445// ---------------------------------------------------------------------------
446
447/// A fitted voting ensemble regressor.
448///
449/// Stores the individually fitted decision tree regressors and aggregates
450/// their predictions by averaging.
451#[derive(Debug, Clone)]
452pub struct FittedVotingRegressor<F> {
453    /// The fitted decision tree regressors.
454    trees: Vec<FittedDecisionTreeRegressor<F>>,
455}
456
457impl<F: Float + Send + Sync + 'static> FittedVotingRegressor<F> {
458    /// Returns the number of trees in the ensemble.
459    #[must_use]
460    pub fn n_estimators(&self) -> usize {
461        self.trees.len()
462    }
463
464    /// R² coefficient of determination on the given test data.
465    /// Equivalent to sklearn's `RegressorMixin.score`.
466    ///
467    /// # Errors
468    ///
469    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
470    /// the feature count does not match the training data.
471    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
472        if x.nrows() != y.len() {
473            return Err(FerroError::ShapeMismatch {
474                expected: vec![x.nrows()],
475                actual: vec![y.len()],
476                context: "y length must match number of samples in X".into(),
477            });
478        }
479        let preds = self.predict(x)?;
480        Ok(crate::r2_score(&preds, y))
481    }
482}
483
484impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for VotingRegressor<F> {
485    type Fitted = FittedVotingRegressor<F>;
486    type Error = FerroError;
487
488    /// Fit the voting regressor by training each decision tree on the full dataset.
489    ///
490    /// # Errors
491    ///
492    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
493    /// numbers of samples.
494    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
495    /// Returns [`FerroError::InvalidParameter`] if configuration is invalid.
496    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedVotingRegressor<F>, FerroError> {
497        let n_samples = x.nrows();
498
499        if n_samples != y.len() {
500            return Err(FerroError::ShapeMismatch {
501                expected: vec![n_samples],
502                actual: vec![y.len()],
503                context: "y length must match number of samples in X".into(),
504            });
505        }
506        if n_samples == 0 {
507            return Err(FerroError::InsufficientSamples {
508                required: 1,
509                actual: 0,
510                context: "VotingRegressor requires at least one sample".into(),
511            });
512        }
513        if self.max_depths.is_empty() {
514            return Err(FerroError::InvalidParameter {
515                name: "max_depths".into(),
516                reason: "must contain at least one entry".into(),
517            });
518        }
519
520        let mut trees = Vec::with_capacity(self.max_depths.len());
521        for &max_depth in &self.max_depths {
522            let tree = DecisionTreeRegressor::<F>::new()
523                .with_max_depth(max_depth)
524                .with_min_samples_split(self.min_samples_split)
525                .with_min_samples_leaf(self.min_samples_leaf);
526            let fitted = tree.fit(x, y)?;
527            trees.push(fitted);
528        }
529
530        Ok(FittedVotingRegressor { trees })
531    }
532}
533
534impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingRegressor<F> {
535    type Output = Array1<F>;
536    type Error = FerroError;
537
538    /// Predict target values by averaging across all trees.
539    ///
540    /// # Errors
541    ///
542    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
543    /// not match the fitted model.
544    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
545        let n_samples = x.nrows();
546        let n_trees_f = F::from(self.trees.len()).unwrap();
547
548        let all_preds: Vec<Array1<F>> = self
549            .trees
550            .iter()
551            .map(|tree| tree.predict(x))
552            .collect::<Result<Vec<_>, _>>()?;
553
554        let mut predictions = Array1::zeros(n_samples);
555        for i in 0..n_samples {
556            let mut sum = F::zero();
557            for tree_preds in &all_preds {
558                sum = sum + tree_preds[i];
559            }
560            predictions[i] = sum / n_trees_f;
561        }
562
563        Ok(predictions)
564    }
565}
566
567// Pipeline integration.
568impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for VotingRegressor<F> {
569    fn fit_pipeline(
570        &self,
571        x: &Array2<F>,
572        y: &Array1<F>,
573    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
574        let fitted = self.fit(x, y)?;
575        Ok(Box::new(fitted))
576    }
577}
578
579impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedVotingRegressor<F> {
580    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
581        self.predict(x)
582    }
583}
584
585// ---------------------------------------------------------------------------
586// Tests
587// ---------------------------------------------------------------------------
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use ndarray::array;
593
594    fn make_classification_data() -> (Array2<f64>, Array1<usize>) {
595        let x = Array2::from_shape_vec(
596            (8, 2),
597            vec![
598                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
599            ],
600        )
601        .unwrap();
602        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
603        (x, y)
604    }
605
606    fn make_regression_data() -> (Array2<f64>, Array1<f64>) {
607        let x = Array2::from_shape_vec(
608            (6, 2),
609            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
610        )
611        .unwrap();
612        let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
613        (x, y)
614    }
615
616    // -- VotingClassifier tests --
617
618    #[test]
619    fn test_voting_classifier_default() {
620        let model = VotingClassifier::<f64>::new();
621        assert_eq!(model.max_depths.len(), 4);
622        assert_eq!(model.min_samples_split, 2);
623        assert_eq!(model.min_samples_leaf, 1);
624    }
625
626    #[test]
627    fn test_voting_classifier_builder() {
628        let model = VotingClassifier::<f64>::new()
629            .with_max_depths(vec![Some(1), Some(3)])
630            .with_min_samples_split(5)
631            .with_min_samples_leaf(2)
632            .with_criterion(ClassificationCriterion::Entropy);
633        assert_eq!(model.max_depths.len(), 2);
634        assert_eq!(model.min_samples_split, 5);
635        assert_eq!(model.min_samples_leaf, 2);
636        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
637    }
638
639    #[test]
640    fn test_voting_classifier_fit_predict() {
641        let (x, y) = make_classification_data();
642        let model = VotingClassifier::<f64>::new();
643        let fitted = model.fit(&x, &y).unwrap();
644        let preds = fitted.predict(&x).unwrap();
645
646        assert_eq!(preds.len(), 8);
647        // On training data with a clear separation, should get most right.
648        for i in 0..4 {
649            assert_eq!(preds[i], 0, "sample {i} should be class 0");
650        }
651        for i in 4..8 {
652            assert_eq!(preds[i], 1, "sample {i} should be class 1");
653        }
654    }
655
656    #[test]
657    fn test_voting_classifier_has_classes() {
658        let (x, y) = make_classification_data();
659        let model = VotingClassifier::<f64>::new();
660        let fitted = model.fit(&x, &y).unwrap();
661        assert_eq!(fitted.classes(), &[0, 1]);
662        assert_eq!(fitted.n_classes(), 2);
663    }
664
665    #[test]
666    fn test_voting_classifier_n_estimators() {
667        let (x, y) = make_classification_data();
668        let model = VotingClassifier::<f64>::new().with_max_depths(vec![Some(2), Some(4), None]);
669        let fitted = model.fit(&x, &y).unwrap();
670        assert_eq!(fitted.n_estimators(), 3);
671    }
672
673    #[test]
674    fn test_voting_classifier_empty_data_error() {
675        let x = Array2::<f64>::zeros((0, 2));
676        let y = Array1::<usize>::zeros(0);
677        let model = VotingClassifier::<f64>::new();
678        let result = model.fit(&x, &y);
679        assert!(result.is_err());
680    }
681
682    #[test]
683    fn test_voting_classifier_shape_mismatch_error() {
684        let x = Array2::<f64>::zeros((5, 2));
685        let y = Array1::<usize>::zeros(3);
686        let model = VotingClassifier::<f64>::new();
687        let result = model.fit(&x, &y);
688        assert!(result.is_err());
689    }
690
691    #[test]
692    fn test_voting_classifier_empty_depths_error() {
693        let (x, y) = make_classification_data();
694        let model = VotingClassifier::<f64>::new().with_max_depths(vec![]);
695        let result = model.fit(&x, &y);
696        assert!(result.is_err());
697    }
698
699    #[test]
700    fn test_voting_classifier_multiclass() {
701        let x = Array2::from_shape_vec(
702            (9, 2),
703            vec![
704                1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 4.0, 4.0, 5.0, 4.0, 4.0, 5.0, 8.0, 8.0, 9.0, 8.0,
705                8.0, 9.0,
706            ],
707        )
708        .unwrap();
709        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
710
711        let model = VotingClassifier::<f64>::new();
712        let fitted = model.fit(&x, &y).unwrap();
713        let preds = fitted.predict(&x).unwrap();
714        assert_eq!(preds.len(), 9);
715        assert_eq!(fitted.n_classes(), 3);
716    }
717
718    // -- VotingRegressor tests --
719
720    #[test]
721    fn test_voting_regressor_default() {
722        let model = VotingRegressor::<f64>::new();
723        assert_eq!(model.max_depths.len(), 4);
724        assert_eq!(model.min_samples_split, 2);
725        assert_eq!(model.min_samples_leaf, 1);
726    }
727
728    #[test]
729    fn test_voting_regressor_builder() {
730        let model = VotingRegressor::<f64>::new()
731            .with_max_depths(vec![Some(1), Some(5)])
732            .with_min_samples_split(3)
733            .with_min_samples_leaf(2);
734        assert_eq!(model.max_depths.len(), 2);
735        assert_eq!(model.min_samples_split, 3);
736        assert_eq!(model.min_samples_leaf, 2);
737    }
738
739    #[test]
740    fn test_voting_regressor_fit_predict() {
741        let (x, y) = make_regression_data();
742        let model = VotingRegressor::<f64>::new();
743        let fitted = model.fit(&x, &y).unwrap();
744        let preds = fitted.predict(&x).unwrap();
745
746        assert_eq!(preds.len(), 6);
747        // The average of multiple trees should approximate the training targets
748        // on the training data.
749        for i in 0..6 {
750            let err = (preds[i] - y[i]).abs();
751            assert!(
752                err < 3.0,
753                "prediction {:.2} should be close to target {:.2}",
754                preds[i],
755                y[i]
756            );
757        }
758    }
759
760    #[test]
761    fn test_voting_regressor_n_estimators() {
762        let (x, y) = make_regression_data();
763        let model = VotingRegressor::<f64>::new().with_max_depths(vec![Some(2), None]);
764        let fitted = model.fit(&x, &y).unwrap();
765        assert_eq!(fitted.n_estimators(), 2);
766    }
767
768    #[test]
769    fn test_voting_regressor_empty_data_error() {
770        let x = Array2::<f64>::zeros((0, 2));
771        let y = Array1::<f64>::zeros(0);
772        let model = VotingRegressor::<f64>::new();
773        let result = model.fit(&x, &y);
774        assert!(result.is_err());
775    }
776
777    #[test]
778    fn test_voting_regressor_shape_mismatch_error() {
779        let x = Array2::<f64>::zeros((5, 2));
780        let y = Array1::<f64>::zeros(3);
781        let model = VotingRegressor::<f64>::new();
782        let result = model.fit(&x, &y);
783        assert!(result.is_err());
784    }
785
786    #[test]
787    fn test_voting_regressor_empty_depths_error() {
788        let (x, y) = make_regression_data();
789        let model = VotingRegressor::<f64>::new().with_max_depths(vec![]);
790        let result = model.fit(&x, &y);
791        assert!(result.is_err());
792    }
793
794    #[test]
795    fn test_voting_regressor_averaging() {
796        // With a single tree (unlimited depth), predictions on training data
797        // should exactly match the targets.
798        let (x, y) = make_regression_data();
799        let model = VotingRegressor::<f64>::new().with_max_depths(vec![None]);
800        let fitted = model.fit(&x, &y).unwrap();
801        let preds = fitted.predict(&x).unwrap();
802
803        for i in 0..6 {
804            assert!(
805                (preds[i] - y[i]).abs() < 1e-10,
806                "single unlimited tree should overfit training data"
807            );
808        }
809    }
810
811    #[test]
812    fn test_voting_classifier_f32() {
813        let x = Array2::<f32>::from_shape_vec(
814            (6, 2),
815            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
816        )
817        .unwrap();
818        let y = array![0, 0, 0, 1, 1, 1];
819        let model = VotingClassifier::<f32>::new();
820        let fitted = model.fit(&x, &y).unwrap();
821        let preds = fitted.predict(&x).unwrap();
822        assert_eq!(preds.len(), 6);
823    }
824
825    #[test]
826    fn test_voting_regressor_f32() {
827        let x = Array2::<f32>::from_shape_vec(
828            (6, 2),
829            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
830        )
831        .unwrap();
832        let y = array![1.0_f32, 2.0, 3.0, 5.0, 6.0, 7.0];
833        let model = VotingRegressor::<f32>::new();
834        let fitted = model.fit(&x, &y).unwrap();
835        let preds = fitted.predict(&x).unwrap();
836        assert_eq!(preds.len(), 6);
837    }
838}