Skip to main content

ferrolearn_tree/
voting.rs

1//! Voting ensemble classifiers and regressors.
2//!
3//! This module provides [`VotingClassifier`] and [`VotingRegressor`], which
4//! train multiple decision trees with different hyperparameter configurations
5//! on the full dataset and aggregate their predictions (majority vote for
6//! classification, averaging for regression).
7//!
8//! Unlike [`RandomForestClassifier`](crate::RandomForestClassifier), voting
9//! ensembles do **not** use bootstrap sampling — each tree sees the entire
10//! dataset. Diversity comes from varying the tree hyperparameters.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::VotingClassifier;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let x = Array2::from_shape_vec((8, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
22//! ]).unwrap();
23//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
24//!
25//! let model = VotingClassifier::<f64>::new()
26//!     .with_max_depths(vec![Some(2), Some(3), Some(5), None]);
27//! let fitted = model.fit(&x, &y).unwrap();
28//! let preds = fitted.predict(&x).unwrap();
29//! ```
30
31use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::HasClasses;
33use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
34use ferrolearn_core::traits::{Fit, Predict};
35use ndarray::{Array1, Array2};
36use num_traits::{Float, FromPrimitive, ToPrimitive};
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    ClassificationCriterion, DecisionTreeClassifier, DecisionTreeRegressor,
41    FittedDecisionTreeClassifier, FittedDecisionTreeRegressor,
42};
43
44// ---------------------------------------------------------------------------
45// VotingClassifier
46// ---------------------------------------------------------------------------
47
48/// Voting ensemble classifier.
49///
50/// Trains multiple decision tree classifiers with different hyperparameter
51/// configurations on the full dataset. Final predictions are made by majority
52/// vote across all trees.
53///
54/// Diversity is introduced by varying `max_depth` across the ensemble members.
55/// If no explicit depths are provided, a default set of depths is used.
56///
57/// # Type Parameters
58///
59/// - `F`: The floating-point type (`f32` or `f64`).
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct VotingClassifier<F> {
62    /// Maximum depth settings for each tree in the ensemble.
63    /// Each entry produces one decision tree.
64    pub max_depths: Vec<Option<usize>>,
65    /// Minimum number of samples required to split an internal node.
66    pub min_samples_split: usize,
67    /// Minimum number of samples required in a leaf node.
68    pub min_samples_leaf: usize,
69    /// Splitting criterion for all trees.
70    pub criterion: ClassificationCriterion,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> VotingClassifier<F> {
75    /// Create a new `VotingClassifier` with default settings.
76    ///
77    /// Defaults: `max_depths = [Some(2), Some(4), Some(6), None]`,
78    /// `min_samples_split = 2`, `min_samples_leaf = 1`, `criterion = Gini`.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            max_depths: vec![Some(2), Some(4), Some(6), None],
83            min_samples_split: 2,
84            min_samples_leaf: 1,
85            criterion: ClassificationCriterion::Gini,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Set the maximum depth settings for each ensemble member.
91    ///
92    /// Each entry in the vector produces one decision tree.
93    #[must_use]
94    pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
95        self.max_depths = max_depths;
96        self
97    }
98
99    /// Set the minimum number of samples required to split a node.
100    #[must_use]
101    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
102        self.min_samples_split = min_samples_split;
103        self
104    }
105
106    /// Set the minimum number of samples required in a leaf node.
107    #[must_use]
108    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
109        self.min_samples_leaf = min_samples_leaf;
110        self
111    }
112
113    /// Set the splitting criterion for all trees.
114    #[must_use]
115    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
116        self.criterion = criterion;
117        self
118    }
119}
120
121impl<F: Float> Default for VotingClassifier<F> {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127// ---------------------------------------------------------------------------
128// FittedVotingClassifier
129// ---------------------------------------------------------------------------
130
131/// A fitted voting ensemble classifier.
132///
133/// Stores the individually fitted decision trees and aggregates their
134/// predictions by majority vote.
135#[derive(Debug, Clone)]
136pub struct FittedVotingClassifier<F> {
137    /// The fitted decision tree classifiers.
138    trees: Vec<FittedDecisionTreeClassifier<F>>,
139    /// Sorted unique class labels.
140    classes: Vec<usize>,
141}
142
143impl<F: Float + Send + Sync + 'static> FittedVotingClassifier<F> {
144    /// Returns the number of trees in the ensemble.
145    #[must_use]
146    pub fn n_estimators(&self) -> usize {
147        self.trees.len()
148    }
149}
150
151impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for VotingClassifier<F> {
152    type Fitted = FittedVotingClassifier<F>;
153    type Error = FerroError;
154
155    /// Fit the voting classifier by training each decision tree on the full dataset.
156    ///
157    /// # Errors
158    ///
159    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
160    /// numbers of samples.
161    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
162    /// Returns [`FerroError::InvalidParameter`] if configuration is invalid.
163    fn fit(
164        &self,
165        x: &Array2<F>,
166        y: &Array1<usize>,
167    ) -> Result<FittedVotingClassifier<F>, FerroError> {
168        let n_samples = x.nrows();
169
170        if n_samples != y.len() {
171            return Err(FerroError::ShapeMismatch {
172                expected: vec![n_samples],
173                actual: vec![y.len()],
174                context: "y length must match number of samples in X".into(),
175            });
176        }
177        if n_samples == 0 {
178            return Err(FerroError::InsufficientSamples {
179                required: 1,
180                actual: 0,
181                context: "VotingClassifier requires at least one sample".into(),
182            });
183        }
184        if self.max_depths.is_empty() {
185            return Err(FerroError::InvalidParameter {
186                name: "max_depths".into(),
187                reason: "must contain at least one entry".into(),
188            });
189        }
190
191        // Determine unique classes from the full dataset.
192        let mut classes: Vec<usize> = y.iter().copied().collect();
193        classes.sort_unstable();
194        classes.dedup();
195
196        let mut trees = Vec::with_capacity(self.max_depths.len());
197        for &max_depth in &self.max_depths {
198            let tree = DecisionTreeClassifier::<F>::new()
199                .with_max_depth(max_depth)
200                .with_min_samples_split(self.min_samples_split)
201                .with_min_samples_leaf(self.min_samples_leaf)
202                .with_criterion(self.criterion);
203            let fitted = tree.fit(x, y)?;
204            trees.push(fitted);
205        }
206
207        Ok(FittedVotingClassifier { trees, classes })
208    }
209}
210
211impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingClassifier<F> {
212    type Output = Array1<usize>;
213    type Error = FerroError;
214
215    /// Predict class labels by majority vote across all trees.
216    ///
217    /// # Errors
218    ///
219    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
220    /// not match the fitted model.
221    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
222        let n_samples = x.nrows();
223        let n_classes = self.classes.len();
224
225        // Collect predictions from all trees.
226        let all_preds: Vec<Array1<usize>> = self
227            .trees
228            .iter()
229            .map(|tree| tree.predict(x))
230            .collect::<Result<Vec<_>, _>>()?;
231
232        let mut predictions = Array1::zeros(n_samples);
233        for i in 0..n_samples {
234            let mut votes = vec![0usize; n_classes];
235            for tree_preds in &all_preds {
236                let pred = tree_preds[i];
237                if let Some(class_idx) = self.classes.iter().position(|&c| c == pred) {
238                    votes[class_idx] += 1;
239                }
240            }
241            let winner = votes
242                .iter()
243                .enumerate()
244                .max_by_key(|&(_, &count)| count)
245                .map(|(idx, _)| idx)
246                .unwrap_or(0);
247            predictions[i] = self.classes[winner];
248        }
249
250        Ok(predictions)
251    }
252}
253
254impl<F: Float + Send + Sync + 'static> HasClasses for FittedVotingClassifier<F> {
255    fn classes(&self) -> &[usize] {
256        &self.classes
257    }
258
259    fn n_classes(&self) -> usize {
260        self.classes.len()
261    }
262}
263
264// Pipeline integration.
265impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
266    for VotingClassifier<F>
267{
268    fn fit_pipeline(
269        &self,
270        x: &Array2<F>,
271        y: &Array1<F>,
272    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
273        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
274        let fitted = self.fit(x, &y_usize)?;
275        Ok(Box::new(FittedVotingClassifierPipelineAdapter(fitted)))
276    }
277}
278
279/// Pipeline adapter for `FittedVotingClassifier<F>`.
280struct FittedVotingClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
281    FittedVotingClassifier<F>,
282);
283
284impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
285    for FittedVotingClassifierPipelineAdapter<F>
286{
287    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
288        let preds = self.0.predict(x)?;
289        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
290    }
291}
292
293// ---------------------------------------------------------------------------
294// VotingRegressor
295// ---------------------------------------------------------------------------
296
297/// Voting ensemble regressor.
298///
299/// Trains multiple decision tree regressors with different hyperparameter
300/// configurations on the full dataset. Final predictions are the average
301/// across all trees.
302///
303/// Diversity is introduced by varying `max_depth` across the ensemble members.
304/// If no explicit depths are provided, a default set of depths is used.
305///
306/// # Type Parameters
307///
308/// - `F`: The floating-point type (`f32` or `f64`).
309///
310/// # Examples
311///
312/// ```
313/// use ferrolearn_tree::VotingRegressor;
314/// use ferrolearn_core::{Fit, Predict};
315/// use ndarray::{array, Array1, Array2};
316///
317/// let x = Array2::from_shape_vec((6, 2), vec![
318///     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
319///     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
320/// ]).unwrap();
321/// let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
322///
323/// let model = VotingRegressor::<f64>::new()
324///     .with_max_depths(vec![Some(2), Some(4), None]);
325/// let fitted = model.fit(&x, &y).unwrap();
326/// let preds = fitted.predict(&x).unwrap();
327/// assert_eq!(preds.len(), 6);
328/// ```
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct VotingRegressor<F> {
331    /// Maximum depth settings for each tree in the ensemble.
332    pub max_depths: Vec<Option<usize>>,
333    /// Minimum number of samples required to split an internal node.
334    pub min_samples_split: usize,
335    /// Minimum number of samples required in a leaf node.
336    pub min_samples_leaf: usize,
337    _marker: std::marker::PhantomData<F>,
338}
339
340impl<F: Float> VotingRegressor<F> {
341    /// Create a new `VotingRegressor` with default settings.
342    ///
343    /// Defaults: `max_depths = [Some(2), Some(4), Some(6), None]`,
344    /// `min_samples_split = 2`, `min_samples_leaf = 1`.
345    #[must_use]
346    pub fn new() -> Self {
347        Self {
348            max_depths: vec![Some(2), Some(4), Some(6), None],
349            min_samples_split: 2,
350            min_samples_leaf: 1,
351            _marker: std::marker::PhantomData,
352        }
353    }
354
355    /// Set the maximum depth settings for each ensemble member.
356    #[must_use]
357    pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
358        self.max_depths = max_depths;
359        self
360    }
361
362    /// Set the minimum number of samples required to split a node.
363    #[must_use]
364    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
365        self.min_samples_split = min_samples_split;
366        self
367    }
368
369    /// Set the minimum number of samples required in a leaf node.
370    #[must_use]
371    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
372        self.min_samples_leaf = min_samples_leaf;
373        self
374    }
375}
376
377impl<F: Float> Default for VotingRegressor<F> {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383// ---------------------------------------------------------------------------
384// FittedVotingRegressor
385// ---------------------------------------------------------------------------
386
387/// A fitted voting ensemble regressor.
388///
389/// Stores the individually fitted decision tree regressors and aggregates
390/// their predictions by averaging.
391#[derive(Debug, Clone)]
392pub struct FittedVotingRegressor<F> {
393    /// The fitted decision tree regressors.
394    trees: Vec<FittedDecisionTreeRegressor<F>>,
395}
396
397impl<F: Float + Send + Sync + 'static> FittedVotingRegressor<F> {
398    /// Returns the number of trees in the ensemble.
399    #[must_use]
400    pub fn n_estimators(&self) -> usize {
401        self.trees.len()
402    }
403}
404
405impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for VotingRegressor<F> {
406    type Fitted = FittedVotingRegressor<F>;
407    type Error = FerroError;
408
409    /// Fit the voting regressor by training each decision tree on the full dataset.
410    ///
411    /// # Errors
412    ///
413    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
414    /// numbers of samples.
415    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
416    /// Returns [`FerroError::InvalidParameter`] if configuration is invalid.
417    fn fit(
418        &self,
419        x: &Array2<F>,
420        y: &Array1<F>,
421    ) -> Result<FittedVotingRegressor<F>, FerroError> {
422        let n_samples = x.nrows();
423
424        if n_samples != y.len() {
425            return Err(FerroError::ShapeMismatch {
426                expected: vec![n_samples],
427                actual: vec![y.len()],
428                context: "y length must match number of samples in X".into(),
429            });
430        }
431        if n_samples == 0 {
432            return Err(FerroError::InsufficientSamples {
433                required: 1,
434                actual: 0,
435                context: "VotingRegressor requires at least one sample".into(),
436            });
437        }
438        if self.max_depths.is_empty() {
439            return Err(FerroError::InvalidParameter {
440                name: "max_depths".into(),
441                reason: "must contain at least one entry".into(),
442            });
443        }
444
445        let mut trees = Vec::with_capacity(self.max_depths.len());
446        for &max_depth in &self.max_depths {
447            let tree = DecisionTreeRegressor::<F>::new()
448                .with_max_depth(max_depth)
449                .with_min_samples_split(self.min_samples_split)
450                .with_min_samples_leaf(self.min_samples_leaf);
451            let fitted = tree.fit(x, y)?;
452            trees.push(fitted);
453        }
454
455        Ok(FittedVotingRegressor { trees })
456    }
457}
458
459impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingRegressor<F> {
460    type Output = Array1<F>;
461    type Error = FerroError;
462
463    /// Predict target values by averaging across all trees.
464    ///
465    /// # Errors
466    ///
467    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
468    /// not match the fitted model.
469    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
470        let n_samples = x.nrows();
471        let n_trees_f = F::from(self.trees.len()).unwrap();
472
473        let all_preds: Vec<Array1<F>> = self
474            .trees
475            .iter()
476            .map(|tree| tree.predict(x))
477            .collect::<Result<Vec<_>, _>>()?;
478
479        let mut predictions = Array1::zeros(n_samples);
480        for i in 0..n_samples {
481            let mut sum = F::zero();
482            for tree_preds in &all_preds {
483                sum = sum + tree_preds[i];
484            }
485            predictions[i] = sum / n_trees_f;
486        }
487
488        Ok(predictions)
489    }
490}
491
492// Pipeline integration.
493impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for VotingRegressor<F> {
494    fn fit_pipeline(
495        &self,
496        x: &Array2<F>,
497        y: &Array1<F>,
498    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
499        let fitted = self.fit(x, y)?;
500        Ok(Box::new(fitted))
501    }
502}
503
504impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedVotingRegressor<F> {
505    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
506        self.predict(x)
507    }
508}
509
510// ---------------------------------------------------------------------------
511// Tests
512// ---------------------------------------------------------------------------
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use ndarray::array;
518
519    fn make_classification_data() -> (Array2<f64>, Array1<usize>) {
520        let x = Array2::from_shape_vec(
521            (8, 2),
522            vec![
523                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
524            ],
525        )
526        .unwrap();
527        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
528        (x, y)
529    }
530
531    fn make_regression_data() -> (Array2<f64>, Array1<f64>) {
532        let x = Array2::from_shape_vec(
533            (6, 2),
534            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
535        )
536        .unwrap();
537        let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
538        (x, y)
539    }
540
541    // -- VotingClassifier tests --
542
543    #[test]
544    fn test_voting_classifier_default() {
545        let model = VotingClassifier::<f64>::new();
546        assert_eq!(model.max_depths.len(), 4);
547        assert_eq!(model.min_samples_split, 2);
548        assert_eq!(model.min_samples_leaf, 1);
549    }
550
551    #[test]
552    fn test_voting_classifier_builder() {
553        let model = VotingClassifier::<f64>::new()
554            .with_max_depths(vec![Some(1), Some(3)])
555            .with_min_samples_split(5)
556            .with_min_samples_leaf(2)
557            .with_criterion(ClassificationCriterion::Entropy);
558        assert_eq!(model.max_depths.len(), 2);
559        assert_eq!(model.min_samples_split, 5);
560        assert_eq!(model.min_samples_leaf, 2);
561        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
562    }
563
564    #[test]
565    fn test_voting_classifier_fit_predict() {
566        let (x, y) = make_classification_data();
567        let model = VotingClassifier::<f64>::new();
568        let fitted = model.fit(&x, &y).unwrap();
569        let preds = fitted.predict(&x).unwrap();
570
571        assert_eq!(preds.len(), 8);
572        // On training data with a clear separation, should get most right.
573        for i in 0..4 {
574            assert_eq!(preds[i], 0, "sample {i} should be class 0");
575        }
576        for i in 4..8 {
577            assert_eq!(preds[i], 1, "sample {i} should be class 1");
578        }
579    }
580
581    #[test]
582    fn test_voting_classifier_has_classes() {
583        let (x, y) = make_classification_data();
584        let model = VotingClassifier::<f64>::new();
585        let fitted = model.fit(&x, &y).unwrap();
586        assert_eq!(fitted.classes(), &[0, 1]);
587        assert_eq!(fitted.n_classes(), 2);
588    }
589
590    #[test]
591    fn test_voting_classifier_n_estimators() {
592        let (x, y) = make_classification_data();
593        let model = VotingClassifier::<f64>::new()
594            .with_max_depths(vec![Some(2), Some(4), None]);
595        let fitted = model.fit(&x, &y).unwrap();
596        assert_eq!(fitted.n_estimators(), 3);
597    }
598
599    #[test]
600    fn test_voting_classifier_empty_data_error() {
601        let x = Array2::<f64>::zeros((0, 2));
602        let y = Array1::<usize>::zeros(0);
603        let model = VotingClassifier::<f64>::new();
604        let result = model.fit(&x, &y);
605        assert!(result.is_err());
606    }
607
608    #[test]
609    fn test_voting_classifier_shape_mismatch_error() {
610        let x = Array2::<f64>::zeros((5, 2));
611        let y = Array1::<usize>::zeros(3);
612        let model = VotingClassifier::<f64>::new();
613        let result = model.fit(&x, &y);
614        assert!(result.is_err());
615    }
616
617    #[test]
618    fn test_voting_classifier_empty_depths_error() {
619        let (x, y) = make_classification_data();
620        let model = VotingClassifier::<f64>::new().with_max_depths(vec![]);
621        let result = model.fit(&x, &y);
622        assert!(result.is_err());
623    }
624
625    #[test]
626    fn test_voting_classifier_multiclass() {
627        let x = Array2::from_shape_vec(
628            (9, 2),
629            vec![
630                1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 4.0, 4.0, 5.0, 4.0, 4.0, 5.0, 8.0, 8.0, 9.0, 8.0,
631                8.0, 9.0,
632            ],
633        )
634        .unwrap();
635        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
636
637        let model = VotingClassifier::<f64>::new();
638        let fitted = model.fit(&x, &y).unwrap();
639        let preds = fitted.predict(&x).unwrap();
640        assert_eq!(preds.len(), 9);
641        assert_eq!(fitted.n_classes(), 3);
642    }
643
644    // -- VotingRegressor tests --
645
646    #[test]
647    fn test_voting_regressor_default() {
648        let model = VotingRegressor::<f64>::new();
649        assert_eq!(model.max_depths.len(), 4);
650        assert_eq!(model.min_samples_split, 2);
651        assert_eq!(model.min_samples_leaf, 1);
652    }
653
654    #[test]
655    fn test_voting_regressor_builder() {
656        let model = VotingRegressor::<f64>::new()
657            .with_max_depths(vec![Some(1), Some(5)])
658            .with_min_samples_split(3)
659            .with_min_samples_leaf(2);
660        assert_eq!(model.max_depths.len(), 2);
661        assert_eq!(model.min_samples_split, 3);
662        assert_eq!(model.min_samples_leaf, 2);
663    }
664
665    #[test]
666    fn test_voting_regressor_fit_predict() {
667        let (x, y) = make_regression_data();
668        let model = VotingRegressor::<f64>::new();
669        let fitted = model.fit(&x, &y).unwrap();
670        let preds = fitted.predict(&x).unwrap();
671
672        assert_eq!(preds.len(), 6);
673        // The average of multiple trees should approximate the training targets
674        // on the training data.
675        for i in 0..6 {
676            let err = (preds[i] - y[i]).abs();
677            assert!(
678                err < 3.0,
679                "prediction {:.2} should be close to target {:.2}",
680                preds[i],
681                y[i]
682            );
683        }
684    }
685
686    #[test]
687    fn test_voting_regressor_n_estimators() {
688        let (x, y) = make_regression_data();
689        let model = VotingRegressor::<f64>::new()
690            .with_max_depths(vec![Some(2), None]);
691        let fitted = model.fit(&x, &y).unwrap();
692        assert_eq!(fitted.n_estimators(), 2);
693    }
694
695    #[test]
696    fn test_voting_regressor_empty_data_error() {
697        let x = Array2::<f64>::zeros((0, 2));
698        let y = Array1::<f64>::zeros(0);
699        let model = VotingRegressor::<f64>::new();
700        let result = model.fit(&x, &y);
701        assert!(result.is_err());
702    }
703
704    #[test]
705    fn test_voting_regressor_shape_mismatch_error() {
706        let x = Array2::<f64>::zeros((5, 2));
707        let y = Array1::<f64>::zeros(3);
708        let model = VotingRegressor::<f64>::new();
709        let result = model.fit(&x, &y);
710        assert!(result.is_err());
711    }
712
713    #[test]
714    fn test_voting_regressor_empty_depths_error() {
715        let (x, y) = make_regression_data();
716        let model = VotingRegressor::<f64>::new().with_max_depths(vec![]);
717        let result = model.fit(&x, &y);
718        assert!(result.is_err());
719    }
720
721    #[test]
722    fn test_voting_regressor_averaging() {
723        // With a single tree (unlimited depth), predictions on training data
724        // should exactly match the targets.
725        let (x, y) = make_regression_data();
726        let model = VotingRegressor::<f64>::new().with_max_depths(vec![None]);
727        let fitted = model.fit(&x, &y).unwrap();
728        let preds = fitted.predict(&x).unwrap();
729
730        for i in 0..6 {
731            assert!(
732                (preds[i] - y[i]).abs() < 1e-10,
733                "single unlimited tree should overfit training data"
734            );
735        }
736    }
737
738    #[test]
739    fn test_voting_classifier_f32() {
740        let x = Array2::<f32>::from_shape_vec(
741            (6, 2),
742            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
743        )
744        .unwrap();
745        let y = array![0, 0, 0, 1, 1, 1];
746        let model = VotingClassifier::<f32>::new();
747        let fitted = model.fit(&x, &y).unwrap();
748        let preds = fitted.predict(&x).unwrap();
749        assert_eq!(preds.len(), 6);
750    }
751
752    #[test]
753    fn test_voting_regressor_f32() {
754        let x = Array2::<f32>::from_shape_vec(
755            (6, 2),
756            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
757        )
758        .unwrap();
759        let y = array![1.0_f32, 2.0, 3.0, 5.0, 6.0, 7.0];
760        let model = VotingRegressor::<f32>::new();
761        let fitted = model.fit(&x, &y).unwrap();
762        let preds = fitted.predict(&x).unwrap();
763        assert_eq!(preds.len(), 6);
764    }
765}