1use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36use rand::SeedableRng;
37use rand::rngs::StdRng;
38use rayon::prelude::*;
39use serde::{Deserialize, Serialize};
40
41use crate::decision_tree::{
42 ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
43};
44use crate::extra_tree::{
45 build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
46};
47use crate::random_forest::MaxFeatures;
48
49fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
51 let result = match strategy {
52 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
53 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
54 MaxFeatures::All => n_features,
55 MaxFeatures::Fixed(n) => n.min(n_features),
56 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
57 };
58 result.max(1).min(n_features)
59}
60
61fn make_tree_params(
63 max_depth: Option<usize>,
64 min_samples_split: usize,
65 min_samples_leaf: usize,
66) -> TreeParams {
67 TreeParams {
68 max_depth,
69 min_samples_split,
70 min_samples_leaf,
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ExtraTreesClassifier<F> {
93 pub n_estimators: usize,
95 pub max_depth: Option<usize>,
97 pub min_samples_split: usize,
99 pub min_samples_leaf: usize,
101 pub max_features: MaxFeatures,
103 pub bootstrap: bool,
105 pub criterion: ClassificationCriterion,
107 pub random_state: Option<u64>,
109 pub n_jobs: Option<usize>,
111 _marker: std::marker::PhantomData<F>,
112}
113
114impl<F: Float> ExtraTreesClassifier<F> {
115 #[must_use]
122 pub fn new() -> Self {
123 Self {
124 n_estimators: 100,
125 max_depth: None,
126 min_samples_split: 2,
127 min_samples_leaf: 1,
128 max_features: MaxFeatures::Sqrt,
129 bootstrap: false,
130 criterion: ClassificationCriterion::Gini,
131 random_state: None,
132 n_jobs: None,
133 _marker: std::marker::PhantomData,
134 }
135 }
136
137 #[must_use]
139 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
140 self.n_estimators = n_estimators;
141 self
142 }
143
144 #[must_use]
146 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
147 self.max_depth = max_depth;
148 self
149 }
150
151 #[must_use]
153 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
154 self.min_samples_split = min_samples_split;
155 self
156 }
157
158 #[must_use]
160 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
161 self.min_samples_leaf = min_samples_leaf;
162 self
163 }
164
165 #[must_use]
167 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
168 self.max_features = max_features;
169 self
170 }
171
172 #[must_use]
174 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
175 self.bootstrap = bootstrap;
176 self
177 }
178
179 #[must_use]
181 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182 self.criterion = criterion;
183 self
184 }
185
186 #[must_use]
188 pub fn with_random_state(mut self, seed: u64) -> Self {
189 self.random_state = Some(seed);
190 self
191 }
192
193 #[must_use]
195 pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
196 self.n_jobs = Some(n_jobs);
197 self
198 }
199}
200
201impl<F: Float> Default for ExtraTreesClassifier<F> {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207#[derive(Debug, Clone)]
216pub struct FittedExtraTreesClassifier<F> {
217 trees: Vec<Vec<Node<F>>>,
219 classes: Vec<usize>,
221 n_features: usize,
223 feature_importances: Array1<F>,
225}
226
227impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
228 #[must_use]
230 pub fn trees(&self) -> &[Vec<Node<F>>] {
231 &self.trees
232 }
233
234 #[must_use]
236 pub fn n_features(&self) -> usize {
237 self.n_features
238 }
239
240 #[must_use]
242 pub fn n_estimators(&self) -> usize {
243 self.trees.len()
244 }
245
246 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255 if x.ncols() != self.n_features {
256 return Err(FerroError::ShapeMismatch {
257 expected: vec![self.n_features],
258 actual: vec![x.ncols()],
259 context: "number of features must match fitted model".into(),
260 });
261 }
262 let n_samples = x.nrows();
263 let n_classes = self.classes.len();
264 let n_trees_f = F::from(self.trees.len()).unwrap();
265 let mut proba = Array2::zeros((n_samples, n_classes));
266
267 for i in 0..n_samples {
268 let row = x.row(i);
269 for tree_nodes in &self.trees {
270 let leaf_idx = traverse(tree_nodes, &row);
271 if let Node::Leaf {
272 class_distribution: Some(ref dist),
273 ..
274 } = tree_nodes[leaf_idx]
275 {
276 for (j, &p) in dist.iter().enumerate() {
277 proba[[i, j]] = proba[[i, j]] + p;
278 }
279 }
280 }
281 for j in 0..n_classes {
283 proba[[i, j]] = proba[[i, j]] / n_trees_f;
284 }
285 }
286 Ok(proba)
287 }
288}
289
290impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
291 type Fitted = FittedExtraTreesClassifier<F>;
292 type Error = FerroError;
293
294 fn fit(
307 &self,
308 x: &Array2<F>,
309 y: &Array1<usize>,
310 ) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
311 let (n_samples, n_features) = x.dim();
312
313 if n_samples != y.len() {
314 return Err(FerroError::ShapeMismatch {
315 expected: vec![n_samples],
316 actual: vec![y.len()],
317 context: "y length must match number of samples in X".into(),
318 });
319 }
320 if n_samples == 0 {
321 return Err(FerroError::InsufficientSamples {
322 required: 1,
323 actual: 0,
324 context: "ExtraTreesClassifier requires at least one sample".into(),
325 });
326 }
327 if self.n_estimators == 0 {
328 return Err(FerroError::InvalidParameter {
329 name: "n_estimators".into(),
330 reason: "must be at least 1".into(),
331 });
332 }
333
334 let mut classes: Vec<usize> = y.iter().copied().collect();
336 classes.sort_unstable();
337 classes.dedup();
338 let n_classes = classes.len();
339
340 let y_mapped: Vec<usize> = y
341 .iter()
342 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
343 .collect();
344
345 let max_features_n = resolve_max_features(self.max_features, n_features);
346 let params = make_tree_params(
347 self.max_depth,
348 self.min_samples_split,
349 self.min_samples_leaf,
350 );
351 let criterion = self.criterion;
352 let bootstrap = self.bootstrap;
353
354 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
356 let mut master_rng = StdRng::seed_from_u64(seed);
357 (0..self.n_estimators)
358 .map(|_| {
359 use rand::RngCore;
360 master_rng.next_u64()
361 })
362 .collect()
363 } else {
364 (0..self.n_estimators)
365 .map(|_| {
366 use rand::RngCore;
367 rand::rng().next_u64()
368 })
369 .collect()
370 };
371
372 let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
374 let pool = rayon::ThreadPoolBuilder::new()
375 .num_threads(n_jobs)
376 .build()
377 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
378 pool.install(|| {
379 tree_seeds
380 .par_iter()
381 .map(|&seed| {
382 build_single_classification_tree(
383 x,
384 &y_mapped,
385 n_classes,
386 n_samples,
387 n_features,
388 max_features_n,
389 ¶ms,
390 criterion,
391 bootstrap,
392 seed,
393 )
394 })
395 .collect()
396 })
397 } else {
398 tree_seeds
399 .par_iter()
400 .map(|&seed| {
401 build_single_classification_tree(
402 x,
403 &y_mapped,
404 n_classes,
405 n_samples,
406 n_features,
407 max_features_n,
408 ¶ms,
409 criterion,
410 bootstrap,
411 seed,
412 )
413 })
414 .collect()
415 };
416
417 let mut total_importances = Array1::<F>::zeros(n_features);
419 for tree_nodes in &trees {
420 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
421 total_importances = total_importances + tree_imp;
422 }
423 let imp_sum: F = total_importances
424 .iter()
425 .copied()
426 .fold(F::zero(), |a, b| a + b);
427 if imp_sum > F::zero() {
428 total_importances.mapv_inplace(|v| v / imp_sum);
429 }
430
431 Ok(FittedExtraTreesClassifier {
432 trees,
433 classes,
434 n_features,
435 feature_importances: total_importances,
436 })
437 }
438}
439
440#[allow(clippy::too_many_arguments)]
442fn build_single_classification_tree<F: Float>(
443 x: &Array2<F>,
444 y_mapped: &[usize],
445 n_classes: usize,
446 n_samples: usize,
447 n_features: usize,
448 max_features_n: usize,
449 params: &TreeParams,
450 criterion: ClassificationCriterion,
451 bootstrap: bool,
452 seed: u64,
453) -> Vec<Node<F>> {
454 let mut rng = StdRng::seed_from_u64(seed);
455
456 let indices: Vec<usize> = if bootstrap {
457 use rand::RngCore;
458 (0..n_samples)
459 .map(|_| (rng.next_u64() as usize) % n_samples)
460 .collect()
461 } else {
462 (0..n_samples).collect()
463 };
464
465 build_extra_classification_tree_for_ensemble(
466 x,
467 y_mapped,
468 n_classes,
469 &indices,
470 None, params,
472 criterion,
473 n_features,
474 max_features_n,
475 &mut rng,
476 )
477}
478
479impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
480 type Output = Array1<usize>;
481 type Error = FerroError;
482
483 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
490 if x.ncols() != self.n_features {
491 return Err(FerroError::ShapeMismatch {
492 expected: vec![self.n_features],
493 actual: vec![x.ncols()],
494 context: "number of features must match fitted model".into(),
495 });
496 }
497
498 let n_samples = x.nrows();
499 let n_classes = self.classes.len();
500 let mut predictions = Array1::zeros(n_samples);
501
502 for i in 0..n_samples {
503 let row = x.row(i);
504 let mut votes = vec![0usize; n_classes];
505
506 for tree_nodes in &self.trees {
507 let leaf_idx = traverse(tree_nodes, &row);
508 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
509 let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
510 if class_idx < n_classes {
511 votes[class_idx] += 1;
512 }
513 }
514 }
515
516 let winner = votes
517 .iter()
518 .enumerate()
519 .max_by_key(|&(_, &count)| count)
520 .map_or(0, |(idx, _)| idx);
521 predictions[i] = self.classes[winner];
522 }
523
524 Ok(predictions)
525 }
526}
527
528impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
529 fn feature_importances(&self) -> &Array1<F> {
530 &self.feature_importances
531 }
532}
533
534impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
535 fn classes(&self) -> &[usize] {
536 &self.classes
537 }
538
539 fn n_classes(&self) -> usize {
540 self.classes.len()
541 }
542}
543
544impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
546 for ExtraTreesClassifier<F>
547{
548 fn fit_pipeline(
549 &self,
550 x: &Array2<F>,
551 y: &Array1<F>,
552 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
553 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
554 let fitted = self.fit(x, &y_usize)?;
555 Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
556 }
557}
558
559struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
561 FittedExtraTreesClassifier<F>,
562);
563
564impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
565 for FittedExtraTreesClassifierPipelineAdapter<F>
566{
567 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
568 let preds = self.0.predict(x)?;
569 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
570 }
571}
572
573#[derive(Debug, Clone, Serialize, Deserialize)]
590pub struct ExtraTreesRegressor<F> {
591 pub n_estimators: usize,
593 pub max_depth: Option<usize>,
595 pub min_samples_split: usize,
597 pub min_samples_leaf: usize,
599 pub max_features: MaxFeatures,
601 pub bootstrap: bool,
603 pub random_state: Option<u64>,
605 pub n_jobs: Option<usize>,
607 _marker: std::marker::PhantomData<F>,
608}
609
610impl<F: Float> ExtraTreesRegressor<F> {
611 #[must_use]
618 pub fn new() -> Self {
619 Self {
620 n_estimators: 100,
621 max_depth: None,
622 min_samples_split: 2,
623 min_samples_leaf: 1,
624 max_features: MaxFeatures::All,
625 bootstrap: false,
626 random_state: None,
627 n_jobs: None,
628 _marker: std::marker::PhantomData,
629 }
630 }
631
632 #[must_use]
634 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
635 self.n_estimators = n_estimators;
636 self
637 }
638
639 #[must_use]
641 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
642 self.max_depth = max_depth;
643 self
644 }
645
646 #[must_use]
648 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
649 self.min_samples_split = min_samples_split;
650 self
651 }
652
653 #[must_use]
655 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
656 self.min_samples_leaf = min_samples_leaf;
657 self
658 }
659
660 #[must_use]
662 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
663 self.max_features = max_features;
664 self
665 }
666
667 #[must_use]
669 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
670 self.bootstrap = bootstrap;
671 self
672 }
673
674 #[must_use]
676 pub fn with_random_state(mut self, seed: u64) -> Self {
677 self.random_state = Some(seed);
678 self
679 }
680
681 #[must_use]
683 pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
684 self.n_jobs = Some(n_jobs);
685 self
686 }
687}
688
689impl<F: Float> Default for ExtraTreesRegressor<F> {
690 fn default() -> Self {
691 Self::new()
692 }
693}
694
695#[derive(Debug, Clone)]
704pub struct FittedExtraTreesRegressor<F> {
705 trees: Vec<Vec<Node<F>>>,
707 n_features: usize,
709 feature_importances: Array1<F>,
711}
712
713impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
714 #[must_use]
716 pub fn trees(&self) -> &[Vec<Node<F>>] {
717 &self.trees
718 }
719
720 #[must_use]
722 pub fn n_features(&self) -> usize {
723 self.n_features
724 }
725
726 #[must_use]
728 pub fn n_estimators(&self) -> usize {
729 self.trees.len()
730 }
731}
732
733impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
734 type Fitted = FittedExtraTreesRegressor<F>;
735 type Error = FerroError;
736
737 fn fit(
746 &self,
747 x: &Array2<F>,
748 y: &Array1<F>,
749 ) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
750 let (n_samples, n_features) = x.dim();
751
752 if n_samples != y.len() {
753 return Err(FerroError::ShapeMismatch {
754 expected: vec![n_samples],
755 actual: vec![y.len()],
756 context: "y length must match number of samples in X".into(),
757 });
758 }
759 if n_samples == 0 {
760 return Err(FerroError::InsufficientSamples {
761 required: 1,
762 actual: 0,
763 context: "ExtraTreesRegressor requires at least one sample".into(),
764 });
765 }
766 if self.n_estimators == 0 {
767 return Err(FerroError::InvalidParameter {
768 name: "n_estimators".into(),
769 reason: "must be at least 1".into(),
770 });
771 }
772
773 let max_features_n = resolve_max_features(self.max_features, n_features);
774 let params = make_tree_params(
775 self.max_depth,
776 self.min_samples_split,
777 self.min_samples_leaf,
778 );
779 let bootstrap = self.bootstrap;
780
781 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
783 let mut master_rng = StdRng::seed_from_u64(seed);
784 (0..self.n_estimators)
785 .map(|_| {
786 use rand::RngCore;
787 master_rng.next_u64()
788 })
789 .collect()
790 } else {
791 (0..self.n_estimators)
792 .map(|_| {
793 use rand::RngCore;
794 rand::rng().next_u64()
795 })
796 .collect()
797 };
798
799 let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
801 let pool = rayon::ThreadPoolBuilder::new()
802 .num_threads(n_jobs)
803 .build()
804 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
805 pool.install(|| {
806 tree_seeds
807 .par_iter()
808 .map(|&seed| {
809 build_single_regression_tree(
810 x,
811 y,
812 n_samples,
813 n_features,
814 max_features_n,
815 ¶ms,
816 bootstrap,
817 seed,
818 )
819 })
820 .collect()
821 })
822 } else {
823 tree_seeds
824 .par_iter()
825 .map(|&seed| {
826 build_single_regression_tree(
827 x,
828 y,
829 n_samples,
830 n_features,
831 max_features_n,
832 ¶ms,
833 bootstrap,
834 seed,
835 )
836 })
837 .collect()
838 };
839
840 let mut total_importances = Array1::<F>::zeros(n_features);
842 for tree_nodes in &trees {
843 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
844 total_importances = total_importances + tree_imp;
845 }
846 let imp_sum: F = total_importances
847 .iter()
848 .copied()
849 .fold(F::zero(), |a, b| a + b);
850 if imp_sum > F::zero() {
851 total_importances.mapv_inplace(|v| v / imp_sum);
852 }
853
854 Ok(FittedExtraTreesRegressor {
855 trees,
856 n_features,
857 feature_importances: total_importances,
858 })
859 }
860}
861
862#[allow(clippy::too_many_arguments)]
864fn build_single_regression_tree<F: Float>(
865 x: &Array2<F>,
866 y: &Array1<F>,
867 n_samples: usize,
868 n_features: usize,
869 max_features_n: usize,
870 params: &TreeParams,
871 bootstrap: bool,
872 seed: u64,
873) -> Vec<Node<F>> {
874 let mut rng = StdRng::seed_from_u64(seed);
875
876 let indices: Vec<usize> = if bootstrap {
877 use rand::RngCore;
878 (0..n_samples)
879 .map(|_| (rng.next_u64() as usize) % n_samples)
880 .collect()
881 } else {
882 (0..n_samples).collect()
883 };
884
885 build_extra_regression_tree_for_ensemble(
886 x,
887 y,
888 &indices,
889 None, params,
891 n_features,
892 max_features_n,
893 &mut rng,
894 )
895}
896
897impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
898 type Output = Array1<F>;
899 type Error = FerroError;
900
901 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
908 if x.ncols() != self.n_features {
909 return Err(FerroError::ShapeMismatch {
910 expected: vec![self.n_features],
911 actual: vec![x.ncols()],
912 context: "number of features must match fitted model".into(),
913 });
914 }
915
916 let n_samples = x.nrows();
917 let n_trees_f = F::from(self.trees.len()).unwrap();
918 let mut predictions = Array1::zeros(n_samples);
919
920 for i in 0..n_samples {
921 let row = x.row(i);
922 let mut sum = F::zero();
923
924 for tree_nodes in &self.trees {
925 let leaf_idx = traverse(tree_nodes, &row);
926 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
927 sum = sum + value;
928 }
929 }
930
931 predictions[i] = sum / n_trees_f;
932 }
933
934 Ok(predictions)
935 }
936}
937
938impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
939 fn feature_importances(&self) -> &Array1<F> {
940 &self.feature_importances
941 }
942}
943
944impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
946 fn fit_pipeline(
947 &self,
948 x: &Array2<F>,
949 y: &Array1<F>,
950 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
951 let fitted = self.fit(x, y)?;
952 Ok(Box::new(fitted))
953 }
954}
955
956impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
957 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
958 self.predict(x)
959 }
960}
961
962#[cfg(test)]
967mod tests {
968 use super::*;
969 use approx::assert_relative_eq;
970 use ndarray::array;
971
972 #[test]
975 fn test_ensemble_classifier_simple() {
976 let x = Array2::from_shape_vec(
977 (8, 2),
978 vec![
979 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
980 ],
981 )
982 .unwrap();
983 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
984
985 let model = ExtraTreesClassifier::<f64>::new()
986 .with_n_estimators(20)
987 .with_random_state(42);
988 let fitted = model.fit(&x, &y).unwrap();
989 let preds = fitted.predict(&x).unwrap();
990
991 assert_eq!(preds, y);
993 }
994
995 #[test]
996 fn test_ensemble_classifier_no_bootstrap() {
997 let x = Array2::from_shape_vec(
998 (8, 2),
999 vec![
1000 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1001 ],
1002 )
1003 .unwrap();
1004 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1005
1006 let model = ExtraTreesClassifier::<f64>::new()
1008 .with_n_estimators(10)
1009 .with_random_state(42);
1010 assert!(!model.bootstrap);
1011 let fitted = model.fit(&x, &y).unwrap();
1012 let preds = fitted.predict(&x).unwrap();
1013 assert_eq!(preds, y);
1014 }
1015
1016 #[test]
1017 fn test_ensemble_classifier_with_bootstrap() {
1018 let x = Array2::from_shape_vec(
1019 (8, 2),
1020 vec![
1021 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1022 ],
1023 )
1024 .unwrap();
1025 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1026
1027 let model = ExtraTreesClassifier::<f64>::new()
1028 .with_n_estimators(20)
1029 .with_bootstrap(true)
1030 .with_random_state(42);
1031 assert!(model.bootstrap);
1032 let fitted = model.fit(&x, &y).unwrap();
1033 let preds = fitted.predict(&x).unwrap();
1034 assert_eq!(preds.len(), 8);
1035 }
1036
1037 #[test]
1038 fn test_ensemble_classifier_predict_proba() {
1039 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1040 let y = array![0, 0, 0, 1, 1, 1];
1041
1042 let model = ExtraTreesClassifier::<f64>::new()
1043 .with_n_estimators(10)
1044 .with_random_state(42);
1045 let fitted = model.fit(&x, &y).unwrap();
1046 let proba = fitted.predict_proba(&x).unwrap();
1047
1048 assert_eq!(proba.dim(), (6, 2));
1049 for i in 0..6 {
1050 let row_sum = proba.row(i).sum();
1051 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1052 }
1053 }
1054
1055 #[test]
1056 fn test_ensemble_classifier_feature_importances() {
1057 let x = Array2::from_shape_vec(
1058 (8, 2),
1059 vec![
1060 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1061 ],
1062 )
1063 .unwrap();
1064 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1065
1066 let model = ExtraTreesClassifier::<f64>::new()
1067 .with_n_estimators(20)
1068 .with_max_features(MaxFeatures::All)
1069 .with_random_state(42);
1070 let fitted = model.fit(&x, &y).unwrap();
1071 let importances = fitted.feature_importances();
1072
1073 assert_eq!(importances.len(), 2);
1074 let total: f64 = importances.sum();
1075 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1076 assert!(importances[0] > importances[1]);
1078 }
1079
1080 #[test]
1081 fn test_ensemble_classifier_n_estimators() {
1082 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1083 let y = array![0, 0, 1, 1];
1084
1085 let model = ExtraTreesClassifier::<f64>::new()
1086 .with_n_estimators(15)
1087 .with_random_state(42);
1088 let fitted = model.fit(&x, &y).unwrap();
1089 assert_eq!(fitted.n_estimators(), 15);
1090 }
1091
1092 #[test]
1093 fn test_ensemble_classifier_classes() {
1094 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1095 let y = array![0, 0, 0, 3, 3, 3]; let model = ExtraTreesClassifier::<f64>::new()
1098 .with_n_estimators(5)
1099 .with_random_state(42);
1100 let fitted = model.fit(&x, &y).unwrap();
1101 assert_eq!(fitted.classes(), &[0, 3]);
1102 assert_eq!(fitted.n_classes(), 2);
1103 }
1104
1105 #[test]
1106 fn test_ensemble_classifier_shape_mismatch() {
1107 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1108 let y = array![0, 0];
1109 let model = ExtraTreesClassifier::<f64>::new();
1110 assert!(model.fit(&x, &y).is_err());
1111 }
1112
1113 #[test]
1114 fn test_ensemble_classifier_empty_data() {
1115 let x = Array2::<f64>::zeros((0, 2));
1116 let y = Array1::<usize>::zeros(0);
1117 let model = ExtraTreesClassifier::<f64>::new();
1118 assert!(model.fit(&x, &y).is_err());
1119 }
1120
1121 #[test]
1122 fn test_ensemble_classifier_zero_estimators() {
1123 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1124 let y = array![0, 1];
1125 let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
1126 assert!(model.fit(&x, &y).is_err());
1127 }
1128
1129 #[test]
1130 fn test_ensemble_classifier_deterministic() {
1131 let x = Array2::from_shape_vec(
1132 (8, 2),
1133 vec![
1134 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1135 ],
1136 )
1137 .unwrap();
1138 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1139
1140 let model1 = ExtraTreesClassifier::<f64>::new()
1141 .with_n_estimators(10)
1142 .with_random_state(123);
1143 let model2 = ExtraTreesClassifier::<f64>::new()
1144 .with_n_estimators(10)
1145 .with_random_state(123);
1146
1147 let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1148 let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1149 assert_eq!(preds1, preds2);
1150 }
1151
1152 #[test]
1153 fn test_ensemble_classifier_predict_shape_mismatch() {
1154 let x =
1155 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1156 let y = array![0, 0, 1, 1];
1157
1158 let model = ExtraTreesClassifier::<f64>::new()
1159 .with_n_estimators(5)
1160 .with_random_state(42);
1161 let fitted = model.fit(&x, &y).unwrap();
1162
1163 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1164 assert!(fitted.predict(&x_wrong).is_err());
1165 }
1166
1167 #[test]
1170 fn test_ensemble_regressor_simple() {
1171 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1172 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1173
1174 let model = ExtraTreesRegressor::<f64>::new()
1175 .with_n_estimators(20)
1176 .with_random_state(42);
1177 let fitted = model.fit(&x, &y).unwrap();
1178 let preds = fitted.predict(&x).unwrap();
1179
1180 assert_eq!(preds.len(), 6);
1181 for i in 0..6 {
1183 assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1184 }
1185 }
1186
1187 #[test]
1188 fn test_ensemble_regressor_constant_target() {
1189 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1190 let y = array![5.0, 5.0, 5.0, 5.0];
1191
1192 let model = ExtraTreesRegressor::<f64>::new()
1193 .with_n_estimators(10)
1194 .with_random_state(42);
1195 let fitted = model.fit(&x, &y).unwrap();
1196 let preds = fitted.predict(&x).unwrap();
1197
1198 for &p in &preds {
1199 assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1200 }
1201 }
1202
1203 #[test]
1204 fn test_ensemble_regressor_no_bootstrap() {
1205 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1206 let y = array![1.0, 2.0, 3.0, 4.0];
1207
1208 let model = ExtraTreesRegressor::<f64>::new()
1209 .with_n_estimators(10)
1210 .with_random_state(42);
1211 assert!(!model.bootstrap);
1212 let fitted = model.fit(&x, &y).unwrap();
1213 let preds = fitted.predict(&x).unwrap();
1214 assert_eq!(preds.len(), 4);
1215 }
1216
1217 #[test]
1218 fn test_ensemble_regressor_with_bootstrap() {
1219 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1220 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1221
1222 let model = ExtraTreesRegressor::<f64>::new()
1223 .with_n_estimators(10)
1224 .with_bootstrap(true)
1225 .with_random_state(42);
1226 assert!(model.bootstrap);
1227 let fitted = model.fit(&x, &y).unwrap();
1228 let preds = fitted.predict(&x).unwrap();
1229 assert_eq!(preds.len(), 6);
1230 }
1231
1232 #[test]
1233 fn test_ensemble_regressor_feature_importances() {
1234 let x = Array2::from_shape_vec(
1235 (8, 2),
1236 vec![
1237 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1238 ],
1239 )
1240 .unwrap();
1241 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1242
1243 let model = ExtraTreesRegressor::<f64>::new()
1244 .with_n_estimators(20)
1245 .with_max_features(MaxFeatures::All)
1246 .with_random_state(42);
1247 let fitted = model.fit(&x, &y).unwrap();
1248 let importances = fitted.feature_importances();
1249
1250 assert_eq!(importances.len(), 2);
1251 let total: f64 = importances.sum();
1252 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1253 assert!(importances[0] > importances[1]);
1254 }
1255
1256 #[test]
1257 fn test_ensemble_regressor_n_estimators() {
1258 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1259 let y = array![1.0, 2.0, 3.0, 4.0];
1260
1261 let model = ExtraTreesRegressor::<f64>::new()
1262 .with_n_estimators(7)
1263 .with_random_state(42);
1264 let fitted = model.fit(&x, &y).unwrap();
1265 assert_eq!(fitted.n_estimators(), 7);
1266 }
1267
1268 #[test]
1269 fn test_ensemble_regressor_shape_mismatch() {
1270 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1271 let y = array![1.0, 2.0];
1272 let model = ExtraTreesRegressor::<f64>::new();
1273 assert!(model.fit(&x, &y).is_err());
1274 }
1275
1276 #[test]
1277 fn test_ensemble_regressor_empty_data() {
1278 let x = Array2::<f64>::zeros((0, 2));
1279 let y = Array1::<f64>::zeros(0);
1280 let model = ExtraTreesRegressor::<f64>::new();
1281 assert!(model.fit(&x, &y).is_err());
1282 }
1283
1284 #[test]
1285 fn test_ensemble_regressor_zero_estimators() {
1286 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1287 let y = array![1.0, 2.0];
1288 let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
1289 assert!(model.fit(&x, &y).is_err());
1290 }
1291
1292 #[test]
1293 fn test_ensemble_regressor_deterministic() {
1294 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1295 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1296
1297 let model1 = ExtraTreesRegressor::<f64>::new()
1298 .with_n_estimators(10)
1299 .with_random_state(99);
1300 let model2 = ExtraTreesRegressor::<f64>::new()
1301 .with_n_estimators(10)
1302 .with_random_state(99);
1303
1304 let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1305 let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1306
1307 for i in 0..6 {
1308 assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1309 }
1310 }
1311
1312 #[test]
1313 fn test_ensemble_regressor_predict_shape_mismatch() {
1314 let x =
1315 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1316 let y = array![1.0, 2.0, 3.0, 4.0];
1317
1318 let model = ExtraTreesRegressor::<f64>::new()
1319 .with_n_estimators(5)
1320 .with_random_state(42);
1321 let fitted = model.fit(&x, &y).unwrap();
1322
1323 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1324 assert!(fitted.predict(&x_wrong).is_err());
1325 }
1326
1327 #[test]
1330 fn test_ensemble_classifier_builder() {
1331 let model = ExtraTreesClassifier::<f64>::new()
1332 .with_n_estimators(50)
1333 .with_max_depth(Some(5))
1334 .with_min_samples_split(10)
1335 .with_min_samples_leaf(3)
1336 .with_max_features(MaxFeatures::Log2)
1337 .with_bootstrap(true)
1338 .with_criterion(ClassificationCriterion::Entropy)
1339 .with_random_state(42)
1340 .with_n_jobs(4);
1341
1342 assert_eq!(model.n_estimators, 50);
1343 assert_eq!(model.max_depth, Some(5));
1344 assert_eq!(model.min_samples_split, 10);
1345 assert_eq!(model.min_samples_leaf, 3);
1346 assert_eq!(model.max_features, MaxFeatures::Log2);
1347 assert!(model.bootstrap);
1348 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1349 assert_eq!(model.random_state, Some(42));
1350 assert_eq!(model.n_jobs, Some(4));
1351 }
1352
1353 #[test]
1354 fn test_ensemble_regressor_builder() {
1355 let model = ExtraTreesRegressor::<f64>::new()
1356 .with_n_estimators(25)
1357 .with_max_depth(Some(8))
1358 .with_min_samples_split(5)
1359 .with_min_samples_leaf(2)
1360 .with_max_features(MaxFeatures::Fraction(0.5))
1361 .with_bootstrap(true)
1362 .with_random_state(99)
1363 .with_n_jobs(2);
1364
1365 assert_eq!(model.n_estimators, 25);
1366 assert_eq!(model.max_depth, Some(8));
1367 assert_eq!(model.min_samples_split, 5);
1368 assert_eq!(model.min_samples_leaf, 2);
1369 assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
1370 assert!(model.bootstrap);
1371 assert_eq!(model.random_state, Some(99));
1372 assert_eq!(model.n_jobs, Some(2));
1373 }
1374
1375 #[test]
1376 fn test_ensemble_classifier_default() {
1377 let model = ExtraTreesClassifier::<f64>::default();
1378 assert_eq!(model.n_estimators, 100);
1379 assert_eq!(model.max_depth, None);
1380 assert_eq!(model.min_samples_split, 2);
1381 assert_eq!(model.min_samples_leaf, 1);
1382 assert_eq!(model.max_features, MaxFeatures::Sqrt);
1383 assert!(!model.bootstrap);
1384 assert_eq!(model.criterion, ClassificationCriterion::Gini);
1385 assert_eq!(model.random_state, None);
1386 assert_eq!(model.n_jobs, None);
1387 }
1388
1389 #[test]
1390 fn test_ensemble_regressor_default() {
1391 let model = ExtraTreesRegressor::<f64>::default();
1392 assert_eq!(model.n_estimators, 100);
1393 assert_eq!(model.max_depth, None);
1394 assert_eq!(model.min_samples_split, 2);
1395 assert_eq!(model.min_samples_leaf, 1);
1396 assert_eq!(model.max_features, MaxFeatures::All);
1397 assert!(!model.bootstrap);
1398 assert_eq!(model.random_state, None);
1399 assert_eq!(model.n_jobs, None);
1400 }
1401}