1use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36use rand::SeedableRng;
37use rand::rngs::StdRng;
38use rayon::prelude::*;
39use serde::{Deserialize, Serialize};
40
41use crate::decision_tree::{
42 ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
43};
44use crate::extra_tree::{
45 build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
46};
47use crate::random_forest::MaxFeatures;
48
49fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
51 let result = match strategy {
52 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
53 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
54 MaxFeatures::All => n_features,
55 MaxFeatures::Fixed(n) => n.min(n_features),
56 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
57 };
58 result.max(1).min(n_features)
59}
60
61fn make_tree_params(
63 max_depth: Option<usize>,
64 min_samples_split: usize,
65 min_samples_leaf: usize,
66) -> TreeParams {
67 TreeParams {
68 max_depth,
69 min_samples_split,
70 min_samples_leaf,
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ExtraTreesClassifier<F> {
93 pub n_estimators: usize,
95 pub max_depth: Option<usize>,
97 pub min_samples_split: usize,
99 pub min_samples_leaf: usize,
101 pub max_features: MaxFeatures,
103 pub bootstrap: bool,
105 pub criterion: ClassificationCriterion,
107 pub random_state: Option<u64>,
109 pub n_jobs: Option<usize>,
111 _marker: std::marker::PhantomData<F>,
112}
113
114impl<F: Float> ExtraTreesClassifier<F> {
115 #[must_use]
122 pub fn new() -> Self {
123 Self {
124 n_estimators: 100,
125 max_depth: None,
126 min_samples_split: 2,
127 min_samples_leaf: 1,
128 max_features: MaxFeatures::Sqrt,
129 bootstrap: false,
130 criterion: ClassificationCriterion::Gini,
131 random_state: None,
132 n_jobs: None,
133 _marker: std::marker::PhantomData,
134 }
135 }
136
137 #[must_use]
139 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
140 self.n_estimators = n_estimators;
141 self
142 }
143
144 #[must_use]
146 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
147 self.max_depth = max_depth;
148 self
149 }
150
151 #[must_use]
153 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
154 self.min_samples_split = min_samples_split;
155 self
156 }
157
158 #[must_use]
160 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
161 self.min_samples_leaf = min_samples_leaf;
162 self
163 }
164
165 #[must_use]
167 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
168 self.max_features = max_features;
169 self
170 }
171
172 #[must_use]
174 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
175 self.bootstrap = bootstrap;
176 self
177 }
178
179 #[must_use]
181 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182 self.criterion = criterion;
183 self
184 }
185
186 #[must_use]
188 pub fn with_random_state(mut self, seed: u64) -> Self {
189 self.random_state = Some(seed);
190 self
191 }
192
193 #[must_use]
195 pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
196 self.n_jobs = Some(n_jobs);
197 self
198 }
199}
200
201impl<F: Float> Default for ExtraTreesClassifier<F> {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207#[derive(Debug, Clone)]
216pub struct FittedExtraTreesClassifier<F> {
217 trees: Vec<Vec<Node<F>>>,
219 classes: Vec<usize>,
221 n_features: usize,
223 feature_importances: Array1<F>,
225}
226
227impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
228 #[must_use]
230 pub fn trees(&self) -> &[Vec<Node<F>>] {
231 &self.trees
232 }
233
234 #[must_use]
236 pub fn n_features(&self) -> usize {
237 self.n_features
238 }
239
240 #[must_use]
242 pub fn n_estimators(&self) -> usize {
243 self.trees.len()
244 }
245
246 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255 if x.ncols() != self.n_features {
256 return Err(FerroError::ShapeMismatch {
257 expected: vec![self.n_features],
258 actual: vec![x.ncols()],
259 context: "number of features must match fitted model".into(),
260 });
261 }
262 let n_samples = x.nrows();
263 let n_classes = self.classes.len();
264 let n_trees_f = F::from(self.trees.len()).unwrap();
265 let mut proba = Array2::zeros((n_samples, n_classes));
266
267 for i in 0..n_samples {
268 let row = x.row(i);
269 for tree_nodes in &self.trees {
270 let leaf_idx = traverse(tree_nodes, &row);
271 if let Node::Leaf {
272 class_distribution: Some(ref dist),
273 ..
274 } = tree_nodes[leaf_idx]
275 {
276 for (j, &p) in dist.iter().enumerate() {
277 proba[[i, j]] = proba[[i, j]] + p;
278 }
279 }
280 }
281 for j in 0..n_classes {
283 proba[[i, j]] = proba[[i, j]] / n_trees_f;
284 }
285 }
286 Ok(proba)
287 }
288
289 pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
297 if x.nrows() != y.len() {
298 return Err(FerroError::ShapeMismatch {
299 expected: vec![x.nrows()],
300 actual: vec![y.len()],
301 context: "y length must match number of samples in X".into(),
302 });
303 }
304 let preds = self.predict(x)?;
305 Ok(crate::mean_accuracy(&preds, y))
306 }
307
308 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
315 let proba = self.predict_proba(x)?;
316 Ok(crate::log_proba(&proba))
317 }
318}
319
320impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
321 type Fitted = FittedExtraTreesClassifier<F>;
322 type Error = FerroError;
323
324 fn fit(
337 &self,
338 x: &Array2<F>,
339 y: &Array1<usize>,
340 ) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
341 let (n_samples, n_features) = x.dim();
342
343 if n_samples != y.len() {
344 return Err(FerroError::ShapeMismatch {
345 expected: vec![n_samples],
346 actual: vec![y.len()],
347 context: "y length must match number of samples in X".into(),
348 });
349 }
350 if n_samples == 0 {
351 return Err(FerroError::InsufficientSamples {
352 required: 1,
353 actual: 0,
354 context: "ExtraTreesClassifier requires at least one sample".into(),
355 });
356 }
357 if self.n_estimators == 0 {
358 return Err(FerroError::InvalidParameter {
359 name: "n_estimators".into(),
360 reason: "must be at least 1".into(),
361 });
362 }
363
364 let mut classes: Vec<usize> = y.iter().copied().collect();
366 classes.sort_unstable();
367 classes.dedup();
368 let n_classes = classes.len();
369
370 let y_mapped: Vec<usize> = y
371 .iter()
372 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
373 .collect();
374
375 let max_features_n = resolve_max_features(self.max_features, n_features);
376 let params = make_tree_params(
377 self.max_depth,
378 self.min_samples_split,
379 self.min_samples_leaf,
380 );
381 let criterion = self.criterion;
382 let bootstrap = self.bootstrap;
383
384 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
386 let mut master_rng = StdRng::seed_from_u64(seed);
387 (0..self.n_estimators)
388 .map(|_| {
389 use rand::RngCore;
390 master_rng.next_u64()
391 })
392 .collect()
393 } else {
394 (0..self.n_estimators)
395 .map(|_| {
396 use rand::RngCore;
397 rand::rng().next_u64()
398 })
399 .collect()
400 };
401
402 let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
404 let pool = rayon::ThreadPoolBuilder::new()
405 .num_threads(n_jobs)
406 .build()
407 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
408 pool.install(|| {
409 tree_seeds
410 .par_iter()
411 .map(|&seed| {
412 build_single_classification_tree(
413 x,
414 &y_mapped,
415 n_classes,
416 n_samples,
417 n_features,
418 max_features_n,
419 ¶ms,
420 criterion,
421 bootstrap,
422 seed,
423 )
424 })
425 .collect()
426 })
427 } else {
428 tree_seeds
429 .par_iter()
430 .map(|&seed| {
431 build_single_classification_tree(
432 x,
433 &y_mapped,
434 n_classes,
435 n_samples,
436 n_features,
437 max_features_n,
438 ¶ms,
439 criterion,
440 bootstrap,
441 seed,
442 )
443 })
444 .collect()
445 };
446
447 let mut total_importances = Array1::<F>::zeros(n_features);
449 for tree_nodes in &trees {
450 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
451 total_importances = total_importances + tree_imp;
452 }
453 let imp_sum: F = total_importances
454 .iter()
455 .copied()
456 .fold(F::zero(), |a, b| a + b);
457 if imp_sum > F::zero() {
458 total_importances.mapv_inplace(|v| v / imp_sum);
459 }
460
461 Ok(FittedExtraTreesClassifier {
462 trees,
463 classes,
464 n_features,
465 feature_importances: total_importances,
466 })
467 }
468}
469
470#[allow(clippy::too_many_arguments)]
472fn build_single_classification_tree<F: Float>(
473 x: &Array2<F>,
474 y_mapped: &[usize],
475 n_classes: usize,
476 n_samples: usize,
477 n_features: usize,
478 max_features_n: usize,
479 params: &TreeParams,
480 criterion: ClassificationCriterion,
481 bootstrap: bool,
482 seed: u64,
483) -> Vec<Node<F>> {
484 let mut rng = StdRng::seed_from_u64(seed);
485
486 let indices: Vec<usize> = if bootstrap {
487 use rand::RngCore;
488 (0..n_samples)
489 .map(|_| (rng.next_u64() as usize) % n_samples)
490 .collect()
491 } else {
492 (0..n_samples).collect()
493 };
494
495 build_extra_classification_tree_for_ensemble(
496 x,
497 y_mapped,
498 n_classes,
499 &indices,
500 None, params,
502 criterion,
503 n_features,
504 max_features_n,
505 &mut rng,
506 )
507}
508
509impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
510 type Output = Array1<usize>;
511 type Error = FerroError;
512
513 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
520 if x.ncols() != self.n_features {
521 return Err(FerroError::ShapeMismatch {
522 expected: vec![self.n_features],
523 actual: vec![x.ncols()],
524 context: "number of features must match fitted model".into(),
525 });
526 }
527
528 let n_samples = x.nrows();
529 let n_classes = self.classes.len();
530 let mut predictions = Array1::zeros(n_samples);
531
532 for i in 0..n_samples {
533 let row = x.row(i);
534 let mut votes = vec![0usize; n_classes];
535
536 for tree_nodes in &self.trees {
537 let leaf_idx = traverse(tree_nodes, &row);
538 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
539 let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
540 if class_idx < n_classes {
541 votes[class_idx] += 1;
542 }
543 }
544 }
545
546 let winner = votes
547 .iter()
548 .enumerate()
549 .max_by_key(|&(_, &count)| count)
550 .map_or(0, |(idx, _)| idx);
551 predictions[i] = self.classes[winner];
552 }
553
554 Ok(predictions)
555 }
556}
557
558impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
559 fn feature_importances(&self) -> &Array1<F> {
560 &self.feature_importances
561 }
562}
563
564impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
565 fn classes(&self) -> &[usize] {
566 &self.classes
567 }
568
569 fn n_classes(&self) -> usize {
570 self.classes.len()
571 }
572}
573
574impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
576 for ExtraTreesClassifier<F>
577{
578 fn fit_pipeline(
579 &self,
580 x: &Array2<F>,
581 y: &Array1<F>,
582 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
583 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
584 let fitted = self.fit(x, &y_usize)?;
585 Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
586 }
587}
588
589struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
591 FittedExtraTreesClassifier<F>,
592);
593
594impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
595 for FittedExtraTreesClassifierPipelineAdapter<F>
596{
597 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
598 let preds = self.0.predict(x)?;
599 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
600 }
601}
602
603#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ExtraTreesRegressor<F> {
621 pub n_estimators: usize,
623 pub max_depth: Option<usize>,
625 pub min_samples_split: usize,
627 pub min_samples_leaf: usize,
629 pub max_features: MaxFeatures,
631 pub bootstrap: bool,
633 pub random_state: Option<u64>,
635 pub n_jobs: Option<usize>,
637 _marker: std::marker::PhantomData<F>,
638}
639
640impl<F: Float> ExtraTreesRegressor<F> {
641 #[must_use]
648 pub fn new() -> Self {
649 Self {
650 n_estimators: 100,
651 max_depth: None,
652 min_samples_split: 2,
653 min_samples_leaf: 1,
654 max_features: MaxFeatures::All,
655 bootstrap: false,
656 random_state: None,
657 n_jobs: None,
658 _marker: std::marker::PhantomData,
659 }
660 }
661
662 #[must_use]
664 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
665 self.n_estimators = n_estimators;
666 self
667 }
668
669 #[must_use]
671 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
672 self.max_depth = max_depth;
673 self
674 }
675
676 #[must_use]
678 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
679 self.min_samples_split = min_samples_split;
680 self
681 }
682
683 #[must_use]
685 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
686 self.min_samples_leaf = min_samples_leaf;
687 self
688 }
689
690 #[must_use]
692 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
693 self.max_features = max_features;
694 self
695 }
696
697 #[must_use]
699 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
700 self.bootstrap = bootstrap;
701 self
702 }
703
704 #[must_use]
706 pub fn with_random_state(mut self, seed: u64) -> Self {
707 self.random_state = Some(seed);
708 self
709 }
710
711 #[must_use]
713 pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
714 self.n_jobs = Some(n_jobs);
715 self
716 }
717}
718
719impl<F: Float> Default for ExtraTreesRegressor<F> {
720 fn default() -> Self {
721 Self::new()
722 }
723}
724
725#[derive(Debug, Clone)]
734pub struct FittedExtraTreesRegressor<F> {
735 trees: Vec<Vec<Node<F>>>,
737 n_features: usize,
739 feature_importances: Array1<F>,
741}
742
743impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
744 #[must_use]
746 pub fn trees(&self) -> &[Vec<Node<F>>] {
747 &self.trees
748 }
749
750 #[must_use]
752 pub fn n_features(&self) -> usize {
753 self.n_features
754 }
755
756 #[must_use]
758 pub fn n_estimators(&self) -> usize {
759 self.trees.len()
760 }
761
762 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
770 if x.nrows() != y.len() {
771 return Err(FerroError::ShapeMismatch {
772 expected: vec![x.nrows()],
773 actual: vec![y.len()],
774 context: "y length must match number of samples in X".into(),
775 });
776 }
777 let preds = self.predict(x)?;
778 Ok(crate::r2_score(&preds, y))
779 }
780}
781
782impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
783 type Fitted = FittedExtraTreesRegressor<F>;
784 type Error = FerroError;
785
786 fn fit(
795 &self,
796 x: &Array2<F>,
797 y: &Array1<F>,
798 ) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
799 let (n_samples, n_features) = x.dim();
800
801 if n_samples != y.len() {
802 return Err(FerroError::ShapeMismatch {
803 expected: vec![n_samples],
804 actual: vec![y.len()],
805 context: "y length must match number of samples in X".into(),
806 });
807 }
808 if n_samples == 0 {
809 return Err(FerroError::InsufficientSamples {
810 required: 1,
811 actual: 0,
812 context: "ExtraTreesRegressor requires at least one sample".into(),
813 });
814 }
815 if self.n_estimators == 0 {
816 return Err(FerroError::InvalidParameter {
817 name: "n_estimators".into(),
818 reason: "must be at least 1".into(),
819 });
820 }
821
822 let max_features_n = resolve_max_features(self.max_features, n_features);
823 let params = make_tree_params(
824 self.max_depth,
825 self.min_samples_split,
826 self.min_samples_leaf,
827 );
828 let bootstrap = self.bootstrap;
829
830 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
832 let mut master_rng = StdRng::seed_from_u64(seed);
833 (0..self.n_estimators)
834 .map(|_| {
835 use rand::RngCore;
836 master_rng.next_u64()
837 })
838 .collect()
839 } else {
840 (0..self.n_estimators)
841 .map(|_| {
842 use rand::RngCore;
843 rand::rng().next_u64()
844 })
845 .collect()
846 };
847
848 let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
850 let pool = rayon::ThreadPoolBuilder::new()
851 .num_threads(n_jobs)
852 .build()
853 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
854 pool.install(|| {
855 tree_seeds
856 .par_iter()
857 .map(|&seed| {
858 build_single_regression_tree(
859 x,
860 y,
861 n_samples,
862 n_features,
863 max_features_n,
864 ¶ms,
865 bootstrap,
866 seed,
867 )
868 })
869 .collect()
870 })
871 } else {
872 tree_seeds
873 .par_iter()
874 .map(|&seed| {
875 build_single_regression_tree(
876 x,
877 y,
878 n_samples,
879 n_features,
880 max_features_n,
881 ¶ms,
882 bootstrap,
883 seed,
884 )
885 })
886 .collect()
887 };
888
889 let mut total_importances = Array1::<F>::zeros(n_features);
891 for tree_nodes in &trees {
892 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
893 total_importances = total_importances + tree_imp;
894 }
895 let imp_sum: F = total_importances
896 .iter()
897 .copied()
898 .fold(F::zero(), |a, b| a + b);
899 if imp_sum > F::zero() {
900 total_importances.mapv_inplace(|v| v / imp_sum);
901 }
902
903 Ok(FittedExtraTreesRegressor {
904 trees,
905 n_features,
906 feature_importances: total_importances,
907 })
908 }
909}
910
911#[allow(clippy::too_many_arguments)]
913fn build_single_regression_tree<F: Float>(
914 x: &Array2<F>,
915 y: &Array1<F>,
916 n_samples: usize,
917 n_features: usize,
918 max_features_n: usize,
919 params: &TreeParams,
920 bootstrap: bool,
921 seed: u64,
922) -> Vec<Node<F>> {
923 let mut rng = StdRng::seed_from_u64(seed);
924
925 let indices: Vec<usize> = if bootstrap {
926 use rand::RngCore;
927 (0..n_samples)
928 .map(|_| (rng.next_u64() as usize) % n_samples)
929 .collect()
930 } else {
931 (0..n_samples).collect()
932 };
933
934 build_extra_regression_tree_for_ensemble(
935 x,
936 y,
937 &indices,
938 None, params,
940 n_features,
941 max_features_n,
942 &mut rng,
943 )
944}
945
946impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
947 type Output = Array1<F>;
948 type Error = FerroError;
949
950 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
957 if x.ncols() != self.n_features {
958 return Err(FerroError::ShapeMismatch {
959 expected: vec![self.n_features],
960 actual: vec![x.ncols()],
961 context: "number of features must match fitted model".into(),
962 });
963 }
964
965 let n_samples = x.nrows();
966 let n_trees_f = F::from(self.trees.len()).unwrap();
967 let mut predictions = Array1::zeros(n_samples);
968
969 for i in 0..n_samples {
970 let row = x.row(i);
971 let mut sum = F::zero();
972
973 for tree_nodes in &self.trees {
974 let leaf_idx = traverse(tree_nodes, &row);
975 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
976 sum = sum + value;
977 }
978 }
979
980 predictions[i] = sum / n_trees_f;
981 }
982
983 Ok(predictions)
984 }
985}
986
987impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
988 fn feature_importances(&self) -> &Array1<F> {
989 &self.feature_importances
990 }
991}
992
993impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
995 fn fit_pipeline(
996 &self,
997 x: &Array2<F>,
998 y: &Array1<F>,
999 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1000 let fitted = self.fit(x, y)?;
1001 Ok(Box::new(fitted))
1002 }
1003}
1004
1005impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
1006 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1007 self.predict(x)
1008 }
1009}
1010
1011#[cfg(test)]
1016mod tests {
1017 use super::*;
1018 use approx::assert_relative_eq;
1019 use ndarray::array;
1020
1021 #[test]
1024 fn test_ensemble_classifier_simple() {
1025 let x = Array2::from_shape_vec(
1026 (8, 2),
1027 vec![
1028 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1029 ],
1030 )
1031 .unwrap();
1032 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1033
1034 let model = ExtraTreesClassifier::<f64>::new()
1035 .with_n_estimators(20)
1036 .with_random_state(42);
1037 let fitted = model.fit(&x, &y).unwrap();
1038 let preds = fitted.predict(&x).unwrap();
1039
1040 assert_eq!(preds, y);
1042 }
1043
1044 #[test]
1045 fn test_ensemble_classifier_no_bootstrap() {
1046 let x = Array2::from_shape_vec(
1047 (8, 2),
1048 vec![
1049 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1050 ],
1051 )
1052 .unwrap();
1053 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1054
1055 let model = ExtraTreesClassifier::<f64>::new()
1057 .with_n_estimators(10)
1058 .with_random_state(42);
1059 assert!(!model.bootstrap);
1060 let fitted = model.fit(&x, &y).unwrap();
1061 let preds = fitted.predict(&x).unwrap();
1062 assert_eq!(preds, y);
1063 }
1064
1065 #[test]
1066 fn test_ensemble_classifier_with_bootstrap() {
1067 let x = Array2::from_shape_vec(
1068 (8, 2),
1069 vec![
1070 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1071 ],
1072 )
1073 .unwrap();
1074 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1075
1076 let model = ExtraTreesClassifier::<f64>::new()
1077 .with_n_estimators(20)
1078 .with_bootstrap(true)
1079 .with_random_state(42);
1080 assert!(model.bootstrap);
1081 let fitted = model.fit(&x, &y).unwrap();
1082 let preds = fitted.predict(&x).unwrap();
1083 assert_eq!(preds.len(), 8);
1084 }
1085
1086 #[test]
1087 fn test_ensemble_classifier_predict_proba() {
1088 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1089 let y = array![0, 0, 0, 1, 1, 1];
1090
1091 let model = ExtraTreesClassifier::<f64>::new()
1092 .with_n_estimators(10)
1093 .with_random_state(42);
1094 let fitted = model.fit(&x, &y).unwrap();
1095 let proba = fitted.predict_proba(&x).unwrap();
1096
1097 assert_eq!(proba.dim(), (6, 2));
1098 for i in 0..6 {
1099 let row_sum = proba.row(i).sum();
1100 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1101 }
1102 }
1103
1104 #[test]
1105 fn test_ensemble_classifier_feature_importances() {
1106 let x = Array2::from_shape_vec(
1107 (8, 2),
1108 vec![
1109 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1110 ],
1111 )
1112 .unwrap();
1113 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1114
1115 let model = ExtraTreesClassifier::<f64>::new()
1116 .with_n_estimators(20)
1117 .with_max_features(MaxFeatures::All)
1118 .with_random_state(42);
1119 let fitted = model.fit(&x, &y).unwrap();
1120 let importances = fitted.feature_importances();
1121
1122 assert_eq!(importances.len(), 2);
1123 let total: f64 = importances.sum();
1124 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1125 assert!(importances[0] > importances[1]);
1127 }
1128
1129 #[test]
1130 fn test_ensemble_classifier_n_estimators() {
1131 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1132 let y = array![0, 0, 1, 1];
1133
1134 let model = ExtraTreesClassifier::<f64>::new()
1135 .with_n_estimators(15)
1136 .with_random_state(42);
1137 let fitted = model.fit(&x, &y).unwrap();
1138 assert_eq!(fitted.n_estimators(), 15);
1139 }
1140
1141 #[test]
1142 fn test_ensemble_classifier_classes() {
1143 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1144 let y = array![0, 0, 0, 3, 3, 3]; let model = ExtraTreesClassifier::<f64>::new()
1147 .with_n_estimators(5)
1148 .with_random_state(42);
1149 let fitted = model.fit(&x, &y).unwrap();
1150 assert_eq!(fitted.classes(), &[0, 3]);
1151 assert_eq!(fitted.n_classes(), 2);
1152 }
1153
1154 #[test]
1155 fn test_ensemble_classifier_shape_mismatch() {
1156 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1157 let y = array![0, 0];
1158 let model = ExtraTreesClassifier::<f64>::new();
1159 assert!(model.fit(&x, &y).is_err());
1160 }
1161
1162 #[test]
1163 fn test_ensemble_classifier_empty_data() {
1164 let x = Array2::<f64>::zeros((0, 2));
1165 let y = Array1::<usize>::zeros(0);
1166 let model = ExtraTreesClassifier::<f64>::new();
1167 assert!(model.fit(&x, &y).is_err());
1168 }
1169
1170 #[test]
1171 fn test_ensemble_classifier_zero_estimators() {
1172 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1173 let y = array![0, 1];
1174 let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
1175 assert!(model.fit(&x, &y).is_err());
1176 }
1177
1178 #[test]
1179 fn test_ensemble_classifier_deterministic() {
1180 let x = Array2::from_shape_vec(
1181 (8, 2),
1182 vec![
1183 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1184 ],
1185 )
1186 .unwrap();
1187 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1188
1189 let model1 = ExtraTreesClassifier::<f64>::new()
1190 .with_n_estimators(10)
1191 .with_random_state(123);
1192 let model2 = ExtraTreesClassifier::<f64>::new()
1193 .with_n_estimators(10)
1194 .with_random_state(123);
1195
1196 let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1197 let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1198 assert_eq!(preds1, preds2);
1199 }
1200
1201 #[test]
1202 fn test_ensemble_classifier_predict_shape_mismatch() {
1203 let x =
1204 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1205 let y = array![0, 0, 1, 1];
1206
1207 let model = ExtraTreesClassifier::<f64>::new()
1208 .with_n_estimators(5)
1209 .with_random_state(42);
1210 let fitted = model.fit(&x, &y).unwrap();
1211
1212 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1213 assert!(fitted.predict(&x_wrong).is_err());
1214 }
1215
1216 #[test]
1219 fn test_ensemble_regressor_simple() {
1220 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1221 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1222
1223 let model = ExtraTreesRegressor::<f64>::new()
1224 .with_n_estimators(20)
1225 .with_random_state(42);
1226 let fitted = model.fit(&x, &y).unwrap();
1227 let preds = fitted.predict(&x).unwrap();
1228
1229 assert_eq!(preds.len(), 6);
1230 for i in 0..6 {
1232 assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1233 }
1234 }
1235
1236 #[test]
1237 fn test_ensemble_regressor_constant_target() {
1238 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1239 let y = array![5.0, 5.0, 5.0, 5.0];
1240
1241 let model = ExtraTreesRegressor::<f64>::new()
1242 .with_n_estimators(10)
1243 .with_random_state(42);
1244 let fitted = model.fit(&x, &y).unwrap();
1245 let preds = fitted.predict(&x).unwrap();
1246
1247 for &p in &preds {
1248 assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1249 }
1250 }
1251
1252 #[test]
1253 fn test_ensemble_regressor_no_bootstrap() {
1254 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1255 let y = array![1.0, 2.0, 3.0, 4.0];
1256
1257 let model = ExtraTreesRegressor::<f64>::new()
1258 .with_n_estimators(10)
1259 .with_random_state(42);
1260 assert!(!model.bootstrap);
1261 let fitted = model.fit(&x, &y).unwrap();
1262 let preds = fitted.predict(&x).unwrap();
1263 assert_eq!(preds.len(), 4);
1264 }
1265
1266 #[test]
1267 fn test_ensemble_regressor_with_bootstrap() {
1268 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1269 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1270
1271 let model = ExtraTreesRegressor::<f64>::new()
1272 .with_n_estimators(10)
1273 .with_bootstrap(true)
1274 .with_random_state(42);
1275 assert!(model.bootstrap);
1276 let fitted = model.fit(&x, &y).unwrap();
1277 let preds = fitted.predict(&x).unwrap();
1278 assert_eq!(preds.len(), 6);
1279 }
1280
1281 #[test]
1282 fn test_ensemble_regressor_feature_importances() {
1283 let x = Array2::from_shape_vec(
1284 (8, 2),
1285 vec![
1286 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1287 ],
1288 )
1289 .unwrap();
1290 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1291
1292 let model = ExtraTreesRegressor::<f64>::new()
1293 .with_n_estimators(20)
1294 .with_max_features(MaxFeatures::All)
1295 .with_random_state(42);
1296 let fitted = model.fit(&x, &y).unwrap();
1297 let importances = fitted.feature_importances();
1298
1299 assert_eq!(importances.len(), 2);
1300 let total: f64 = importances.sum();
1301 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1302 assert!(importances[0] > importances[1]);
1303 }
1304
1305 #[test]
1306 fn test_ensemble_regressor_n_estimators() {
1307 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1308 let y = array![1.0, 2.0, 3.0, 4.0];
1309
1310 let model = ExtraTreesRegressor::<f64>::new()
1311 .with_n_estimators(7)
1312 .with_random_state(42);
1313 let fitted = model.fit(&x, &y).unwrap();
1314 assert_eq!(fitted.n_estimators(), 7);
1315 }
1316
1317 #[test]
1318 fn test_ensemble_regressor_shape_mismatch() {
1319 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1320 let y = array![1.0, 2.0];
1321 let model = ExtraTreesRegressor::<f64>::new();
1322 assert!(model.fit(&x, &y).is_err());
1323 }
1324
1325 #[test]
1326 fn test_ensemble_regressor_empty_data() {
1327 let x = Array2::<f64>::zeros((0, 2));
1328 let y = Array1::<f64>::zeros(0);
1329 let model = ExtraTreesRegressor::<f64>::new();
1330 assert!(model.fit(&x, &y).is_err());
1331 }
1332
1333 #[test]
1334 fn test_ensemble_regressor_zero_estimators() {
1335 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1336 let y = array![1.0, 2.0];
1337 let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
1338 assert!(model.fit(&x, &y).is_err());
1339 }
1340
1341 #[test]
1342 fn test_ensemble_regressor_deterministic() {
1343 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1344 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1345
1346 let model1 = ExtraTreesRegressor::<f64>::new()
1347 .with_n_estimators(10)
1348 .with_random_state(99);
1349 let model2 = ExtraTreesRegressor::<f64>::new()
1350 .with_n_estimators(10)
1351 .with_random_state(99);
1352
1353 let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1354 let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1355
1356 for i in 0..6 {
1357 assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1358 }
1359 }
1360
1361 #[test]
1362 fn test_ensemble_regressor_predict_shape_mismatch() {
1363 let x =
1364 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1365 let y = array![1.0, 2.0, 3.0, 4.0];
1366
1367 let model = ExtraTreesRegressor::<f64>::new()
1368 .with_n_estimators(5)
1369 .with_random_state(42);
1370 let fitted = model.fit(&x, &y).unwrap();
1371
1372 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1373 assert!(fitted.predict(&x_wrong).is_err());
1374 }
1375
1376 #[test]
1379 fn test_ensemble_classifier_builder() {
1380 let model = ExtraTreesClassifier::<f64>::new()
1381 .with_n_estimators(50)
1382 .with_max_depth(Some(5))
1383 .with_min_samples_split(10)
1384 .with_min_samples_leaf(3)
1385 .with_max_features(MaxFeatures::Log2)
1386 .with_bootstrap(true)
1387 .with_criterion(ClassificationCriterion::Entropy)
1388 .with_random_state(42)
1389 .with_n_jobs(4);
1390
1391 assert_eq!(model.n_estimators, 50);
1392 assert_eq!(model.max_depth, Some(5));
1393 assert_eq!(model.min_samples_split, 10);
1394 assert_eq!(model.min_samples_leaf, 3);
1395 assert_eq!(model.max_features, MaxFeatures::Log2);
1396 assert!(model.bootstrap);
1397 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1398 assert_eq!(model.random_state, Some(42));
1399 assert_eq!(model.n_jobs, Some(4));
1400 }
1401
1402 #[test]
1403 fn test_ensemble_regressor_builder() {
1404 let model = ExtraTreesRegressor::<f64>::new()
1405 .with_n_estimators(25)
1406 .with_max_depth(Some(8))
1407 .with_min_samples_split(5)
1408 .with_min_samples_leaf(2)
1409 .with_max_features(MaxFeatures::Fraction(0.5))
1410 .with_bootstrap(true)
1411 .with_random_state(99)
1412 .with_n_jobs(2);
1413
1414 assert_eq!(model.n_estimators, 25);
1415 assert_eq!(model.max_depth, Some(8));
1416 assert_eq!(model.min_samples_split, 5);
1417 assert_eq!(model.min_samples_leaf, 2);
1418 assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
1419 assert!(model.bootstrap);
1420 assert_eq!(model.random_state, Some(99));
1421 assert_eq!(model.n_jobs, Some(2));
1422 }
1423
1424 #[test]
1425 fn test_ensemble_classifier_default() {
1426 let model = ExtraTreesClassifier::<f64>::default();
1427 assert_eq!(model.n_estimators, 100);
1428 assert_eq!(model.max_depth, None);
1429 assert_eq!(model.min_samples_split, 2);
1430 assert_eq!(model.min_samples_leaf, 1);
1431 assert_eq!(model.max_features, MaxFeatures::Sqrt);
1432 assert!(!model.bootstrap);
1433 assert_eq!(model.criterion, ClassificationCriterion::Gini);
1434 assert_eq!(model.random_state, None);
1435 assert_eq!(model.n_jobs, None);
1436 }
1437
1438 #[test]
1439 fn test_ensemble_regressor_default() {
1440 let model = ExtraTreesRegressor::<f64>::default();
1441 assert_eq!(model.n_estimators, 100);
1442 assert_eq!(model.max_depth, None);
1443 assert_eq!(model.min_samples_split, 2);
1444 assert_eq!(model.min_samples_leaf, 1);
1445 assert_eq!(model.max_features, MaxFeatures::All);
1446 assert!(!model.bootstrap);
1447 assert_eq!(model.random_state, None);
1448 assert_eq!(model.n_jobs, None);
1449 }
1450}