1use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::HasClasses;
33use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
34use ferrolearn_core::traits::{Fit, Predict};
35use ndarray::{Array1, Array2};
36use num_traits::{Float, FromPrimitive, ToPrimitive};
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 ClassificationCriterion, DecisionTreeClassifier, DecisionTreeRegressor,
41 FittedDecisionTreeClassifier, FittedDecisionTreeRegressor,
42};
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct VotingClassifier<F> {
62 pub max_depths: Vec<Option<usize>>,
65 pub min_samples_split: usize,
67 pub min_samples_leaf: usize,
69 pub criterion: ClassificationCriterion,
71 _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> VotingClassifier<F> {
75 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 max_depths: vec![Some(2), Some(4), Some(6), None],
83 min_samples_split: 2,
84 min_samples_leaf: 1,
85 criterion: ClassificationCriterion::Gini,
86 _marker: std::marker::PhantomData,
87 }
88 }
89
90 #[must_use]
94 pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
95 self.max_depths = max_depths;
96 self
97 }
98
99 #[must_use]
101 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
102 self.min_samples_split = min_samples_split;
103 self
104 }
105
106 #[must_use]
108 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
109 self.min_samples_leaf = min_samples_leaf;
110 self
111 }
112
113 #[must_use]
115 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
116 self.criterion = criterion;
117 self
118 }
119}
120
121impl<F: Float> Default for VotingClassifier<F> {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127#[derive(Debug, Clone)]
136pub struct FittedVotingClassifier<F> {
137 trees: Vec<FittedDecisionTreeClassifier<F>>,
139 classes: Vec<usize>,
141}
142
143impl<F: Float + Send + Sync + 'static> FittedVotingClassifier<F> {
144 #[must_use]
146 pub fn n_estimators(&self) -> usize {
147 self.trees.len()
148 }
149}
150
151impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for VotingClassifier<F> {
152 type Fitted = FittedVotingClassifier<F>;
153 type Error = FerroError;
154
155 fn fit(
164 &self,
165 x: &Array2<F>,
166 y: &Array1<usize>,
167 ) -> Result<FittedVotingClassifier<F>, FerroError> {
168 let n_samples = x.nrows();
169
170 if n_samples != y.len() {
171 return Err(FerroError::ShapeMismatch {
172 expected: vec![n_samples],
173 actual: vec![y.len()],
174 context: "y length must match number of samples in X".into(),
175 });
176 }
177 if n_samples == 0 {
178 return Err(FerroError::InsufficientSamples {
179 required: 1,
180 actual: 0,
181 context: "VotingClassifier requires at least one sample".into(),
182 });
183 }
184 if self.max_depths.is_empty() {
185 return Err(FerroError::InvalidParameter {
186 name: "max_depths".into(),
187 reason: "must contain at least one entry".into(),
188 });
189 }
190
191 let mut classes: Vec<usize> = y.iter().copied().collect();
193 classes.sort_unstable();
194 classes.dedup();
195
196 let mut trees = Vec::with_capacity(self.max_depths.len());
197 for &max_depth in &self.max_depths {
198 let tree = DecisionTreeClassifier::<F>::new()
199 .with_max_depth(max_depth)
200 .with_min_samples_split(self.min_samples_split)
201 .with_min_samples_leaf(self.min_samples_leaf)
202 .with_criterion(self.criterion);
203 let fitted = tree.fit(x, y)?;
204 trees.push(fitted);
205 }
206
207 Ok(FittedVotingClassifier { trees, classes })
208 }
209}
210
211impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingClassifier<F> {
212 type Output = Array1<usize>;
213 type Error = FerroError;
214
215 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
222 let n_samples = x.nrows();
223 let n_classes = self.classes.len();
224
225 let all_preds: Vec<Array1<usize>> = self
227 .trees
228 .iter()
229 .map(|tree| tree.predict(x))
230 .collect::<Result<Vec<_>, _>>()?;
231
232 let mut predictions = Array1::zeros(n_samples);
233 for i in 0..n_samples {
234 let mut votes = vec![0usize; n_classes];
235 for tree_preds in &all_preds {
236 let pred = tree_preds[i];
237 if let Some(class_idx) = self.classes.iter().position(|&c| c == pred) {
238 votes[class_idx] += 1;
239 }
240 }
241 let winner = votes
242 .iter()
243 .enumerate()
244 .max_by_key(|&(_, &count)| count)
245 .map_or(0, |(idx, _)| idx);
246 predictions[i] = self.classes[winner];
247 }
248
249 Ok(predictions)
250 }
251}
252
253impl<F: Float + Send + Sync + 'static> HasClasses for FittedVotingClassifier<F> {
254 fn classes(&self) -> &[usize] {
255 &self.classes
256 }
257
258 fn n_classes(&self) -> usize {
259 self.classes.len()
260 }
261}
262
263impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
265 for VotingClassifier<F>
266{
267 fn fit_pipeline(
268 &self,
269 x: &Array2<F>,
270 y: &Array1<F>,
271 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
272 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
273 let fitted = self.fit(x, &y_usize)?;
274 Ok(Box::new(FittedVotingClassifierPipelineAdapter(fitted)))
275 }
276}
277
278struct FittedVotingClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
280 FittedVotingClassifier<F>,
281);
282
283impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
284 for FittedVotingClassifierPipelineAdapter<F>
285{
286 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
287 let preds = self.0.predict(x)?;
288 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
289 }
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct VotingRegressor<F> {
330 pub max_depths: Vec<Option<usize>>,
332 pub min_samples_split: usize,
334 pub min_samples_leaf: usize,
336 _marker: std::marker::PhantomData<F>,
337}
338
339impl<F: Float> VotingRegressor<F> {
340 #[must_use]
345 pub fn new() -> Self {
346 Self {
347 max_depths: vec![Some(2), Some(4), Some(6), None],
348 min_samples_split: 2,
349 min_samples_leaf: 1,
350 _marker: std::marker::PhantomData,
351 }
352 }
353
354 #[must_use]
356 pub fn with_max_depths(mut self, max_depths: Vec<Option<usize>>) -> Self {
357 self.max_depths = max_depths;
358 self
359 }
360
361 #[must_use]
363 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
364 self.min_samples_split = min_samples_split;
365 self
366 }
367
368 #[must_use]
370 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
371 self.min_samples_leaf = min_samples_leaf;
372 self
373 }
374}
375
376impl<F: Float> Default for VotingRegressor<F> {
377 fn default() -> Self {
378 Self::new()
379 }
380}
381
382#[derive(Debug, Clone)]
391pub struct FittedVotingRegressor<F> {
392 trees: Vec<FittedDecisionTreeRegressor<F>>,
394}
395
396impl<F: Float + Send + Sync + 'static> FittedVotingRegressor<F> {
397 #[must_use]
399 pub fn n_estimators(&self) -> usize {
400 self.trees.len()
401 }
402}
403
404impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for VotingRegressor<F> {
405 type Fitted = FittedVotingRegressor<F>;
406 type Error = FerroError;
407
408 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedVotingRegressor<F>, FerroError> {
417 let n_samples = x.nrows();
418
419 if n_samples != y.len() {
420 return Err(FerroError::ShapeMismatch {
421 expected: vec![n_samples],
422 actual: vec![y.len()],
423 context: "y length must match number of samples in X".into(),
424 });
425 }
426 if n_samples == 0 {
427 return Err(FerroError::InsufficientSamples {
428 required: 1,
429 actual: 0,
430 context: "VotingRegressor requires at least one sample".into(),
431 });
432 }
433 if self.max_depths.is_empty() {
434 return Err(FerroError::InvalidParameter {
435 name: "max_depths".into(),
436 reason: "must contain at least one entry".into(),
437 });
438 }
439
440 let mut trees = Vec::with_capacity(self.max_depths.len());
441 for &max_depth in &self.max_depths {
442 let tree = DecisionTreeRegressor::<F>::new()
443 .with_max_depth(max_depth)
444 .with_min_samples_split(self.min_samples_split)
445 .with_min_samples_leaf(self.min_samples_leaf);
446 let fitted = tree.fit(x, y)?;
447 trees.push(fitted);
448 }
449
450 Ok(FittedVotingRegressor { trees })
451 }
452}
453
454impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedVotingRegressor<F> {
455 type Output = Array1<F>;
456 type Error = FerroError;
457
458 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
465 let n_samples = x.nrows();
466 let n_trees_f = F::from(self.trees.len()).unwrap();
467
468 let all_preds: Vec<Array1<F>> = self
469 .trees
470 .iter()
471 .map(|tree| tree.predict(x))
472 .collect::<Result<Vec<_>, _>>()?;
473
474 let mut predictions = Array1::zeros(n_samples);
475 for i in 0..n_samples {
476 let mut sum = F::zero();
477 for tree_preds in &all_preds {
478 sum = sum + tree_preds[i];
479 }
480 predictions[i] = sum / n_trees_f;
481 }
482
483 Ok(predictions)
484 }
485}
486
487impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for VotingRegressor<F> {
489 fn fit_pipeline(
490 &self,
491 x: &Array2<F>,
492 y: &Array1<F>,
493 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
494 let fitted = self.fit(x, y)?;
495 Ok(Box::new(fitted))
496 }
497}
498
499impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedVotingRegressor<F> {
500 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
501 self.predict(x)
502 }
503}
504
505#[cfg(test)]
510mod tests {
511 use super::*;
512 use ndarray::array;
513
514 fn make_classification_data() -> (Array2<f64>, Array1<usize>) {
515 let x = Array2::from_shape_vec(
516 (8, 2),
517 vec![
518 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
519 ],
520 )
521 .unwrap();
522 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
523 (x, y)
524 }
525
526 fn make_regression_data() -> (Array2<f64>, Array1<f64>) {
527 let x = Array2::from_shape_vec(
528 (6, 2),
529 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
530 )
531 .unwrap();
532 let y = array![1.0, 2.0, 3.0, 5.0, 6.0, 7.0];
533 (x, y)
534 }
535
536 #[test]
539 fn test_voting_classifier_default() {
540 let model = VotingClassifier::<f64>::new();
541 assert_eq!(model.max_depths.len(), 4);
542 assert_eq!(model.min_samples_split, 2);
543 assert_eq!(model.min_samples_leaf, 1);
544 }
545
546 #[test]
547 fn test_voting_classifier_builder() {
548 let model = VotingClassifier::<f64>::new()
549 .with_max_depths(vec![Some(1), Some(3)])
550 .with_min_samples_split(5)
551 .with_min_samples_leaf(2)
552 .with_criterion(ClassificationCriterion::Entropy);
553 assert_eq!(model.max_depths.len(), 2);
554 assert_eq!(model.min_samples_split, 5);
555 assert_eq!(model.min_samples_leaf, 2);
556 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
557 }
558
559 #[test]
560 fn test_voting_classifier_fit_predict() {
561 let (x, y) = make_classification_data();
562 let model = VotingClassifier::<f64>::new();
563 let fitted = model.fit(&x, &y).unwrap();
564 let preds = fitted.predict(&x).unwrap();
565
566 assert_eq!(preds.len(), 8);
567 for i in 0..4 {
569 assert_eq!(preds[i], 0, "sample {i} should be class 0");
570 }
571 for i in 4..8 {
572 assert_eq!(preds[i], 1, "sample {i} should be class 1");
573 }
574 }
575
576 #[test]
577 fn test_voting_classifier_has_classes() {
578 let (x, y) = make_classification_data();
579 let model = VotingClassifier::<f64>::new();
580 let fitted = model.fit(&x, &y).unwrap();
581 assert_eq!(fitted.classes(), &[0, 1]);
582 assert_eq!(fitted.n_classes(), 2);
583 }
584
585 #[test]
586 fn test_voting_classifier_n_estimators() {
587 let (x, y) = make_classification_data();
588 let model = VotingClassifier::<f64>::new().with_max_depths(vec![Some(2), Some(4), None]);
589 let fitted = model.fit(&x, &y).unwrap();
590 assert_eq!(fitted.n_estimators(), 3);
591 }
592
593 #[test]
594 fn test_voting_classifier_empty_data_error() {
595 let x = Array2::<f64>::zeros((0, 2));
596 let y = Array1::<usize>::zeros(0);
597 let model = VotingClassifier::<f64>::new();
598 let result = model.fit(&x, &y);
599 assert!(result.is_err());
600 }
601
602 #[test]
603 fn test_voting_classifier_shape_mismatch_error() {
604 let x = Array2::<f64>::zeros((5, 2));
605 let y = Array1::<usize>::zeros(3);
606 let model = VotingClassifier::<f64>::new();
607 let result = model.fit(&x, &y);
608 assert!(result.is_err());
609 }
610
611 #[test]
612 fn test_voting_classifier_empty_depths_error() {
613 let (x, y) = make_classification_data();
614 let model = VotingClassifier::<f64>::new().with_max_depths(vec![]);
615 let result = model.fit(&x, &y);
616 assert!(result.is_err());
617 }
618
619 #[test]
620 fn test_voting_classifier_multiclass() {
621 let x = Array2::from_shape_vec(
622 (9, 2),
623 vec![
624 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 4.0, 4.0, 5.0, 4.0, 4.0, 5.0, 8.0, 8.0, 9.0, 8.0,
625 8.0, 9.0,
626 ],
627 )
628 .unwrap();
629 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
630
631 let model = VotingClassifier::<f64>::new();
632 let fitted = model.fit(&x, &y).unwrap();
633 let preds = fitted.predict(&x).unwrap();
634 assert_eq!(preds.len(), 9);
635 assert_eq!(fitted.n_classes(), 3);
636 }
637
638 #[test]
641 fn test_voting_regressor_default() {
642 let model = VotingRegressor::<f64>::new();
643 assert_eq!(model.max_depths.len(), 4);
644 assert_eq!(model.min_samples_split, 2);
645 assert_eq!(model.min_samples_leaf, 1);
646 }
647
648 #[test]
649 fn test_voting_regressor_builder() {
650 let model = VotingRegressor::<f64>::new()
651 .with_max_depths(vec![Some(1), Some(5)])
652 .with_min_samples_split(3)
653 .with_min_samples_leaf(2);
654 assert_eq!(model.max_depths.len(), 2);
655 assert_eq!(model.min_samples_split, 3);
656 assert_eq!(model.min_samples_leaf, 2);
657 }
658
659 #[test]
660 fn test_voting_regressor_fit_predict() {
661 let (x, y) = make_regression_data();
662 let model = VotingRegressor::<f64>::new();
663 let fitted = model.fit(&x, &y).unwrap();
664 let preds = fitted.predict(&x).unwrap();
665
666 assert_eq!(preds.len(), 6);
667 for i in 0..6 {
670 let err = (preds[i] - y[i]).abs();
671 assert!(
672 err < 3.0,
673 "prediction {:.2} should be close to target {:.2}",
674 preds[i],
675 y[i]
676 );
677 }
678 }
679
680 #[test]
681 fn test_voting_regressor_n_estimators() {
682 let (x, y) = make_regression_data();
683 let model = VotingRegressor::<f64>::new().with_max_depths(vec![Some(2), None]);
684 let fitted = model.fit(&x, &y).unwrap();
685 assert_eq!(fitted.n_estimators(), 2);
686 }
687
688 #[test]
689 fn test_voting_regressor_empty_data_error() {
690 let x = Array2::<f64>::zeros((0, 2));
691 let y = Array1::<f64>::zeros(0);
692 let model = VotingRegressor::<f64>::new();
693 let result = model.fit(&x, &y);
694 assert!(result.is_err());
695 }
696
697 #[test]
698 fn test_voting_regressor_shape_mismatch_error() {
699 let x = Array2::<f64>::zeros((5, 2));
700 let y = Array1::<f64>::zeros(3);
701 let model = VotingRegressor::<f64>::new();
702 let result = model.fit(&x, &y);
703 assert!(result.is_err());
704 }
705
706 #[test]
707 fn test_voting_regressor_empty_depths_error() {
708 let (x, y) = make_regression_data();
709 let model = VotingRegressor::<f64>::new().with_max_depths(vec![]);
710 let result = model.fit(&x, &y);
711 assert!(result.is_err());
712 }
713
714 #[test]
715 fn test_voting_regressor_averaging() {
716 let (x, y) = make_regression_data();
719 let model = VotingRegressor::<f64>::new().with_max_depths(vec![None]);
720 let fitted = model.fit(&x, &y).unwrap();
721 let preds = fitted.predict(&x).unwrap();
722
723 for i in 0..6 {
724 assert!(
725 (preds[i] - y[i]).abs() < 1e-10,
726 "single unlimited tree should overfit training data"
727 );
728 }
729 }
730
731 #[test]
732 fn test_voting_classifier_f32() {
733 let x = Array2::<f32>::from_shape_vec(
734 (6, 2),
735 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
736 )
737 .unwrap();
738 let y = array![0, 0, 0, 1, 1, 1];
739 let model = VotingClassifier::<f32>::new();
740 let fitted = model.fit(&x, &y).unwrap();
741 let preds = fitted.predict(&x).unwrap();
742 assert_eq!(preds.len(), 6);
743 }
744
745 #[test]
746 fn test_voting_regressor_f32() {
747 let x = Array2::<f32>::from_shape_vec(
748 (6, 2),
749 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
750 )
751 .unwrap();
752 let y = array![1.0_f32, 2.0, 3.0, 5.0, 6.0, 7.0];
753 let model = VotingRegressor::<f32>::new();
754 let fitted = model.fit(&x, &y).unwrap();
755 let preds = fitted.predict(&x).unwrap();
756 assert_eq!(preds.len(), 6);
757 }
758}