1use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36use rand::SeedableRng;
37use rand::rngs::StdRng;
38use rayon::prelude::*;
39use serde::{Deserialize, Serialize};
40
41use crate::decision_tree::{
42 ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
43};
44use crate::extra_tree::{
45 build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
46};
47use crate::random_forest::MaxFeatures;
48
49fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
51 let result = match strategy {
52 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
53 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
54 MaxFeatures::All => n_features,
55 MaxFeatures::Fixed(n) => n.min(n_features),
56 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
57 };
58 result.max(1).min(n_features)
59}
60
61fn make_tree_params(
63 max_depth: Option<usize>,
64 min_samples_split: usize,
65 min_samples_leaf: usize,
66) -> TreeParams {
67 TreeParams {
68 max_depth,
69 min_samples_split,
70 min_samples_leaf,
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ExtraTreesClassifier<F> {
93 pub n_estimators: usize,
95 pub max_depth: Option<usize>,
97 pub min_samples_split: usize,
99 pub min_samples_leaf: usize,
101 pub max_features: MaxFeatures,
103 pub bootstrap: bool,
105 pub criterion: ClassificationCriterion,
107 pub random_state: Option<u64>,
109 pub n_jobs: Option<usize>,
111 _marker: std::marker::PhantomData<F>,
112}
113
114impl<F: Float> ExtraTreesClassifier<F> {
115 #[must_use]
122 pub fn new() -> Self {
123 Self {
124 n_estimators: 100,
125 max_depth: None,
126 min_samples_split: 2,
127 min_samples_leaf: 1,
128 max_features: MaxFeatures::Sqrt,
129 bootstrap: false,
130 criterion: ClassificationCriterion::Gini,
131 random_state: None,
132 n_jobs: None,
133 _marker: std::marker::PhantomData,
134 }
135 }
136
137 #[must_use]
139 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
140 self.n_estimators = n_estimators;
141 self
142 }
143
144 #[must_use]
146 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
147 self.max_depth = max_depth;
148 self
149 }
150
151 #[must_use]
153 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
154 self.min_samples_split = min_samples_split;
155 self
156 }
157
158 #[must_use]
160 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
161 self.min_samples_leaf = min_samples_leaf;
162 self
163 }
164
165 #[must_use]
167 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
168 self.max_features = max_features;
169 self
170 }
171
172 #[must_use]
174 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
175 self.bootstrap = bootstrap;
176 self
177 }
178
179 #[must_use]
181 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182 self.criterion = criterion;
183 self
184 }
185
186 #[must_use]
188 pub fn with_random_state(mut self, seed: u64) -> Self {
189 self.random_state = Some(seed);
190 self
191 }
192
193 #[must_use]
195 pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
196 self.n_jobs = Some(n_jobs);
197 self
198 }
199}
200
201impl<F: Float> Default for ExtraTreesClassifier<F> {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207#[derive(Debug, Clone)]
216pub struct FittedExtraTreesClassifier<F> {
217 trees: Vec<Vec<Node<F>>>,
219 classes: Vec<usize>,
221 n_features: usize,
223 feature_importances: Array1<F>,
225}
226
227impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
228 #[must_use]
230 pub fn trees(&self) -> &[Vec<Node<F>>] {
231 &self.trees
232 }
233
234 #[must_use]
236 pub fn n_features(&self) -> usize {
237 self.n_features
238 }
239
240 #[must_use]
242 pub fn n_estimators(&self) -> usize {
243 self.trees.len()
244 }
245
246 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255 if x.ncols() != self.n_features {
256 return Err(FerroError::ShapeMismatch {
257 expected: vec![self.n_features],
258 actual: vec![x.ncols()],
259 context: "number of features must match fitted model".into(),
260 });
261 }
262 let n_samples = x.nrows();
263 let n_classes = self.classes.len();
264 let n_trees_f = F::from(self.trees.len()).unwrap();
265 let mut proba = Array2::zeros((n_samples, n_classes));
266
267 for i in 0..n_samples {
268 let row = x.row(i);
269 for tree_nodes in &self.trees {
270 let leaf_idx = traverse(tree_nodes, &row);
271 if let Node::Leaf {
272 class_distribution: Some(ref dist),
273 ..
274 } = tree_nodes[leaf_idx]
275 {
276 for (j, &p) in dist.iter().enumerate() {
277 proba[[i, j]] = proba[[i, j]] + p;
278 }
279 }
280 }
281 for j in 0..n_classes {
283 proba[[i, j]] = proba[[i, j]] / n_trees_f;
284 }
285 }
286 Ok(proba)
287 }
288}
289
290impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
291 type Fitted = FittedExtraTreesClassifier<F>;
292 type Error = FerroError;
293
294 fn fit(
307 &self,
308 x: &Array2<F>,
309 y: &Array1<usize>,
310 ) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
311 let (n_samples, n_features) = x.dim();
312
313 if n_samples != y.len() {
314 return Err(FerroError::ShapeMismatch {
315 expected: vec![n_samples],
316 actual: vec![y.len()],
317 context: "y length must match number of samples in X".into(),
318 });
319 }
320 if n_samples == 0 {
321 return Err(FerroError::InsufficientSamples {
322 required: 1,
323 actual: 0,
324 context: "ExtraTreesClassifier requires at least one sample".into(),
325 });
326 }
327 if self.n_estimators == 0 {
328 return Err(FerroError::InvalidParameter {
329 name: "n_estimators".into(),
330 reason: "must be at least 1".into(),
331 });
332 }
333
334 let mut classes: Vec<usize> = y.iter().copied().collect();
336 classes.sort_unstable();
337 classes.dedup();
338 let n_classes = classes.len();
339
340 let y_mapped: Vec<usize> = y
341 .iter()
342 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
343 .collect();
344
345 let max_features_n = resolve_max_features(self.max_features, n_features);
346 let params = make_tree_params(
347 self.max_depth,
348 self.min_samples_split,
349 self.min_samples_leaf,
350 );
351 let criterion = self.criterion;
352 let bootstrap = self.bootstrap;
353
354 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
356 let mut master_rng = StdRng::seed_from_u64(seed);
357 (0..self.n_estimators)
358 .map(|_| {
359 use rand::RngCore;
360 master_rng.next_u64()
361 })
362 .collect()
363 } else {
364 (0..self.n_estimators)
365 .map(|_| {
366 use rand::RngCore;
367 rand::rng().next_u64()
368 })
369 .collect()
370 };
371
372 let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
374 let pool = rayon::ThreadPoolBuilder::new()
375 .num_threads(n_jobs)
376 .build()
377 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
378 pool.install(|| {
379 tree_seeds
380 .par_iter()
381 .map(|&seed| {
382 build_single_classification_tree(
383 x,
384 &y_mapped,
385 n_classes,
386 n_samples,
387 n_features,
388 max_features_n,
389 ¶ms,
390 criterion,
391 bootstrap,
392 seed,
393 )
394 })
395 .collect()
396 })
397 } else {
398 tree_seeds
399 .par_iter()
400 .map(|&seed| {
401 build_single_classification_tree(
402 x,
403 &y_mapped,
404 n_classes,
405 n_samples,
406 n_features,
407 max_features_n,
408 ¶ms,
409 criterion,
410 bootstrap,
411 seed,
412 )
413 })
414 .collect()
415 };
416
417 let mut total_importances = Array1::<F>::zeros(n_features);
419 for tree_nodes in &trees {
420 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
421 total_importances = total_importances + tree_imp;
422 }
423 let imp_sum: F = total_importances
424 .iter()
425 .copied()
426 .fold(F::zero(), |a, b| a + b);
427 if imp_sum > F::zero() {
428 total_importances.mapv_inplace(|v| v / imp_sum);
429 }
430
431 Ok(FittedExtraTreesClassifier {
432 trees,
433 classes,
434 n_features,
435 feature_importances: total_importances,
436 })
437 }
438}
439
440#[allow(clippy::too_many_arguments)]
442fn build_single_classification_tree<F: Float>(
443 x: &Array2<F>,
444 y_mapped: &[usize],
445 n_classes: usize,
446 n_samples: usize,
447 n_features: usize,
448 max_features_n: usize,
449 params: &TreeParams,
450 criterion: ClassificationCriterion,
451 bootstrap: bool,
452 seed: u64,
453) -> Vec<Node<F>> {
454 let mut rng = StdRng::seed_from_u64(seed);
455
456 let indices: Vec<usize> = if bootstrap {
457 use rand::RngCore;
458 (0..n_samples)
459 .map(|_| (rng.next_u64() as usize) % n_samples)
460 .collect()
461 } else {
462 (0..n_samples).collect()
463 };
464
465 build_extra_classification_tree_for_ensemble(
466 x,
467 y_mapped,
468 n_classes,
469 &indices,
470 None, params,
472 criterion,
473 n_features,
474 max_features_n,
475 &mut rng,
476 )
477}
478
479impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
480 type Output = Array1<usize>;
481 type Error = FerroError;
482
483 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
490 if x.ncols() != self.n_features {
491 return Err(FerroError::ShapeMismatch {
492 expected: vec![self.n_features],
493 actual: vec![x.ncols()],
494 context: "number of features must match fitted model".into(),
495 });
496 }
497
498 let n_samples = x.nrows();
499 let n_classes = self.classes.len();
500 let mut predictions = Array1::zeros(n_samples);
501
502 for i in 0..n_samples {
503 let row = x.row(i);
504 let mut votes = vec![0usize; n_classes];
505
506 for tree_nodes in &self.trees {
507 let leaf_idx = traverse(tree_nodes, &row);
508 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
509 let class_idx = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
510 if class_idx < n_classes {
511 votes[class_idx] += 1;
512 }
513 }
514 }
515
516 let winner = votes
517 .iter()
518 .enumerate()
519 .max_by_key(|&(_, &count)| count)
520 .map(|(idx, _)| idx)
521 .unwrap_or(0);
522 predictions[i] = self.classes[winner];
523 }
524
525 Ok(predictions)
526 }
527}
528
529impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
530 fn feature_importances(&self) -> &Array1<F> {
531 &self.feature_importances
532 }
533}
534
535impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
536 fn classes(&self) -> &[usize] {
537 &self.classes
538 }
539
540 fn n_classes(&self) -> usize {
541 self.classes.len()
542 }
543}
544
545impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
547 for ExtraTreesClassifier<F>
548{
549 fn fit_pipeline(
550 &self,
551 x: &Array2<F>,
552 y: &Array1<F>,
553 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
554 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
555 let fitted = self.fit(x, &y_usize)?;
556 Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
557 }
558}
559
560struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
562 FittedExtraTreesClassifier<F>,
563);
564
565impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
566 for FittedExtraTreesClassifierPipelineAdapter<F>
567{
568 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
569 let preds = self.0.predict(x)?;
570 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
571 }
572}
573
574#[derive(Debug, Clone, Serialize, Deserialize)]
591pub struct ExtraTreesRegressor<F> {
592 pub n_estimators: usize,
594 pub max_depth: Option<usize>,
596 pub min_samples_split: usize,
598 pub min_samples_leaf: usize,
600 pub max_features: MaxFeatures,
602 pub bootstrap: bool,
604 pub random_state: Option<u64>,
606 pub n_jobs: Option<usize>,
608 _marker: std::marker::PhantomData<F>,
609}
610
611impl<F: Float> ExtraTreesRegressor<F> {
612 #[must_use]
619 pub fn new() -> Self {
620 Self {
621 n_estimators: 100,
622 max_depth: None,
623 min_samples_split: 2,
624 min_samples_leaf: 1,
625 max_features: MaxFeatures::All,
626 bootstrap: false,
627 random_state: None,
628 n_jobs: None,
629 _marker: std::marker::PhantomData,
630 }
631 }
632
633 #[must_use]
635 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
636 self.n_estimators = n_estimators;
637 self
638 }
639
640 #[must_use]
642 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
643 self.max_depth = max_depth;
644 self
645 }
646
647 #[must_use]
649 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
650 self.min_samples_split = min_samples_split;
651 self
652 }
653
654 #[must_use]
656 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
657 self.min_samples_leaf = min_samples_leaf;
658 self
659 }
660
661 #[must_use]
663 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
664 self.max_features = max_features;
665 self
666 }
667
668 #[must_use]
670 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
671 self.bootstrap = bootstrap;
672 self
673 }
674
675 #[must_use]
677 pub fn with_random_state(mut self, seed: u64) -> Self {
678 self.random_state = Some(seed);
679 self
680 }
681
682 #[must_use]
684 pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
685 self.n_jobs = Some(n_jobs);
686 self
687 }
688}
689
690impl<F: Float> Default for ExtraTreesRegressor<F> {
691 fn default() -> Self {
692 Self::new()
693 }
694}
695
696#[derive(Debug, Clone)]
705pub struct FittedExtraTreesRegressor<F> {
706 trees: Vec<Vec<Node<F>>>,
708 n_features: usize,
710 feature_importances: Array1<F>,
712}
713
714impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
715 #[must_use]
717 pub fn trees(&self) -> &[Vec<Node<F>>] {
718 &self.trees
719 }
720
721 #[must_use]
723 pub fn n_features(&self) -> usize {
724 self.n_features
725 }
726
727 #[must_use]
729 pub fn n_estimators(&self) -> usize {
730 self.trees.len()
731 }
732}
733
734impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
735 type Fitted = FittedExtraTreesRegressor<F>;
736 type Error = FerroError;
737
738 fn fit(
747 &self,
748 x: &Array2<F>,
749 y: &Array1<F>,
750 ) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
751 let (n_samples, n_features) = x.dim();
752
753 if n_samples != y.len() {
754 return Err(FerroError::ShapeMismatch {
755 expected: vec![n_samples],
756 actual: vec![y.len()],
757 context: "y length must match number of samples in X".into(),
758 });
759 }
760 if n_samples == 0 {
761 return Err(FerroError::InsufficientSamples {
762 required: 1,
763 actual: 0,
764 context: "ExtraTreesRegressor requires at least one sample".into(),
765 });
766 }
767 if self.n_estimators == 0 {
768 return Err(FerroError::InvalidParameter {
769 name: "n_estimators".into(),
770 reason: "must be at least 1".into(),
771 });
772 }
773
774 let max_features_n = resolve_max_features(self.max_features, n_features);
775 let params = make_tree_params(
776 self.max_depth,
777 self.min_samples_split,
778 self.min_samples_leaf,
779 );
780 let bootstrap = self.bootstrap;
781
782 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
784 let mut master_rng = StdRng::seed_from_u64(seed);
785 (0..self.n_estimators)
786 .map(|_| {
787 use rand::RngCore;
788 master_rng.next_u64()
789 })
790 .collect()
791 } else {
792 (0..self.n_estimators)
793 .map(|_| {
794 use rand::RngCore;
795 rand::rng().next_u64()
796 })
797 .collect()
798 };
799
800 let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
802 let pool = rayon::ThreadPoolBuilder::new()
803 .num_threads(n_jobs)
804 .build()
805 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
806 pool.install(|| {
807 tree_seeds
808 .par_iter()
809 .map(|&seed| {
810 build_single_regression_tree(
811 x,
812 y,
813 n_samples,
814 n_features,
815 max_features_n,
816 ¶ms,
817 bootstrap,
818 seed,
819 )
820 })
821 .collect()
822 })
823 } else {
824 tree_seeds
825 .par_iter()
826 .map(|&seed| {
827 build_single_regression_tree(
828 x,
829 y,
830 n_samples,
831 n_features,
832 max_features_n,
833 ¶ms,
834 bootstrap,
835 seed,
836 )
837 })
838 .collect()
839 };
840
841 let mut total_importances = Array1::<F>::zeros(n_features);
843 for tree_nodes in &trees {
844 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
845 total_importances = total_importances + tree_imp;
846 }
847 let imp_sum: F = total_importances
848 .iter()
849 .copied()
850 .fold(F::zero(), |a, b| a + b);
851 if imp_sum > F::zero() {
852 total_importances.mapv_inplace(|v| v / imp_sum);
853 }
854
855 Ok(FittedExtraTreesRegressor {
856 trees,
857 n_features,
858 feature_importances: total_importances,
859 })
860 }
861}
862
863#[allow(clippy::too_many_arguments)]
865fn build_single_regression_tree<F: Float>(
866 x: &Array2<F>,
867 y: &Array1<F>,
868 n_samples: usize,
869 n_features: usize,
870 max_features_n: usize,
871 params: &TreeParams,
872 bootstrap: bool,
873 seed: u64,
874) -> Vec<Node<F>> {
875 let mut rng = StdRng::seed_from_u64(seed);
876
877 let indices: Vec<usize> = if bootstrap {
878 use rand::RngCore;
879 (0..n_samples)
880 .map(|_| (rng.next_u64() as usize) % n_samples)
881 .collect()
882 } else {
883 (0..n_samples).collect()
884 };
885
886 build_extra_regression_tree_for_ensemble(
887 x,
888 y,
889 &indices,
890 None, params,
892 n_features,
893 max_features_n,
894 &mut rng,
895 )
896}
897
898impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
899 type Output = Array1<F>;
900 type Error = FerroError;
901
902 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
909 if x.ncols() != self.n_features {
910 return Err(FerroError::ShapeMismatch {
911 expected: vec![self.n_features],
912 actual: vec![x.ncols()],
913 context: "number of features must match fitted model".into(),
914 });
915 }
916
917 let n_samples = x.nrows();
918 let n_trees_f = F::from(self.trees.len()).unwrap();
919 let mut predictions = Array1::zeros(n_samples);
920
921 for i in 0..n_samples {
922 let row = x.row(i);
923 let mut sum = F::zero();
924
925 for tree_nodes in &self.trees {
926 let leaf_idx = traverse(tree_nodes, &row);
927 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
928 sum = sum + value;
929 }
930 }
931
932 predictions[i] = sum / n_trees_f;
933 }
934
935 Ok(predictions)
936 }
937}
938
939impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
940 fn feature_importances(&self) -> &Array1<F> {
941 &self.feature_importances
942 }
943}
944
945impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
947 fn fit_pipeline(
948 &self,
949 x: &Array2<F>,
950 y: &Array1<F>,
951 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
952 let fitted = self.fit(x, y)?;
953 Ok(Box::new(fitted))
954 }
955}
956
957impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
958 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
959 self.predict(x)
960 }
961}
962
963#[cfg(test)]
968mod tests {
969 use super::*;
970 use approx::assert_relative_eq;
971 use ndarray::array;
972
973 #[test]
976 fn test_ensemble_classifier_simple() {
977 let x = Array2::from_shape_vec(
978 (8, 2),
979 vec![
980 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
981 ],
982 )
983 .unwrap();
984 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
985
986 let model = ExtraTreesClassifier::<f64>::new()
987 .with_n_estimators(20)
988 .with_random_state(42);
989 let fitted = model.fit(&x, &y).unwrap();
990 let preds = fitted.predict(&x).unwrap();
991
992 assert_eq!(preds, y);
994 }
995
996 #[test]
997 fn test_ensemble_classifier_no_bootstrap() {
998 let x = Array2::from_shape_vec(
999 (8, 2),
1000 vec![
1001 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1002 ],
1003 )
1004 .unwrap();
1005 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1006
1007 let model = ExtraTreesClassifier::<f64>::new()
1009 .with_n_estimators(10)
1010 .with_random_state(42);
1011 assert!(!model.bootstrap);
1012 let fitted = model.fit(&x, &y).unwrap();
1013 let preds = fitted.predict(&x).unwrap();
1014 assert_eq!(preds, y);
1015 }
1016
1017 #[test]
1018 fn test_ensemble_classifier_with_bootstrap() {
1019 let x = Array2::from_shape_vec(
1020 (8, 2),
1021 vec![
1022 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1023 ],
1024 )
1025 .unwrap();
1026 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1027
1028 let model = ExtraTreesClassifier::<f64>::new()
1029 .with_n_estimators(20)
1030 .with_bootstrap(true)
1031 .with_random_state(42);
1032 assert!(model.bootstrap);
1033 let fitted = model.fit(&x, &y).unwrap();
1034 let preds = fitted.predict(&x).unwrap();
1035 assert_eq!(preds.len(), 8);
1036 }
1037
1038 #[test]
1039 fn test_ensemble_classifier_predict_proba() {
1040 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1041 let y = array![0, 0, 0, 1, 1, 1];
1042
1043 let model = ExtraTreesClassifier::<f64>::new()
1044 .with_n_estimators(10)
1045 .with_random_state(42);
1046 let fitted = model.fit(&x, &y).unwrap();
1047 let proba = fitted.predict_proba(&x).unwrap();
1048
1049 assert_eq!(proba.dim(), (6, 2));
1050 for i in 0..6 {
1051 let row_sum = proba.row(i).sum();
1052 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1053 }
1054 }
1055
1056 #[test]
1057 fn test_ensemble_classifier_feature_importances() {
1058 let x = Array2::from_shape_vec(
1059 (8, 2),
1060 vec![
1061 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1062 ],
1063 )
1064 .unwrap();
1065 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1066
1067 let model = ExtraTreesClassifier::<f64>::new()
1068 .with_n_estimators(20)
1069 .with_max_features(MaxFeatures::All)
1070 .with_random_state(42);
1071 let fitted = model.fit(&x, &y).unwrap();
1072 let importances = fitted.feature_importances();
1073
1074 assert_eq!(importances.len(), 2);
1075 let total: f64 = importances.sum();
1076 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1077 assert!(importances[0] > importances[1]);
1079 }
1080
1081 #[test]
1082 fn test_ensemble_classifier_n_estimators() {
1083 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1084 let y = array![0, 0, 1, 1];
1085
1086 let model = ExtraTreesClassifier::<f64>::new()
1087 .with_n_estimators(15)
1088 .with_random_state(42);
1089 let fitted = model.fit(&x, &y).unwrap();
1090 assert_eq!(fitted.n_estimators(), 15);
1091 }
1092
1093 #[test]
1094 fn test_ensemble_classifier_classes() {
1095 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1096 let y = array![0, 0, 0, 3, 3, 3]; let model = ExtraTreesClassifier::<f64>::new()
1099 .with_n_estimators(5)
1100 .with_random_state(42);
1101 let fitted = model.fit(&x, &y).unwrap();
1102 assert_eq!(fitted.classes(), &[0, 3]);
1103 assert_eq!(fitted.n_classes(), 2);
1104 }
1105
1106 #[test]
1107 fn test_ensemble_classifier_shape_mismatch() {
1108 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1109 let y = array![0, 0];
1110 let model = ExtraTreesClassifier::<f64>::new();
1111 assert!(model.fit(&x, &y).is_err());
1112 }
1113
1114 #[test]
1115 fn test_ensemble_classifier_empty_data() {
1116 let x = Array2::<f64>::zeros((0, 2));
1117 let y = Array1::<usize>::zeros(0);
1118 let model = ExtraTreesClassifier::<f64>::new();
1119 assert!(model.fit(&x, &y).is_err());
1120 }
1121
1122 #[test]
1123 fn test_ensemble_classifier_zero_estimators() {
1124 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1125 let y = array![0, 1];
1126 let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
1127 assert!(model.fit(&x, &y).is_err());
1128 }
1129
1130 #[test]
1131 fn test_ensemble_classifier_deterministic() {
1132 let x = Array2::from_shape_vec(
1133 (8, 2),
1134 vec![
1135 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1136 ],
1137 )
1138 .unwrap();
1139 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1140
1141 let model1 = ExtraTreesClassifier::<f64>::new()
1142 .with_n_estimators(10)
1143 .with_random_state(123);
1144 let model2 = ExtraTreesClassifier::<f64>::new()
1145 .with_n_estimators(10)
1146 .with_random_state(123);
1147
1148 let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1149 let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1150 assert_eq!(preds1, preds2);
1151 }
1152
1153 #[test]
1154 fn test_ensemble_classifier_predict_shape_mismatch() {
1155 let x =
1156 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1157 let y = array![0, 0, 1, 1];
1158
1159 let model = ExtraTreesClassifier::<f64>::new()
1160 .with_n_estimators(5)
1161 .with_random_state(42);
1162 let fitted = model.fit(&x, &y).unwrap();
1163
1164 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1165 assert!(fitted.predict(&x_wrong).is_err());
1166 }
1167
1168 #[test]
1171 fn test_ensemble_regressor_simple() {
1172 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1173 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1174
1175 let model = ExtraTreesRegressor::<f64>::new()
1176 .with_n_estimators(20)
1177 .with_random_state(42);
1178 let fitted = model.fit(&x, &y).unwrap();
1179 let preds = fitted.predict(&x).unwrap();
1180
1181 assert_eq!(preds.len(), 6);
1182 for i in 0..6 {
1184 assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1185 }
1186 }
1187
1188 #[test]
1189 fn test_ensemble_regressor_constant_target() {
1190 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1191 let y = array![5.0, 5.0, 5.0, 5.0];
1192
1193 let model = ExtraTreesRegressor::<f64>::new()
1194 .with_n_estimators(10)
1195 .with_random_state(42);
1196 let fitted = model.fit(&x, &y).unwrap();
1197 let preds = fitted.predict(&x).unwrap();
1198
1199 for &p in preds.iter() {
1200 assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1201 }
1202 }
1203
1204 #[test]
1205 fn test_ensemble_regressor_no_bootstrap() {
1206 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1207 let y = array![1.0, 2.0, 3.0, 4.0];
1208
1209 let model = ExtraTreesRegressor::<f64>::new()
1210 .with_n_estimators(10)
1211 .with_random_state(42);
1212 assert!(!model.bootstrap);
1213 let fitted = model.fit(&x, &y).unwrap();
1214 let preds = fitted.predict(&x).unwrap();
1215 assert_eq!(preds.len(), 4);
1216 }
1217
1218 #[test]
1219 fn test_ensemble_regressor_with_bootstrap() {
1220 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1221 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1222
1223 let model = ExtraTreesRegressor::<f64>::new()
1224 .with_n_estimators(10)
1225 .with_bootstrap(true)
1226 .with_random_state(42);
1227 assert!(model.bootstrap);
1228 let fitted = model.fit(&x, &y).unwrap();
1229 let preds = fitted.predict(&x).unwrap();
1230 assert_eq!(preds.len(), 6);
1231 }
1232
1233 #[test]
1234 fn test_ensemble_regressor_feature_importances() {
1235 let x = Array2::from_shape_vec(
1236 (8, 2),
1237 vec![
1238 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1239 ],
1240 )
1241 .unwrap();
1242 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1243
1244 let model = ExtraTreesRegressor::<f64>::new()
1245 .with_n_estimators(20)
1246 .with_max_features(MaxFeatures::All)
1247 .with_random_state(42);
1248 let fitted = model.fit(&x, &y).unwrap();
1249 let importances = fitted.feature_importances();
1250
1251 assert_eq!(importances.len(), 2);
1252 let total: f64 = importances.sum();
1253 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1254 assert!(importances[0] > importances[1]);
1255 }
1256
1257 #[test]
1258 fn test_ensemble_regressor_n_estimators() {
1259 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1260 let y = array![1.0, 2.0, 3.0, 4.0];
1261
1262 let model = ExtraTreesRegressor::<f64>::new()
1263 .with_n_estimators(7)
1264 .with_random_state(42);
1265 let fitted = model.fit(&x, &y).unwrap();
1266 assert_eq!(fitted.n_estimators(), 7);
1267 }
1268
1269 #[test]
1270 fn test_ensemble_regressor_shape_mismatch() {
1271 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1272 let y = array![1.0, 2.0];
1273 let model = ExtraTreesRegressor::<f64>::new();
1274 assert!(model.fit(&x, &y).is_err());
1275 }
1276
1277 #[test]
1278 fn test_ensemble_regressor_empty_data() {
1279 let x = Array2::<f64>::zeros((0, 2));
1280 let y = Array1::<f64>::zeros(0);
1281 let model = ExtraTreesRegressor::<f64>::new();
1282 assert!(model.fit(&x, &y).is_err());
1283 }
1284
1285 #[test]
1286 fn test_ensemble_regressor_zero_estimators() {
1287 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1288 let y = array![1.0, 2.0];
1289 let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
1290 assert!(model.fit(&x, &y).is_err());
1291 }
1292
1293 #[test]
1294 fn test_ensemble_regressor_deterministic() {
1295 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1296 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1297
1298 let model1 = ExtraTreesRegressor::<f64>::new()
1299 .with_n_estimators(10)
1300 .with_random_state(99);
1301 let model2 = ExtraTreesRegressor::<f64>::new()
1302 .with_n_estimators(10)
1303 .with_random_state(99);
1304
1305 let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1306 let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1307
1308 for i in 0..6 {
1309 assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1310 }
1311 }
1312
1313 #[test]
1314 fn test_ensemble_regressor_predict_shape_mismatch() {
1315 let x =
1316 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1317 let y = array![1.0, 2.0, 3.0, 4.0];
1318
1319 let model = ExtraTreesRegressor::<f64>::new()
1320 .with_n_estimators(5)
1321 .with_random_state(42);
1322 let fitted = model.fit(&x, &y).unwrap();
1323
1324 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1325 assert!(fitted.predict(&x_wrong).is_err());
1326 }
1327
1328 #[test]
1331 fn test_ensemble_classifier_builder() {
1332 let model = ExtraTreesClassifier::<f64>::new()
1333 .with_n_estimators(50)
1334 .with_max_depth(Some(5))
1335 .with_min_samples_split(10)
1336 .with_min_samples_leaf(3)
1337 .with_max_features(MaxFeatures::Log2)
1338 .with_bootstrap(true)
1339 .with_criterion(ClassificationCriterion::Entropy)
1340 .with_random_state(42)
1341 .with_n_jobs(4);
1342
1343 assert_eq!(model.n_estimators, 50);
1344 assert_eq!(model.max_depth, Some(5));
1345 assert_eq!(model.min_samples_split, 10);
1346 assert_eq!(model.min_samples_leaf, 3);
1347 assert_eq!(model.max_features, MaxFeatures::Log2);
1348 assert!(model.bootstrap);
1349 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1350 assert_eq!(model.random_state, Some(42));
1351 assert_eq!(model.n_jobs, Some(4));
1352 }
1353
1354 #[test]
1355 fn test_ensemble_regressor_builder() {
1356 let model = ExtraTreesRegressor::<f64>::new()
1357 .with_n_estimators(25)
1358 .with_max_depth(Some(8))
1359 .with_min_samples_split(5)
1360 .with_min_samples_leaf(2)
1361 .with_max_features(MaxFeatures::Fraction(0.5))
1362 .with_bootstrap(true)
1363 .with_random_state(99)
1364 .with_n_jobs(2);
1365
1366 assert_eq!(model.n_estimators, 25);
1367 assert_eq!(model.max_depth, Some(8));
1368 assert_eq!(model.min_samples_split, 5);
1369 assert_eq!(model.min_samples_leaf, 2);
1370 assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
1371 assert!(model.bootstrap);
1372 assert_eq!(model.random_state, Some(99));
1373 assert_eq!(model.n_jobs, Some(2));
1374 }
1375
1376 #[test]
1377 fn test_ensemble_classifier_default() {
1378 let model = ExtraTreesClassifier::<f64>::default();
1379 assert_eq!(model.n_estimators, 100);
1380 assert_eq!(model.max_depth, None);
1381 assert_eq!(model.min_samples_split, 2);
1382 assert_eq!(model.min_samples_leaf, 1);
1383 assert_eq!(model.max_features, MaxFeatures::Sqrt);
1384 assert!(!model.bootstrap);
1385 assert_eq!(model.criterion, ClassificationCriterion::Gini);
1386 assert_eq!(model.random_state, None);
1387 assert_eq!(model.n_jobs, None);
1388 }
1389
1390 #[test]
1391 fn test_ensemble_regressor_default() {
1392 let model = ExtraTreesRegressor::<f64>::default();
1393 assert_eq!(model.n_estimators, 100);
1394 assert_eq!(model.max_depth, None);
1395 assert_eq!(model.min_samples_split, 2);
1396 assert_eq!(model.min_samples_leaf, 1);
1397 assert_eq!(model.max_features, MaxFeatures::All);
1398 assert!(!model.bootstrap);
1399 assert_eq!(model.random_state, None);
1400 assert_eq!(model.n_jobs, None);
1401 }
1402}