1use crate::dataset::Dataset;
11use crate::error::{Result, ScryLearnError};
12
13#[derive(Clone, Debug, PartialEq, Eq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20#[non_exhaustive]
21pub enum Voting {
22 Hard,
24 Soft,
28}
29
30pub trait EnsembleClassifier: Send + Sync {
39 fn fit(&mut self, data: &Dataset) -> Result<()>;
41
42 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
44
45 fn predict_proba(&self, _features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
50 Err(ScryLearnError::InvalidParameter(
51 "this estimator does not support predict_proba".into(),
52 ))
53 }
54
55 fn clone_box(&self) -> Box<dyn EnsembleClassifier>;
57}
58
59impl Clone for Box<dyn EnsembleClassifier> {
60 fn clone(&self) -> Self {
61 self.clone_box()
62 }
63}
64
65macro_rules! impl_ensemble_no_proba {
70 ($ty:path) => {
71 impl EnsembleClassifier for $ty {
72 fn fit(&mut self, data: &Dataset) -> Result<()> {
73 self.fit(data)
74 }
75 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
76 self.predict(features)
77 }
78 fn clone_box(&self) -> Box<dyn EnsembleClassifier> {
79 Box::new(self.clone())
80 }
81 }
82 };
83}
84
85macro_rules! impl_ensemble_with_proba {
86 ($ty:path) => {
87 impl EnsembleClassifier for $ty {
88 fn fit(&mut self, data: &Dataset) -> Result<()> {
89 self.fit(data)
90 }
91 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
92 self.predict(features)
93 }
94 fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
95 self.predict_proba(features)
96 }
97 fn clone_box(&self) -> Box<dyn EnsembleClassifier> {
98 Box::new(self.clone())
99 }
100 }
101 };
102}
103
104impl_ensemble_with_proba!(crate::tree::DecisionTreeClassifier);
106impl_ensemble_with_proba!(crate::tree::RandomForestClassifier);
107impl_ensemble_with_proba!(crate::naive_bayes::GaussianNb);
108impl_ensemble_with_proba!(crate::naive_bayes::BernoulliNB);
109impl_ensemble_with_proba!(crate::naive_bayes::MultinomialNB);
110
111impl_ensemble_no_proba!(crate::tree::DecisionTreeRegressor);
113impl_ensemble_no_proba!(crate::linear::LogisticRegression);
114impl_ensemble_no_proba!(crate::linear::LinearRegression);
115impl_ensemble_no_proba!(crate::linear::LassoRegression);
116impl_ensemble_no_proba!(crate::linear::ElasticNet);
117impl_ensemble_no_proba!(crate::neighbors::KnnClassifier);
118impl_ensemble_no_proba!(crate::neighbors::KnnRegressor);
119impl_ensemble_no_proba!(crate::svm::LinearSVC);
120impl_ensemble_no_proba!(crate::svm::LinearSVR);
121#[cfg(feature = "experimental")]
122impl_ensemble_no_proba!(crate::svm::KernelSVC);
123#[cfg(feature = "experimental")]
124impl_ensemble_no_proba!(crate::svm::KernelSVR);
125
126#[derive(Clone)]
149#[non_exhaustive]
150pub struct VotingClassifier {
151 estimators: Vec<Box<dyn EnsembleClassifier>>,
153 voting_strategy: Voting,
155 weights: Option<Vec<f64>>,
157 fitted: bool,
159 n_classes: usize,
161}
162
163impl std::fmt::Debug for VotingClassifier {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("VotingClassifier")
166 .field("n_estimators", &self.estimators.len())
167 .field("voting", &self.voting_strategy)
168 .field("weights", &self.weights)
169 .field("fitted", &self.fitted)
170 .finish()
171 }
172}
173
174impl VotingClassifier {
175 pub fn new(estimators: Vec<Box<dyn EnsembleClassifier>>) -> Self {
177 Self {
178 estimators,
179 voting_strategy: Voting::Hard,
180 weights: None,
181 fitted: false,
182 n_classes: 0,
183 }
184 }
185
186 pub fn voting(mut self, v: Voting) -> Self {
188 self.voting_strategy = v;
189 self
190 }
191
192 pub fn weights(mut self, w: Vec<f64>) -> Self {
194 self.weights = Some(w);
195 self
196 }
197
198 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
200 if data.n_samples() == 0 {
201 return Err(ScryLearnError::EmptyDataset);
202 }
203 if self.estimators.is_empty() {
204 return Err(ScryLearnError::InvalidParameter(
205 "VotingClassifier requires at least one estimator".into(),
206 ));
207 }
208 if let Some(ref w) = self.weights {
209 if w.len() != self.estimators.len() {
210 return Err(ScryLearnError::InvalidParameter(format!(
211 "weights length ({}) must match estimators length ({})",
212 w.len(),
213 self.estimators.len(),
214 )));
215 }
216 }
217
218 self.n_classes = data.n_classes();
219
220 for est in &mut self.estimators {
221 est.fit(data)?;
222 }
223 self.fitted = true;
224 Ok(())
225 }
226
227 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
229 if !self.fitted {
230 return Err(ScryLearnError::NotFitted);
231 }
232
233 match self.voting_strategy {
234 Voting::Hard => self.predict_hard(features),
235 Voting::Soft => self.predict_soft(features),
236 }
237 }
238
239 fn predict_hard(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
241 let n = features.len();
242 let n_classes = self.n_classes;
243
244 let all_preds: Vec<Vec<f64>> = self
246 .estimators
247 .iter()
248 .map(|est| est.predict(features))
249 .collect::<Result<_>>()?;
250
251 let weights = self.uniform_weights();
252
253 let mut result = Vec::with_capacity(n);
254 for sample_idx in 0..n {
255 let mut votes = vec![0.0_f64; n_classes.max(1)];
256 for (est_idx, preds) in all_preds.iter().enumerate() {
257 let class = preds[sample_idx] as usize;
258 if class < votes.len() {
259 votes[class] += weights[est_idx];
260 }
261 }
262 let best_class = votes
263 .iter()
264 .enumerate()
265 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
266 .map_or(0, |(idx, _)| idx);
267 result.push(best_class as f64);
268 }
269
270 Ok(result)
271 }
272
273 fn predict_soft(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
275 let n = features.len();
276 let n_classes = self.n_classes;
277 let weights = self.uniform_weights();
278
279 let mut avg_proba = vec![vec![0.0; n_classes]; n];
280
281 for (est_idx, est) in self.estimators.iter().enumerate() {
282 let probas = est.predict_proba(features)?;
283 for (sample_idx, proba) in probas.iter().enumerate() {
284 for (class_idx, &p) in proba.iter().enumerate() {
285 if class_idx < n_classes {
286 avg_proba[sample_idx][class_idx] += p * weights[est_idx];
287 }
288 }
289 }
290 }
291
292 let result: Vec<f64> = avg_proba
293 .iter()
294 .map(|proba| {
295 proba
296 .iter()
297 .enumerate()
298 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
299 .map_or(0.0, |(idx, _)| idx as f64)
300 })
301 .collect();
302
303 Ok(result)
304 }
305
306 fn uniform_weights(&self) -> Vec<f64> {
308 self.weights
309 .clone()
310 .unwrap_or_else(|| vec![1.0; self.estimators.len()])
311 }
312}
313
314#[derive(Clone)]
341#[non_exhaustive]
342pub struct StackingClassifier {
343 estimators: Vec<Box<dyn EnsembleClassifier>>,
345 final_estimator: Box<dyn EnsembleClassifier>,
347 cv: usize,
349 seed: u64,
351 fitted: bool,
353}
354
355impl std::fmt::Debug for StackingClassifier {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 f.debug_struct("StackingClassifier")
358 .field("n_estimators", &self.estimators.len())
359 .field("cv", &self.cv)
360 .field("fitted", &self.fitted)
361 .finish()
362 }
363}
364
365impl StackingClassifier {
366 pub fn new(
370 estimators: Vec<Box<dyn EnsembleClassifier>>,
371 final_estimator: Box<dyn EnsembleClassifier>,
372 ) -> Self {
373 Self {
374 estimators,
375 final_estimator,
376 cv: 5,
377 seed: 42,
378 fitted: false,
379 }
380 }
381
382 pub fn cv(mut self, k: usize) -> Self {
384 self.cv = k;
385 self
386 }
387
388 pub fn seed(mut self, s: u64) -> Self {
390 self.seed = s;
391 self
392 }
393
394 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
403 data.validate_finite()?;
404 if data.n_samples() == 0 {
405 return Err(ScryLearnError::EmptyDataset);
406 }
407 if self.estimators.is_empty() {
408 return Err(ScryLearnError::InvalidParameter(
409 "StackingClassifier requires at least one base estimator".into(),
410 ));
411 }
412 if self.cv < 2 {
413 return Err(ScryLearnError::InvalidParameter(
414 "cv must be at least 2".into(),
415 ));
416 }
417
418 let n_samples = data.n_samples();
419 let n_estimators = self.estimators.len();
420
421 let folds = generate_fold_indices(n_samples, self.cv, self.seed);
423
424 let mut meta_features = vec![vec![0.0; n_estimators]; n_samples];
426
427 for (fold_idx, test_indices) in folds.iter().enumerate() {
428 let train_indices: Vec<usize> = (0..n_samples)
429 .filter(|i| !test_indices.contains(i))
430 .collect();
431
432 let train_data = data.subset(&train_indices);
433 let test_features = Self::extract_features(data, test_indices);
434
435 for (est_idx, est_template) in self.estimators.iter().enumerate() {
436 let mut est = est_template.clone_box();
437 est.fit(&train_data)?;
438 let preds = est.predict(&test_features)?;
439
440 for (local_idx, &global_idx) in test_indices.iter().enumerate() {
441 meta_features[global_idx][est_idx] = preds[local_idx];
442 }
443
444 let _ = fold_idx;
446 }
447 }
448
449 let meta_columns: Vec<Vec<f64>> = (0..n_estimators)
451 .map(|est_idx| meta_features.iter().map(|row| row[est_idx]).collect())
452 .collect();
453 let feature_names: Vec<String> = (0..n_estimators).map(|i| format!("est_{i}")).collect();
454
455 let meta_dataset = Dataset::new(meta_columns, data.target.clone(), feature_names, "target");
456
457 self.final_estimator.fit(&meta_dataset)?;
459
460 for est in &mut self.estimators {
462 est.fit(data)?;
463 }
464
465 self.fitted = true;
466 Ok(())
467 }
468
469 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
474 if !self.fitted {
475 return Err(ScryLearnError::NotFitted);
476 }
477
478 let n = features.len();
479 let n_estimators = self.estimators.len();
480
481 let base_preds: Vec<Vec<f64>> = self
483 .estimators
484 .iter()
485 .map(|est| est.predict(features))
486 .collect::<Result<_>>()?;
487
488 let meta_features: Vec<Vec<f64>> = (0..n)
490 .map(|i| (0..n_estimators).map(|j| base_preds[j][i]).collect())
491 .collect();
492
493 self.final_estimator.predict(&meta_features)
494 }
495
496 fn extract_features(data: &Dataset, indices: &[usize]) -> Vec<Vec<f64>> {
498 indices.iter().map(|&i| data.sample(i)).collect()
499 }
500}
501
502fn generate_fold_indices(n: usize, k: usize, seed: u64) -> Vec<Vec<usize>> {
504 let mut indices: Vec<usize> = (0..n).collect();
505 let mut rng = crate::rng::FastRng::new(seed);
506
507 for i in (1..indices.len()).rev() {
509 let j = rng.usize(0..=i);
510 indices.swap(i, j);
511 }
512
513 let fold_size = n / k;
514 let remainder = n % k;
515 let mut folds = Vec::with_capacity(k);
516 let mut start = 0;
517 for fold in 0..k {
518 let extra = usize::from(fold < remainder);
519 let end = start + fold_size + extra;
520 folds.push(indices[start..end].to_vec());
521 start = end;
522 }
523
524 folds
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::tree::DecisionTreeClassifier;
531
532 fn make_iris_like_data() -> Dataset {
533 let mut f1 = Vec::new();
535 let mut f2 = Vec::new();
536 let mut target = Vec::new();
537 let mut rng = crate::rng::FastRng::new(42);
538
539 for _ in 0..40 {
541 f1.push(1.0 + rng.f64() * 0.5);
542 f2.push(1.0 + rng.f64() * 0.5);
543 target.push(0.0);
544 }
545 for _ in 0..40 {
547 f1.push(5.0 + rng.f64() * 0.5);
548 f2.push(5.0 + rng.f64() * 0.5);
549 target.push(1.0);
550 }
551 for _ in 0..40 {
553 f1.push(1.0 + rng.f64() * 0.5);
554 f2.push(5.0 + rng.f64() * 0.5);
555 target.push(2.0);
556 }
557
558 Dataset::new(
559 vec![f1, f2],
560 target,
561 vec!["f1".into(), "f2".into()],
562 "class",
563 )
564 }
565
566 #[test]
567 fn test_voting_hard_basic() {
568 let data = make_iris_like_data();
569
570 let mut vc = VotingClassifier::new(vec![
571 Box::new(DecisionTreeClassifier::new().max_depth(3)),
572 Box::new(DecisionTreeClassifier::new().max_depth(5)),
573 Box::new(DecisionTreeClassifier::new().max_depth(7)),
574 ])
575 .voting(Voting::Hard);
576
577 vc.fit(&data).unwrap();
578 let features = data.feature_matrix();
579 let preds = vc.predict(&features).unwrap();
580
581 let acc = preds
582 .iter()
583 .zip(data.target.iter())
584 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
585 .count() as f64
586 / data.n_samples() as f64;
587
588 assert!(
589 acc >= 0.85,
590 "VotingClassifier hard vote accuracy should be ≥ 85%, got {:.1}%",
591 acc * 100.0,
592 );
593 }
594
595 #[test]
596 fn test_voting_soft_basic() {
597 let data = make_iris_like_data();
598
599 let mut vc = VotingClassifier::new(vec![
601 Box::new(DecisionTreeClassifier::new().max_depth(3)),
602 Box::new(DecisionTreeClassifier::new().max_depth(5)),
603 Box::new(DecisionTreeClassifier::new().max_depth(7)),
604 ])
605 .voting(Voting::Soft);
606
607 vc.fit(&data).unwrap();
608 let features = data.feature_matrix();
609 let preds = vc.predict(&features).unwrap();
610
611 let acc = preds
612 .iter()
613 .zip(data.target.iter())
614 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
615 .count() as f64
616 / data.n_samples() as f64;
617
618 assert!(
619 acc >= 0.85,
620 "VotingClassifier soft vote accuracy should be ≥ 85%, got {:.1}%",
621 acc * 100.0,
622 );
623 }
624
625 #[test]
626 fn test_voting_weighted() {
627 let data = make_iris_like_data();
628
629 let mut vc = VotingClassifier::new(vec![
630 Box::new(DecisionTreeClassifier::new().max_depth(3)),
631 Box::new(DecisionTreeClassifier::new().max_depth(5)),
632 ])
633 .voting(Voting::Hard)
634 .weights(vec![1.0, 2.0]);
635
636 vc.fit(&data).unwrap();
637 let features = data.feature_matrix();
638 let preds = vc.predict(&features).unwrap();
639 assert_eq!(preds.len(), data.n_samples());
640 }
641
642 #[test]
643 fn test_voting_not_fitted() {
644 let vc = VotingClassifier::new(vec![Box::new(DecisionTreeClassifier::new())]);
645 let result = vc.predict(&[vec![1.0, 2.0]]);
646 assert!(result.is_err());
647 }
648
649 #[test]
650 fn test_voting_empty_estimators() {
651 let data = make_iris_like_data();
652 let mut vc = VotingClassifier::new(vec![]);
653 assert!(vc.fit(&data).is_err());
654 }
655
656 #[test]
657 fn test_voting_weights_mismatch() {
658 let data = make_iris_like_data();
659 let mut vc = VotingClassifier::new(vec![Box::new(DecisionTreeClassifier::new())])
660 .weights(vec![1.0, 2.0]); assert!(vc.fit(&data).is_err());
662 }
663
664 #[test]
665 fn test_stacking_basic() {
666 let data = make_iris_like_data();
667
668 let mut sc = StackingClassifier::new(
669 vec![
670 Box::new(DecisionTreeClassifier::new().max_depth(3)),
671 Box::new(DecisionTreeClassifier::new().max_depth(7)),
672 ],
673 Box::new(DecisionTreeClassifier::new().max_depth(5)),
674 )
675 .cv(3)
676 .seed(42);
677
678 sc.fit(&data).unwrap();
679 let features = data.feature_matrix();
680 let preds = sc.predict(&features).unwrap();
681
682 assert_eq!(preds.len(), data.n_samples());
683
684 let acc = preds
685 .iter()
686 .zip(data.target.iter())
687 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
688 .count() as f64
689 / data.n_samples() as f64;
690
691 assert!(
692 acc >= 0.70,
693 "StackingClassifier accuracy should be ≥ 70%, got {:.1}%",
694 acc * 100.0,
695 );
696 }
697
698 #[test]
699 fn test_stacking_not_fitted() {
700 let sc = StackingClassifier::new(
701 vec![Box::new(DecisionTreeClassifier::new())],
702 Box::new(DecisionTreeClassifier::new()),
703 );
704 let result = sc.predict(&[vec![1.0, 2.0]]);
705 assert!(result.is_err());
706 }
707
708 #[test]
709 fn test_stacking_empty_estimators() {
710 let data = make_iris_like_data();
711 let mut sc = StackingClassifier::new(vec![], Box::new(DecisionTreeClassifier::new()));
712 assert!(sc.fit(&data).is_err());
713 }
714
715 #[test]
716 fn test_stacking_cv_too_small() {
717 let data = make_iris_like_data();
718 let mut sc = StackingClassifier::new(
719 vec![Box::new(DecisionTreeClassifier::new())],
720 Box::new(DecisionTreeClassifier::new()),
721 )
722 .cv(1);
723 assert!(sc.fit(&data).is_err());
724 }
725
726 #[test]
727 fn test_generate_fold_indices() {
728 let folds = generate_fold_indices(10, 3, 42);
729 assert_eq!(folds.len(), 3);
730 let total: usize = folds.iter().map(std::vec::Vec::len).sum();
731 assert_eq!(total, 10);
732 let mut all: Vec<usize> = folds.into_iter().flatten().collect();
734 all.sort_unstable();
735 assert_eq!(all, (0..10).collect::<Vec<_>>());
736 }
737
738 #[test]
739 fn test_voting_accuracy_ge_individual() {
740 let data = make_iris_like_data();
741 let features = data.feature_matrix();
742
743 let mut dt1 = DecisionTreeClassifier::new().max_depth(2);
745 dt1.fit(&data).unwrap();
746 let preds1 = dt1.predict(&features).unwrap();
747 let acc1 = preds1
748 .iter()
749 .zip(data.target.iter())
750 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
751 .count() as f64
752 / data.n_samples() as f64;
753
754 let mut vc = VotingClassifier::new(vec![
756 Box::new(DecisionTreeClassifier::new().max_depth(2)),
757 Box::new(DecisionTreeClassifier::new().max_depth(4)),
758 Box::new(DecisionTreeClassifier::new().max_depth(6)),
759 ])
760 .voting(Voting::Hard);
761
762 vc.fit(&data).unwrap();
763 let preds_vc = vc.predict(&features).unwrap();
764 let acc_vc = preds_vc
765 .iter()
766 .zip(data.target.iter())
767 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
768 .count() as f64
769 / data.n_samples() as f64;
770
771 assert!(
773 acc_vc >= acc1 - 0.05,
774 "VotingClassifier ({:.1}%) should be ≥ individual DT ({:.1}%) - 5%",
775 acc_vc * 100.0,
776 acc1 * 100.0,
777 );
778 }
779}