1use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::HasClasses;
33use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
34use ferrolearn_core::traits::{Fit, Predict};
35use ndarray::{Array1, Array2};
36use num_traits::{Float, FromPrimitive, ToPrimitive};
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 ClassificationCriterion, DecisionTreeClassifier, DecisionTreeRegressor,
41 FittedDecisionTreeClassifier, FittedDecisionTreeRegressor,
42};
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct VotingClassifier<F> {
62 pub max_depths: Vec<Option<usize>>,
65 pub min_samples_split: usize,
67 pub min_samples_leaf: usize,
69 pub criterion: ClassificationCriterion,
71 _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> VotingClassifier<F> {
75 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 max_depths: vec![Some(2), Some(4), Some(6), None],
83 min_samples_split: 2,
84 min_samples_leaf: 1,
85 criterion: ClassificationCriterion::Gini,
86 _marker: std::marker::PhantomData,
87 }
88 }
89
90 #[must_use]
94 pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
95 self.max_depths = max_depths;
96 self
97 }
98
99 #[must_use]
101 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
102 self.min_samples_split = min_samples_split;
103 self
104 }
105
106 #[must_use]
108 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
109 self.min_samples_leaf = min_samples_leaf;
110 self
111 }
112
113 #[must_use]
115 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
116 self.criterion = criterion;
117 self
118 }
119}
120
121impl<F: Float> Default for VotingClassifier<F> {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127#[derive(Debug, Clone)]
136pub struct FittedVotingClassifier<F> {
137 trees: Vec<FittedDecisionTreeClassifier<F>>,
139 classes: Vec<usize>,
141}
142
143impl<F: Float + Send + Sync + 'static> FittedVotingClassifier<F> {
144 #[must_use]
146 pub fn n_estimators(&self) -> usize {
147 self.trees.len()
148 }
149}
150
151impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for VotingClassifier<F> {
152 type Fitted = FittedVotingClassifier<F>;
153 type Error = FerroError;
154
155 fn fit(
164 &self,
165 x: &Array2<F>,
166 y: &Array1<usize>,
167 ) -> Result<FittedVotingClassifier<F>, FerroError> {
168 let n_samples = x.nrows();
169
170 if n_samples != y.len() {
171 return Err(FerroError::ShapeMismatch {
172 expected: vec![n_samples],
173 actual: vec![y.len()],
174 context: "y length must match number of samples in X".into(),
175 });
176 }
177 if n_samples == 0 {
178 return Err(FerroError::InsufficientSamples {
179 required: 1,
180 actual: 0,
181 context: "VotingClassifier requires at least one sample".into(),
182 });
183 }
184 if self.max_depths.is_empty() {
185 return Err(FerroError::InvalidParameter {
186 name: "max_depths".into(),
187 reason: "must contain at least one entry".into(),
188 });
189 }
190
191 let mut classes: Vec<usize> = y.iter().copied().collect();
193 classes.sort_unstable();
194 classes.dedup();
195
196 let mut trees = Vec::with_capacity(self.max_depths.len());
197 for &max_depth in &self.max_depths {
198 let tree = DecisionTreeClassifier::<F>::new()
199 .with_max_depth(max_depth)
200 .with_min_samples_split(self.min_samples_split)
201 .with_min_samples_leaf(self.min_samples_leaf)
202 .with_criterion(self.criterion);
203 let fitted = tree.fit(x, y)?;
204 trees.push(fitted);
205 }
206
207 Ok(FittedVotingClassifier { trees, classes })
208 }
209}
210
211impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingClassifier<F> {
212 type Output = Array1<usize>;
213 type Error = FerroError;
214
215 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
222 let n_samples = x.nrows();
223 let n_classes = self.classes.len();
224
225 let all_preds: Vec<Array1<usize>> = self
227 .trees
228 .iter()
229 .map(|tree| tree.predict(x))
230 .collect::<Result<Vec<_>, _>>()?;
231
232 let mut predictions = Array1::zeros(n_samples);
233 for i in 0..n_samples {
234 let mut votes = vec![0usize; n_classes];
235 for tree_preds in &all_preds {
236 let pred = tree_preds[i];
237 if let Some(class_idx) = self.classes.iter().position(|&c| c == pred) {
238 votes[class_idx] += 1;
239 }
240 }
241 let winner = votes
242 .iter()
243 .enumerate()
244 .max_by_key(|&(_, &count)| count)
245 .map(|(idx, _)| idx)
246 .unwrap_or(0);
247 predictions[i] = self.classes[winner];
248 }
249
250 Ok(predictions)
251 }
252}
253
254impl<F: Float + Send + Sync + 'static> HasClasses for FittedVotingClassifier<F> {
255 fn classes(&self) -> &[usize] {
256 &self.classes
257 }
258
259 fn n_classes(&self) -> usize {
260 self.classes.len()
261 }
262}
263
264impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
266 for VotingClassifier<F>
267{
268 fn fit_pipeline(
269 &self,
270 x: &Array2<F>,
271 y: &Array1<F>,
272 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
273 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
274 let fitted = self.fit(x, &y_usize)?;
275 Ok(Box::new(FittedVotingClassifierPipelineAdapter(fitted)))
276 }
277}
278
279struct FittedVotingClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
281 FittedVotingClassifier<F>,
282);
283
284impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
285 for FittedVotingClassifierPipelineAdapter<F>
286{
287 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
288 let preds = self.0.predict(x)?;
289 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct VotingRegressor<F> {
331 pub max_depths: Vec<Option<usize>>,
333 pub min_samples_split: usize,
335 pub min_samples_leaf: usize,
337 _marker: std::marker::PhantomData<F>,
338}
339
340impl<F: Float> VotingRegressor<F> {
341 #[must_use]
346 pub fn new() -> Self {
347 Self {
348 max_depths: vec![Some(2), Some(4), Some(6), None],
349 min_samples_split: 2,
350 min_samples_leaf: 1,
351 _marker: std::marker::PhantomData,
352 }
353 }
354
355 #[must_use]
357 pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
358 self.max_depths = max_depths;
359 self
360 }
361
362 #[must_use]
364 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
365 self.min_samples_split = min_samples_split;
366 self
367 }
368
369 #[must_use]
371 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
372 self.min_samples_leaf = min_samples_leaf;
373 self
374 }
375}
376
377impl<F: Float> Default for VotingRegressor<F> {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383#[derive(Debug, Clone)]
392pub struct FittedVotingRegressor<F> {
393 trees: Vec<FittedDecisionTreeRegressor<F>>,
395}
396
397impl<F: Float + Send + Sync + 'static> FittedVotingRegressor<F> {
398 #[must_use]
400 pub fn n_estimators(&self) -> usize {
401 self.trees.len()
402 }
403}
404
405impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for VotingRegressor<F> {
406 type Fitted = FittedVotingRegressor<F>;
407 type Error = FerroError;
408
409 fn fit(
418 &self,
419 x: &Array2<F>,
420 y: &Array1<F>,
421 ) -> Result<FittedVotingRegressor<F>, FerroError> {
422 let n_samples = x.nrows();
423
424 if n_samples != y.len() {
425 return Err(FerroError::ShapeMismatch {
426 expected: vec![n_samples],
427 actual: vec![y.len()],
428 context: "y length must match number of samples in X".into(),
429 });
430 }
431 if n_samples == 0 {
432 return Err(FerroError::InsufficientSamples {
433 required: 1,
434 actual: 0,
435 context: "VotingRegressor requires at least one sample".into(),
436 });
437 }
438 if self.max_depths.is_empty() {
439 return Err(FerroError::InvalidParameter {
440 name: "max_depths".into(),
441 reason: "must contain at least one entry".into(),
442 });
443 }
444
445 let mut trees = Vec::with_capacity(self.max_depths.len());
446 for &max_depth in &self.max_depths {
447 let tree = DecisionTreeRegressor::<F>::new()
448 .with_max_depth(max_depth)
449 .with_min_samples_split(self.min_samples_split)
450 .with_min_samples_leaf(self.min_samples_leaf);
451 let fitted = tree.fit(x, y)?;
452 trees.push(fitted);
453 }
454
455 Ok(FittedVotingRegressor { trees })
456 }
457}
458
459impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingRegressor<F> {
460 type Output = Array1<F>;
461 type Error = FerroError;
462
463 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
470 let n_samples = x.nrows();
471 let n_trees_f = F::from(self.trees.len()).unwrap();
472
473 let all_preds: Vec<Array1<F>> = self
474 .trees
475 .iter()
476 .map(|tree| tree.predict(x))
477 .collect::<Result<Vec<_>, _>>()?;
478
479 let mut predictions = Array1::zeros(n_samples);
480 for i in 0..n_samples {
481 let mut sum = F::zero();
482 for tree_preds in &all_preds {
483 sum = sum + tree_preds[i];
484 }
485 predictions[i] = sum / n_trees_f;
486 }
487
488 Ok(predictions)
489 }
490}
491
492impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for VotingRegressor<F> {
494 fn fit_pipeline(
495 &self,
496 x: &Array2<F>,
497 y: &Array1<F>,
498 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
499 let fitted = self.fit(x, y)?;
500 Ok(Box::new(fitted))
501 }
502}
503
504impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedVotingRegressor<F> {
505 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
506 self.predict(x)
507 }
508}
509
510#[cfg(test)]
515mod tests {
516 use super::*;
517 use ndarray::array;
518
519 fn make_classification_data() -> (Array2<f64>, Array1<usize>) {
520 let x = Array2::from_shape_vec(
521 (8, 2),
522 vec![
523 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
524 ],
525 )
526 .unwrap();
527 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
528 (x, y)
529 }
530
531 fn make_regression_data() -> (Array2<f64>, Array1<f64>) {
532 let x = Array2::from_shape_vec(
533 (6, 2),
534 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
535 )
536 .unwrap();
537 let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
538 (x, y)
539 }
540
541 #[test]
544 fn test_voting_classifier_default() {
545 let model = VotingClassifier::<f64>::new();
546 assert_eq!(model.max_depths.len(), 4);
547 assert_eq!(model.min_samples_split, 2);
548 assert_eq!(model.min_samples_leaf, 1);
549 }
550
551 #[test]
552 fn test_voting_classifier_builder() {
553 let model = VotingClassifier::<f64>::new()
554 .with_max_depths(vec![Some(1), Some(3)])
555 .with_min_samples_split(5)
556 .with_min_samples_leaf(2)
557 .with_criterion(ClassificationCriterion::Entropy);
558 assert_eq!(model.max_depths.len(), 2);
559 assert_eq!(model.min_samples_split, 5);
560 assert_eq!(model.min_samples_leaf, 2);
561 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
562 }
563
564 #[test]
565 fn test_voting_classifier_fit_predict() {
566 let (x, y) = make_classification_data();
567 let model = VotingClassifier::<f64>::new();
568 let fitted = model.fit(&x, &y).unwrap();
569 let preds = fitted.predict(&x).unwrap();
570
571 assert_eq!(preds.len(), 8);
572 for i in 0..4 {
574 assert_eq!(preds[i], 0, "sample {i} should be class 0");
575 }
576 for i in 4..8 {
577 assert_eq!(preds[i], 1, "sample {i} should be class 1");
578 }
579 }
580
581 #[test]
582 fn test_voting_classifier_has_classes() {
583 let (x, y) = make_classification_data();
584 let model = VotingClassifier::<f64>::new();
585 let fitted = model.fit(&x, &y).unwrap();
586 assert_eq!(fitted.classes(), &[0, 1]);
587 assert_eq!(fitted.n_classes(), 2);
588 }
589
590 #[test]
591 fn test_voting_classifier_n_estimators() {
592 let (x, y) = make_classification_data();
593 let model = VotingClassifier::<f64>::new()
594 .with_max_depths(vec![Some(2), Some(4), None]);
595 let fitted = model.fit(&x, &y).unwrap();
596 assert_eq!(fitted.n_estimators(), 3);
597 }
598
599 #[test]
600 fn test_voting_classifier_empty_data_error() {
601 let x = Array2::<f64>::zeros((0, 2));
602 let y = Array1::<usize>::zeros(0);
603 let model = VotingClassifier::<f64>::new();
604 let result = model.fit(&x, &y);
605 assert!(result.is_err());
606 }
607
608 #[test]
609 fn test_voting_classifier_shape_mismatch_error() {
610 let x = Array2::<f64>::zeros((5, 2));
611 let y = Array1::<usize>::zeros(3);
612 let model = VotingClassifier::<f64>::new();
613 let result = model.fit(&x, &y);
614 assert!(result.is_err());
615 }
616
617 #[test]
618 fn test_voting_classifier_empty_depths_error() {
619 let (x, y) = make_classification_data();
620 let model = VotingClassifier::<f64>::new().with_max_depths(vec![]);
621 let result = model.fit(&x, &y);
622 assert!(result.is_err());
623 }
624
625 #[test]
626 fn test_voting_classifier_multiclass() {
627 let x = Array2::from_shape_vec(
628 (9, 2),
629 vec![
630 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 4.0, 4.0, 5.0, 4.0, 4.0, 5.0, 8.0, 8.0, 9.0, 8.0,
631 8.0, 9.0,
632 ],
633 )
634 .unwrap();
635 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
636
637 let model = VotingClassifier::<f64>::new();
638 let fitted = model.fit(&x, &y).unwrap();
639 let preds = fitted.predict(&x).unwrap();
640 assert_eq!(preds.len(), 9);
641 assert_eq!(fitted.n_classes(), 3);
642 }
643
644 #[test]
647 fn test_voting_regressor_default() {
648 let model = VotingRegressor::<f64>::new();
649 assert_eq!(model.max_depths.len(), 4);
650 assert_eq!(model.min_samples_split, 2);
651 assert_eq!(model.min_samples_leaf, 1);
652 }
653
654 #[test]
655 fn test_voting_regressor_builder() {
656 let model = VotingRegressor::<f64>::new()
657 .with_max_depths(vec![Some(1), Some(5)])
658 .with_min_samples_split(3)
659 .with_min_samples_leaf(2);
660 assert_eq!(model.max_depths.len(), 2);
661 assert_eq!(model.min_samples_split, 3);
662 assert_eq!(model.min_samples_leaf, 2);
663 }
664
665 #[test]
666 fn test_voting_regressor_fit_predict() {
667 let (x, y) = make_regression_data();
668 let model = VotingRegressor::<f64>::new();
669 let fitted = model.fit(&x, &y).unwrap();
670 let preds = fitted.predict(&x).unwrap();
671
672 assert_eq!(preds.len(), 6);
673 for i in 0..6 {
676 let err = (preds[i] - y[i]).abs();
677 assert!(
678 err < 3.0,
679 "prediction {:.2} should be close to target {:.2}",
680 preds[i],
681 y[i]
682 );
683 }
684 }
685
686 #[test]
687 fn test_voting_regressor_n_estimators() {
688 let (x, y) = make_regression_data();
689 let model = VotingRegressor::<f64>::new()
690 .with_max_depths(vec![Some(2), None]);
691 let fitted = model.fit(&x, &y).unwrap();
692 assert_eq!(fitted.n_estimators(), 2);
693 }
694
695 #[test]
696 fn test_voting_regressor_empty_data_error() {
697 let x = Array2::<f64>::zeros((0, 2));
698 let y = Array1::<f64>::zeros(0);
699 let model = VotingRegressor::<f64>::new();
700 let result = model.fit(&x, &y);
701 assert!(result.is_err());
702 }
703
704 #[test]
705 fn test_voting_regressor_shape_mismatch_error() {
706 let x = Array2::<f64>::zeros((5, 2));
707 let y = Array1::<f64>::zeros(3);
708 let model = VotingRegressor::<f64>::new();
709 let result = model.fit(&x, &y);
710 assert!(result.is_err());
711 }
712
713 #[test]
714 fn test_voting_regressor_empty_depths_error() {
715 let (x, y) = make_regression_data();
716 let model = VotingRegressor::<f64>::new().with_max_depths(vec![]);
717 let result = model.fit(&x, &y);
718 assert!(result.is_err());
719 }
720
721 #[test]
722 fn test_voting_regressor_averaging() {
723 let (x, y) = make_regression_data();
726 let model = VotingRegressor::<f64>::new().with_max_depths(vec![None]);
727 let fitted = model.fit(&x, &y).unwrap();
728 let preds = fitted.predict(&x).unwrap();
729
730 for i in 0..6 {
731 assert!(
732 (preds[i] - y[i]).abs() < 1e-10,
733 "single unlimited tree should overfit training data"
734 );
735 }
736 }
737
738 #[test]
739 fn test_voting_classifier_f32() {
740 let x = Array2::<f32>::from_shape_vec(
741 (6, 2),
742 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
743 )
744 .unwrap();
745 let y = array![0, 0, 0, 1, 1, 1];
746 let model = VotingClassifier::<f32>::new();
747 let fitted = model.fit(&x, &y).unwrap();
748 let preds = fitted.predict(&x).unwrap();
749 assert_eq!(preds.len(), 6);
750 }
751
752 #[test]
753 fn test_voting_regressor_f32() {
754 let x = Array2::<f32>::from_shape_vec(
755 (6, 2),
756 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
757 )
758 .unwrap();
759 let y = array![1.0_f32, 2.0, 3.0, 5.0, 6.0, 7.0];
760 let model = VotingRegressor::<f32>::new();
761 let fitted = model.fit(&x, &y).unwrap();
762 let preds = fitted.predict(&x).unwrap();
763 assert_eq!(preds.len(), 6);
764 }
765}