Skip to main content

ferrolearn_tree/
voting.rs

1//! Voting ensemble classifiers and regressors.
2//!
3//! This module provides [`VotingClassifier`] and [`VotingRegressor`], which
4//! train multiple decision trees with different hyperparameter configurations
5//! on the full dataset and aggregate their predictions (majority vote for
6//! classification, averaging for regression).
7//!
8//! Unlike [`RandomForestClassifier`](crate::RandomForestClassifier), voting
9//! ensembles do **not** use bootstrap sampling — each tree sees the entire
10//! dataset. Diversity comes from varying the tree hyperparameters.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::VotingClassifier;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let x = Array2::from_shape_vec((8, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
22//! ]).unwrap();
23//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
24//!
25//! let model = VotingClassifier::<f64>::new()
26//!     .with_max_depths(vec![Some(2), Some(3), Some(5), None]);
27//! let fitted = model.fit(&x, &y).unwrap();
28//! let preds = fitted.predict(&x).unwrap();
29//! ```
30
31use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::HasClasses;
33use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
34use ferrolearn_core::traits::{Fit, Predict};
35use ndarray::{Array1, Array2};
36use num_traits::{Float, FromPrimitive, ToPrimitive};
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    ClassificationCriterion, DecisionTreeClassifier, DecisionTreeRegressor,
41    FittedDecisionTreeClassifier, FittedDecisionTreeRegressor,
42};
43
44// ---------------------------------------------------------------------------
45// VotingClassifier
46// ---------------------------------------------------------------------------
47
48/// Voting ensemble classifier.
49///
50/// Trains multiple decision tree classifiers with different hyperparameter
51/// configurations on the full dataset. Final predictions are made by majority
52/// vote across all trees.
53///
54/// Diversity is introduced by varying `max_depth` across the ensemble members.
55/// If no explicit depths are provided, a default set of depths is used.
56///
57/// # Type Parameters
58///
59/// - `F`: The floating-point type (`f32` or `f64`).
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct VotingClassifier<F> {
62    /// Maximum depth settings for each tree in the ensemble.
63    /// Each entry produces one decision tree.
64    pub max_depths: Vec<Option<usize>>,
65    /// Minimum number of samples required to split an internal node.
66    pub min_samples_split: usize,
67    /// Minimum number of samples required in a leaf node.
68    pub min_samples_leaf: usize,
69    /// Splitting criterion for all trees.
70    pub criterion: ClassificationCriterion,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> VotingClassifier<F> {
75    /// Create a new `VotingClassifier` with default settings.
76    ///
77    /// Defaults: `max_depths = [Some(2), Some(4), Some(6), None]`,
78    /// `min_samples_split = 2`, `min_samples_leaf = 1`, `criterion = Gini`.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            max_depths: vec![Some(2), Some(4), Some(6), None],
83            min_samples_split: 2,
84            min_samples_leaf: 1,
85            criterion: ClassificationCriterion::Gini,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Set the maximum depth settings for each ensemble member.
91    ///
92    /// Each entry in the vector produces one decision tree.
93    #[must_use]
94    pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
95        self.max_depths = max_depths;
96        self
97    }
98
99    /// Set the minimum number of samples required to split a node.
100    #[must_use]
101    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
102        self.min_samples_split = min_samples_split;
103        self
104    }
105
106    /// Set the minimum number of samples required in a leaf node.
107    #[must_use]
108    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
109        self.min_samples_leaf = min_samples_leaf;
110        self
111    }
112
113    /// Set the splitting criterion for all trees.
114    #[must_use]
115    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
116        self.criterion = criterion;
117        self
118    }
119}
120
121impl<F: Float> Default for VotingClassifier<F> {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127// ---------------------------------------------------------------------------
128// FittedVotingClassifier
129// ---------------------------------------------------------------------------
130
131/// A fitted voting ensemble classifier.
132///
133/// Stores the individually fitted decision trees and aggregates their
134/// predictions by majority vote.
135#[derive(Debug, Clone)]
136pub struct FittedVotingClassifier<F> {
137    /// The fitted decision tree classifiers.
138    trees: Vec<FittedDecisionTreeClassifier<F>>,
139    /// Sorted unique class labels.
140    classes: Vec<usize>,
141}
142
143impl<F: Float + Send + Sync + 'static> FittedVotingClassifier<F> {
144    /// Returns the number of trees in the ensemble.
145    #[must_use]
146    pub fn n_estimators(&self) -> usize {
147        self.trees.len()
148    }
149}
150
151impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for VotingClassifier<F> {
152    type Fitted = FittedVotingClassifier<F>;
153    type Error = FerroError;
154
155    /// Fit the voting classifier by training each decision tree on the full dataset.
156    ///
157    /// # Errors
158    ///
159    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
160    /// numbers of samples.
161    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
162    /// Returns [`FerroError::InvalidParameter`] if configuration is invalid.
163    fn fit(
164        &self,
165        x: &Array2<F>,
166        y: &Array1<usize>,
167    ) -> Result<FittedVotingClassifier<F>, FerroError> {
168        let n_samples = x.nrows();
169
170        if n_samples != y.len() {
171            return Err(FerroError::ShapeMismatch {
172                expected: vec![n_samples],
173                actual: vec![y.len()],
174                context: "y length must match number of samples in X".into(),
175            });
176        }
177        if n_samples == 0 {
178            return Err(FerroError::InsufficientSamples {
179                required: 1,
180                actual: 0,
181                context: "VotingClassifier requires at least one sample".into(),
182            });
183        }
184        if self.max_depths.is_empty() {
185            return Err(FerroError::InvalidParameter {
186                name: "max_depths".into(),
187                reason: "must contain at least one entry".into(),
188            });
189        }
190
191        // Determine unique classes from the full dataset.
192        let mut classes: Vec<usize> = y.iter().copied().collect();
193        classes.sort_unstable();
194        classes.dedup();
195
196        let mut trees = Vec::with_capacity(self.max_depths.len());
197        for &max_depth in &self.max_depths {
198            let tree = DecisionTreeClassifier::<F>::new()
199                .with_max_depth(max_depth)
200                .with_min_samples_split(self.min_samples_split)
201                .with_min_samples_leaf(self.min_samples_leaf)
202                .with_criterion(self.criterion);
203            let fitted = tree.fit(x, y)?;
204            trees.push(fitted);
205        }
206
207        Ok(FittedVotingClassifier { trees, classes })
208    }
209}
210
211impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingClassifier<F> {
212    type Output = Array1<usize>;
213    type Error = FerroError;
214
215    /// Predict class labels by majority vote across all trees.
216    ///
217    /// # Errors
218    ///
219    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
220    /// not match the fitted model.
221    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
222        let n_samples = x.nrows();
223        let n_classes = self.classes.len();
224
225        // Collect predictions from all trees.
226        let all_preds: Vec<Array1<usize>> = self
227            .trees
228            .iter()
229            .map(|tree| tree.predict(x))
230            .collect::<Result<Vec<_>, _>>()?;
231
232        let mut predictions = Array1::zeros(n_samples);
233        for i in 0..n_samples {
234            let mut votes = vec![0usize; n_classes];
235            for tree_preds in &all_preds {
236                let pred = tree_preds[i];
237                if let Some(class_idx) = self.classes.iter().position(|&c| c == pred) {
238                    votes[class_idx] += 1;
239                }
240            }
241            let winner = votes
242                .iter()
243                .enumerate()
244                .max_by_key(|&(_, &count)| count)
245                .map_or(0, |(idx, _)| idx);
246            predictions[i] = self.classes[winner];
247        }
248
249        Ok(predictions)
250    }
251}
252
253impl<F: Float + Send + Sync + 'static> HasClasses for FittedVotingClassifier<F> {
254    fn classes(&self) -> &[usize] {
255        &self.classes
256    }
257
258    fn n_classes(&self) -> usize {
259        self.classes.len()
260    }
261}
262
263// Pipeline integration.
264impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
265    for VotingClassifier<F>
266{
267    fn fit_pipeline(
268        &self,
269        x: &Array2<F>,
270        y: &Array1<F>,
271    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
272        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
273        let fitted = self.fit(x, &y_usize)?;
274        Ok(Box::new(FittedVotingClassifierPipelineAdapter(fitted)))
275    }
276}
277
278/// Pipeline adapter for `FittedVotingClassifier<F>`.
279struct FittedVotingClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
280    FittedVotingClassifier<F>,
281);
282
283impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
284    for FittedVotingClassifierPipelineAdapter<F>
285{
286    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
287        let preds = self.0.predict(x)?;
288        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
289    }
290}
291
292// ---------------------------------------------------------------------------
293// VotingRegressor
294// ---------------------------------------------------------------------------
295
296/// Voting ensemble regressor.
297///
298/// Trains multiple decision tree regressors with different hyperparameter
299/// configurations on the full dataset. Final predictions are the average
300/// across all trees.
301///
302/// Diversity is introduced by varying `max_depth` across the ensemble members.
303/// If no explicit depths are provided, a default set of depths is used.
304///
305/// # Type Parameters
306///
307/// - `F`: The floating-point type (`f32` or `f64`).
308///
309/// # Examples
310///
311/// ```
312/// use ferrolearn_tree::VotingRegressor;
313/// use ferrolearn_core::{Fit, Predict};
314/// use ndarray::{array, Array1, Array2};
315///
316/// let x = Array2::from_shape_vec((6, 2), vec![
317///     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
318///     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
319/// ]).unwrap();
320/// let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
321///
322/// let model = VotingRegressor::<f64>::new()
323///     .with_max_depths(vec![Some(2), Some(4), None]);
324/// let fitted = model.fit(&x, &y).unwrap();
325/// let preds = fitted.predict(&x).unwrap();
326/// assert_eq!(preds.len(), 6);
327/// ```
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct VotingRegressor<F> {
330    /// Maximum depth settings for each tree in the ensemble.
331    pub max_depths: Vec<Option<usize>>,
332    /// Minimum number of samples required to split an internal node.
333    pub min_samples_split: usize,
334    /// Minimum number of samples required in a leaf node.
335    pub min_samples_leaf: usize,
336    _marker: std::marker::PhantomData<F>,
337}
338
339impl<F: Float> VotingRegressor<F> {
340    /// Create a new `VotingRegressor` with default settings.
341    ///
342    /// Defaults: `max_depths = [Some(2), Some(4), Some(6), None]`,
343    /// `min_samples_split = 2`, `min_samples_leaf = 1`.
344    #[must_use]
345    pub fn new() -> Self {
346        Self {
347            max_depths: vec![Some(2), Some(4), Some(6), None],
348            min_samples_split: 2,
349            min_samples_leaf: 1,
350            _marker: std::marker::PhantomData,
351        }
352    }
353
354    /// Set the maximum depth settings for each ensemble member.
355    #[must_use]
356    pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
357        self.max_depths = max_depths;
358        self
359    }
360
361    /// Set the minimum number of samples required to split a node.
362    #[must_use]
363    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
364        self.min_samples_split = min_samples_split;
365        self
366    }
367
368    /// Set the minimum number of samples required in a leaf node.
369    #[must_use]
370    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
371        self.min_samples_leaf = min_samples_leaf;
372        self
373    }
374}
375
376impl<F: Float> Default for VotingRegressor<F> {
377    fn default() -> Self {
378        Self::new()
379    }
380}
381
382// ---------------------------------------------------------------------------
383// FittedVotingRegressor
384// ---------------------------------------------------------------------------
385
386/// A fitted voting ensemble regressor.
387///
388/// Stores the individually fitted decision tree regressors and aggregates
389/// their predictions by averaging.
390#[derive(Debug, Clone)]
391pub struct FittedVotingRegressor<F> {
392    /// The fitted decision tree regressors.
393    trees: Vec<FittedDecisionTreeRegressor<F>>,
394}
395
396impl<F: Float + Send + Sync + 'static> FittedVotingRegressor<F> {
397    /// Returns the number of trees in the ensemble.
398    #[must_use]
399    pub fn n_estimators(&self) -> usize {
400        self.trees.len()
401    }
402}
403
404impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for VotingRegressor<F> {
405    type Fitted = FittedVotingRegressor<F>;
406    type Error = FerroError;
407
408    /// Fit the voting regressor by training each decision tree on the full dataset.
409    ///
410    /// # Errors
411    ///
412    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
413    /// numbers of samples.
414    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
415    /// Returns [`FerroError::InvalidParameter`] if configuration is invalid.
416    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedVotingRegressor<F>, FerroError> {
417        let n_samples = x.nrows();
418
419        if n_samples != y.len() {
420            return Err(FerroError::ShapeMismatch {
421                expected: vec![n_samples],
422                actual: vec![y.len()],
423                context: "y length must match number of samples in X".into(),
424            });
425        }
426        if n_samples == 0 {
427            return Err(FerroError::InsufficientSamples {
428                required: 1,
429                actual: 0,
430                context: "VotingRegressor requires at least one sample".into(),
431            });
432        }
433        if self.max_depths.is_empty() {
434            return Err(FerroError::InvalidParameter {
435                name: "max_depths".into(),
436                reason: "must contain at least one entry".into(),
437            });
438        }
439
440        let mut trees = Vec::with_capacity(self.max_depths.len());
441        for &max_depth in &self.max_depths {
442            let tree = DecisionTreeRegressor::<F>::new()
443                .with_max_depth(max_depth)
444                .with_min_samples_split(self.min_samples_split)
445                .with_min_samples_leaf(self.min_samples_leaf);
446            let fitted = tree.fit(x, y)?;
447            trees.push(fitted);
448        }
449
450        Ok(FittedVotingRegressor { trees })
451    }
452}
453
454impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingRegressor<F> {
455    type Output = Array1<F>;
456    type Error = FerroError;
457
458    /// Predict target values by averaging across all trees.
459    ///
460    /// # Errors
461    ///
462    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
463    /// not match the fitted model.
464    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
465        let n_samples = x.nrows();
466        let n_trees_f = F::from(self.trees.len()).unwrap();
467
468        let all_preds: Vec<Array1<F>> = self
469            .trees
470            .iter()
471            .map(|tree| tree.predict(x))
472            .collect::<Result<Vec<_>, _>>()?;
473
474        let mut predictions = Array1::zeros(n_samples);
475        for i in 0..n_samples {
476            let mut sum = F::zero();
477            for tree_preds in &all_preds {
478                sum = sum + tree_preds[i];
479            }
480            predictions[i] = sum / n_trees_f;
481        }
482
483        Ok(predictions)
484    }
485}
486
487// Pipeline integration.
488impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for VotingRegressor<F> {
489    fn fit_pipeline(
490        &self,
491        x: &Array2<F>,
492        y: &Array1<F>,
493    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
494        let fitted = self.fit(x, y)?;
495        Ok(Box::new(fitted))
496    }
497}
498
499impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedVotingRegressor<F> {
500    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
501        self.predict(x)
502    }
503}
504
505// ---------------------------------------------------------------------------
506// Tests
507// ---------------------------------------------------------------------------
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use ndarray::array;
513
514    fn make_classification_data() -> (Array2<f64>, Array1<usize>) {
515        let x = Array2::from_shape_vec(
516            (8, 2),
517            vec![
518                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
519            ],
520        )
521        .unwrap();
522        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
523        (x, y)
524    }
525
526    fn make_regression_data() -> (Array2<f64>, Array1<f64>) {
527        let x = Array2::from_shape_vec(
528            (6, 2),
529            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
530        )
531        .unwrap();
532        let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
533        (x, y)
534    }
535
536    // -- VotingClassifier tests --
537
538    #[test]
539    fn test_voting_classifier_default() {
540        let model = VotingClassifier::<f64>::new();
541        assert_eq!(model.max_depths.len(), 4);
542        assert_eq!(model.min_samples_split, 2);
543        assert_eq!(model.min_samples_leaf, 1);
544    }
545
546    #[test]
547    fn test_voting_classifier_builder() {
548        let model = VotingClassifier::<f64>::new()
549            .with_max_depths(vec![Some(1), Some(3)])
550            .with_min_samples_split(5)
551            .with_min_samples_leaf(2)
552            .with_criterion(ClassificationCriterion::Entropy);
553        assert_eq!(model.max_depths.len(), 2);
554        assert_eq!(model.min_samples_split, 5);
555        assert_eq!(model.min_samples_leaf, 2);
556        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
557    }
558
559    #[test]
560    fn test_voting_classifier_fit_predict() {
561        let (x, y) = make_classification_data();
562        let model = VotingClassifier::<f64>::new();
563        let fitted = model.fit(&x, &y).unwrap();
564        let preds = fitted.predict(&x).unwrap();
565
566        assert_eq!(preds.len(), 8);
567        // On training data with a clear separation, should get most right.
568        for i in 0..4 {
569            assert_eq!(preds[i], 0, "sample {i} should be class 0");
570        }
571        for i in 4..8 {
572            assert_eq!(preds[i], 1, "sample {i} should be class 1");
573        }
574    }
575
576    #[test]
577    fn test_voting_classifier_has_classes() {
578        let (x, y) = make_classification_data();
579        let model = VotingClassifier::<f64>::new();
580        let fitted = model.fit(&x, &y).unwrap();
581        assert_eq!(fitted.classes(), &[0, 1]);
582        assert_eq!(fitted.n_classes(), 2);
583    }
584
585    #[test]
586    fn test_voting_classifier_n_estimators() {
587        let (x, y) = make_classification_data();
588        let model = VotingClassifier::<f64>::new().with_max_depths(vec![Some(2), Some(4), None]);
589        let fitted = model.fit(&x, &y).unwrap();
590        assert_eq!(fitted.n_estimators(), 3);
591    }
592
593    #[test]
594    fn test_voting_classifier_empty_data_error() {
595        let x = Array2::<f64>::zeros((0, 2));
596        let y = Array1::<usize>::zeros(0);
597        let model = VotingClassifier::<f64>::new();
598        let result = model.fit(&x, &y);
599        assert!(result.is_err());
600    }
601
602    #[test]
603    fn test_voting_classifier_shape_mismatch_error() {
604        let x = Array2::<f64>::zeros((5, 2));
605        let y = Array1::<usize>::zeros(3);
606        let model = VotingClassifier::<f64>::new();
607        let result = model.fit(&x, &y);
608        assert!(result.is_err());
609    }
610
611    #[test]
612    fn test_voting_classifier_empty_depths_error() {
613        let (x, y) = make_classification_data();
614        let model = VotingClassifier::<f64>::new().with_max_depths(vec![]);
615        let result = model.fit(&x, &y);
616        assert!(result.is_err());
617    }
618
619    #[test]
620    fn test_voting_classifier_multiclass() {
621        let x = Array2::from_shape_vec(
622            (9, 2),
623            vec![
624                1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 4.0, 4.0, 5.0, 4.0, 4.0, 5.0, 8.0, 8.0, 9.0, 8.0,
625                8.0, 9.0,
626            ],
627        )
628        .unwrap();
629        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
630
631        let model = VotingClassifier::<f64>::new();
632        let fitted = model.fit(&x, &y).unwrap();
633        let preds = fitted.predict(&x).unwrap();
634        assert_eq!(preds.len(), 9);
635        assert_eq!(fitted.n_classes(), 3);
636    }
637
638    // -- VotingRegressor tests --
639
640    #[test]
641    fn test_voting_regressor_default() {
642        let model = VotingRegressor::<f64>::new();
643        assert_eq!(model.max_depths.len(), 4);
644        assert_eq!(model.min_samples_split, 2);
645        assert_eq!(model.min_samples_leaf, 1);
646    }
647
648    #[test]
649    fn test_voting_regressor_builder() {
650        let model = VotingRegressor::<f64>::new()
651            .with_max_depths(vec![Some(1), Some(5)])
652            .with_min_samples_split(3)
653            .with_min_samples_leaf(2);
654        assert_eq!(model.max_depths.len(), 2);
655        assert_eq!(model.min_samples_split, 3);
656        assert_eq!(model.min_samples_leaf, 2);
657    }
658
659    #[test]
660    fn test_voting_regressor_fit_predict() {
661        let (x, y) = make_regression_data();
662        let model = VotingRegressor::<f64>::new();
663        let fitted = model.fit(&x, &y).unwrap();
664        let preds = fitted.predict(&x).unwrap();
665
666        assert_eq!(preds.len(), 6);
667        // The average of multiple trees should approximate the training targets
668        // on the training data.
669        for i in 0..6 {
670            let err = (preds[i] - y[i]).abs();
671            assert!(
672                err < 3.0,
673                "prediction {:.2} should be close to target {:.2}",
674                preds[i],
675                y[i]
676            );
677        }
678    }
679
680    #[test]
681    fn test_voting_regressor_n_estimators() {
682        let (x, y) = make_regression_data();
683        let model = VotingRegressor::<f64>::new().with_max_depths(vec![Some(2), None]);
684        let fitted = model.fit(&x, &y).unwrap();
685        assert_eq!(fitted.n_estimators(), 2);
686    }
687
688    #[test]
689    fn test_voting_regressor_empty_data_error() {
690        let x = Array2::<f64>::zeros((0, 2));
691        let y = Array1::<f64>::zeros(0);
692        let model = VotingRegressor::<f64>::new();
693        let result = model.fit(&x, &y);
694        assert!(result.is_err());
695    }
696
697    #[test]
698    fn test_voting_regressor_shape_mismatch_error() {
699        let x = Array2::<f64>::zeros((5, 2));
700        let y = Array1::<f64>::zeros(3);
701        let model = VotingRegressor::<f64>::new();
702        let result = model.fit(&x, &y);
703        assert!(result.is_err());
704    }
705
706    #[test]
707    fn test_voting_regressor_empty_depths_error() {
708        let (x, y) = make_regression_data();
709        let model = VotingRegressor::<f64>::new().with_max_depths(vec![]);
710        let result = model.fit(&x, &y);
711        assert!(result.is_err());
712    }
713
714    #[test]
715    fn test_voting_regressor_averaging() {
716        // With a single tree (unlimited depth), predictions on training data
717        // should exactly match the targets.
718        let (x, y) = make_regression_data();
719        let model = VotingRegressor::<f64>::new().with_max_depths(vec![None]);
720        let fitted = model.fit(&x, &y).unwrap();
721        let preds = fitted.predict(&x).unwrap();
722
723        for i in 0..6 {
724            assert!(
725                (preds[i] - y[i]).abs() < 1e-10,
726                "single unlimited tree should overfit training data"
727            );
728        }
729    }
730
731    #[test]
732    fn test_voting_classifier_f32() {
733        let x = Array2::<f32>::from_shape_vec(
734            (6, 2),
735            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
736        )
737        .unwrap();
738        let y = array![0, 0, 0, 1, 1, 1];
739        let model = VotingClassifier::<f32>::new();
740        let fitted = model.fit(&x, &y).unwrap();
741        let preds = fitted.predict(&x).unwrap();
742        assert_eq!(preds.len(), 6);
743    }
744
745    #[test]
746    fn test_voting_regressor_f32() {
747        let x = Array2::<f32>::from_shape_vec(
748            (6, 2),
749            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
750        )
751        .unwrap();
752        let y = array![1.0_f32, 2.0, 3.0, 5.0, 6.0, 7.0];
753        let model = VotingRegressor::<f32>::new();
754        let fitted = model.fit(&x, &y).unwrap();
755        let preds = fitted.predict(&x).unwrap();
756        assert_eq!(preds.len(), 6);
757    }
758}