1use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::{Float, FromPrimitive, ToPrimitive};
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use rand::seq::index::sample as rand_sample_indices;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 ClassificationCriterion, Node, RegressionCriterion, TreeParams, compute_feature_importances,
41 traverse,
42};
43use crate::random_forest::MaxFeatures;
44
45struct ClassificationData<'a, F> {
51 x: &'a Array2<F>,
52 y: &'a [usize],
53 n_classes: usize,
54 feature_indices: Option<&'a [usize]>,
55 criterion: ClassificationCriterion,
56}
57
58struct RegressionData<'a, F> {
60 x: &'a Array2<F>,
61 y: &'a Array1<F>,
62 feature_indices: Option<&'a [usize]>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ExtraTreeClassifier<F> {
81 pub max_depth: Option<usize>,
83 pub min_samples_split: usize,
85 pub min_samples_leaf: usize,
87 pub max_features: MaxFeatures,
89 pub criterion: ClassificationCriterion,
91 pub random_state: Option<u64>,
93 _marker: std::marker::PhantomData<F>,
94}
95
96impl<F: Float> ExtraTreeClassifier<F> {
97 #[must_use]
103 pub fn new() -> Self {
104 Self {
105 max_depth: None,
106 min_samples_split: 2,
107 min_samples_leaf: 1,
108 max_features: MaxFeatures::Sqrt,
109 criterion: ClassificationCriterion::Gini,
110 random_state: None,
111 _marker: std::marker::PhantomData,
112 }
113 }
114
115 #[must_use]
117 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
118 self.max_depth = max_depth;
119 self
120 }
121
122 #[must_use]
124 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
125 self.min_samples_split = min_samples_split;
126 self
127 }
128
129 #[must_use]
131 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
132 self.min_samples_leaf = min_samples_leaf;
133 self
134 }
135
136 #[must_use]
138 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
139 self.max_features = max_features;
140 self
141 }
142
143 #[must_use]
145 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
146 self.criterion = criterion;
147 self
148 }
149
150 #[must_use]
152 pub fn with_random_state(mut self, seed: u64) -> Self {
153 self.random_state = Some(seed);
154 self
155 }
156}
157
158impl<F: Float> Default for ExtraTreeClassifier<F> {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[derive(Debug, Clone)]
174pub struct FittedExtraTreeClassifier<F> {
175 nodes: Vec<Node<F>>,
177 classes: Vec<usize>,
179 n_features: usize,
181 feature_importances: Array1<F>,
183}
184
185impl<F: Float + Send + Sync + 'static> FittedExtraTreeClassifier<F> {
186 #[must_use]
188 pub fn nodes(&self) -> &[Node<F>] {
189 &self.nodes
190 }
191
192 #[must_use]
194 pub fn n_features(&self) -> usize {
195 self.n_features
196 }
197
198 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207 if x.ncols() != self.n_features {
208 return Err(FerroError::ShapeMismatch {
209 expected: vec![self.n_features],
210 actual: vec![x.ncols()],
211 context: "number of features must match fitted model".into(),
212 });
213 }
214 let n_samples = x.nrows();
215 let n_classes = self.classes.len();
216 let mut proba = Array2::zeros((n_samples, n_classes));
217 for i in 0..n_samples {
218 let row = x.row(i);
219 let leaf = traverse(&self.nodes, &row);
220 if let Node::Leaf {
221 class_distribution: Some(ref dist),
222 ..
223 } = self.nodes[leaf]
224 {
225 for (j, &p) in dist.iter().enumerate() {
226 proba[[i, j]] = p;
227 }
228 }
229 }
230 Ok(proba)
231 }
232}
233
234impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreeClassifier<F> {
235 type Fitted = FittedExtraTreeClassifier<F>;
236 type Error = FerroError;
237
238 fn fit(
247 &self,
248 x: &Array2<F>,
249 y: &Array1<usize>,
250 ) -> Result<FittedExtraTreeClassifier<F>, FerroError> {
251 let (n_samples, n_features) = x.dim();
252
253 if n_samples != y.len() {
254 return Err(FerroError::ShapeMismatch {
255 expected: vec![n_samples],
256 actual: vec![y.len()],
257 context: "y length must match number of samples in X".into(),
258 });
259 }
260 if n_samples == 0 {
261 return Err(FerroError::InsufficientSamples {
262 required: 1,
263 actual: 0,
264 context: "ExtraTreeClassifier requires at least one sample".into(),
265 });
266 }
267 if self.min_samples_split < 2 {
268 return Err(FerroError::InvalidParameter {
269 name: "min_samples_split".into(),
270 reason: "must be at least 2".into(),
271 });
272 }
273 if self.min_samples_leaf < 1 {
274 return Err(FerroError::InvalidParameter {
275 name: "min_samples_leaf".into(),
276 reason: "must be at least 1".into(),
277 });
278 }
279
280 let mut classes: Vec<usize> = y.iter().copied().collect();
282 classes.sort_unstable();
283 classes.dedup();
284 let n_classes = classes.len();
285
286 let y_mapped: Vec<usize> = y
288 .iter()
289 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
290 .collect();
291
292 let indices: Vec<usize> = (0..n_samples).collect();
293
294 let max_features_n = resolve_max_features(self.max_features, n_features);
295
296 let mut rng = if let Some(seed) = self.random_state {
297 StdRng::seed_from_u64(seed)
298 } else {
299 StdRng::from_os_rng()
300 };
301
302 let data = ClassificationData {
303 x,
304 y: &y_mapped,
305 n_classes,
306 feature_indices: None,
307 criterion: self.criterion,
308 };
309 let params = TreeParams {
310 max_depth: self.max_depth,
311 min_samples_split: self.min_samples_split,
312 min_samples_leaf: self.min_samples_leaf,
313 };
314
315 let mut nodes: Vec<Node<F>> = Vec::new();
316 build_extra_classification_tree(
317 &data,
318 &indices,
319 &mut nodes,
320 0,
321 ¶ms,
322 n_features,
323 max_features_n,
324 &mut rng,
325 );
326
327 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
328
329 Ok(FittedExtraTreeClassifier {
330 nodes,
331 classes,
332 n_features,
333 feature_importances,
334 })
335 }
336}
337
338impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeClassifier<F> {
339 type Output = Array1<usize>;
340 type Error = FerroError;
341
342 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
349 if x.ncols() != self.n_features {
350 return Err(FerroError::ShapeMismatch {
351 expected: vec![self.n_features],
352 actual: vec![x.ncols()],
353 context: "number of features must match fitted model".into(),
354 });
355 }
356 let n_samples = x.nrows();
357 let mut predictions = Array1::zeros(n_samples);
358 for i in 0..n_samples {
359 let row = x.row(i);
360 let leaf = traverse(&self.nodes, &row);
361 if let Node::Leaf { value, .. } = self.nodes[leaf] {
362 predictions[i] = float_to_usize(value);
363 }
364 }
365 Ok(predictions)
366 }
367}
368
369impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeClassifier<F> {
370 fn feature_importances(&self) -> &Array1<F> {
371 &self.feature_importances
372 }
373}
374
375impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreeClassifier<F> {
376 fn classes(&self) -> &[usize] {
377 &self.classes
378 }
379
380 fn n_classes(&self) -> usize {
381 self.classes.len()
382 }
383}
384
385impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
387 for ExtraTreeClassifier<F>
388{
389 fn fit_pipeline(
390 &self,
391 x: &Array2<F>,
392 y: &Array1<F>,
393 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
394 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
395 let fitted = self.fit(x, &y_usize)?;
396 Ok(Box::new(FittedExtraTreeClassifierPipelineAdapter(fitted)))
397 }
398}
399
400struct FittedExtraTreeClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
402 FittedExtraTreeClassifier<F>,
403);
404
405impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
406 for FittedExtraTreeClassifierPipelineAdapter<F>
407{
408 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
409 let preds = self.0.predict(x)?;
410 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
411 }
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ExtraTreeRegressor<F> {
430 pub max_depth: Option<usize>,
432 pub min_samples_split: usize,
434 pub min_samples_leaf: usize,
436 pub max_features: MaxFeatures,
438 pub criterion: RegressionCriterion,
440 pub random_state: Option<u64>,
442 _marker: std::marker::PhantomData<F>,
443}
444
445impl<F: Float> ExtraTreeRegressor<F> {
446 #[must_use]
452 pub fn new() -> Self {
453 Self {
454 max_depth: None,
455 min_samples_split: 2,
456 min_samples_leaf: 1,
457 max_features: MaxFeatures::All,
458 criterion: RegressionCriterion::Mse,
459 random_state: None,
460 _marker: std::marker::PhantomData,
461 }
462 }
463
464 #[must_use]
466 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
467 self.max_depth = max_depth;
468 self
469 }
470
471 #[must_use]
473 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
474 self.min_samples_split = min_samples_split;
475 self
476 }
477
478 #[must_use]
480 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
481 self.min_samples_leaf = min_samples_leaf;
482 self
483 }
484
485 #[must_use]
487 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
488 self.max_features = max_features;
489 self
490 }
491
492 #[must_use]
494 pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
495 self.criterion = criterion;
496 self
497 }
498
499 #[must_use]
501 pub fn with_random_state(mut self, seed: u64) -> Self {
502 self.random_state = Some(seed);
503 self
504 }
505}
506
507impl<F: Float> Default for ExtraTreeRegressor<F> {
508 fn default() -> Self {
509 Self::new()
510 }
511}
512
513#[derive(Debug, Clone)]
521pub struct FittedExtraTreeRegressor<F> {
522 nodes: Vec<Node<F>>,
524 n_features: usize,
526 feature_importances: Array1<F>,
528}
529
530impl<F: Float + Send + Sync + 'static> FittedExtraTreeRegressor<F> {
531 #[must_use]
533 pub fn nodes(&self) -> &[Node<F>] {
534 &self.nodes
535 }
536
537 #[must_use]
539 pub fn n_features(&self) -> usize {
540 self.n_features
541 }
542}
543
544impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreeRegressor<F> {
545 type Fitted = FittedExtraTreeRegressor<F>;
546 type Error = FerroError;
547
548 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedExtraTreeRegressor<F>, FerroError> {
557 let (n_samples, n_features) = x.dim();
558
559 if n_samples != y.len() {
560 return Err(FerroError::ShapeMismatch {
561 expected: vec![n_samples],
562 actual: vec![y.len()],
563 context: "y length must match number of samples in X".into(),
564 });
565 }
566 if n_samples == 0 {
567 return Err(FerroError::InsufficientSamples {
568 required: 1,
569 actual: 0,
570 context: "ExtraTreeRegressor requires at least one sample".into(),
571 });
572 }
573 if self.min_samples_split < 2 {
574 return Err(FerroError::InvalidParameter {
575 name: "min_samples_split".into(),
576 reason: "must be at least 2".into(),
577 });
578 }
579 if self.min_samples_leaf < 1 {
580 return Err(FerroError::InvalidParameter {
581 name: "min_samples_leaf".into(),
582 reason: "must be at least 1".into(),
583 });
584 }
585
586 let indices: Vec<usize> = (0..n_samples).collect();
587 let max_features_n = resolve_max_features(self.max_features, n_features);
588
589 let mut rng = if let Some(seed) = self.random_state {
590 StdRng::seed_from_u64(seed)
591 } else {
592 StdRng::from_os_rng()
593 };
594
595 let data = RegressionData {
596 x,
597 y,
598 feature_indices: None,
599 };
600 let params = TreeParams {
601 max_depth: self.max_depth,
602 min_samples_split: self.min_samples_split,
603 min_samples_leaf: self.min_samples_leaf,
604 };
605
606 let mut nodes: Vec<Node<F>> = Vec::new();
607 build_extra_regression_tree(
608 &data,
609 &indices,
610 &mut nodes,
611 0,
612 ¶ms,
613 n_features,
614 max_features_n,
615 &mut rng,
616 );
617
618 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
619
620 Ok(FittedExtraTreeRegressor {
621 nodes,
622 n_features,
623 feature_importances,
624 })
625 }
626}
627
628impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeRegressor<F> {
629 type Output = Array1<F>;
630 type Error = FerroError;
631
632 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
639 if x.ncols() != self.n_features {
640 return Err(FerroError::ShapeMismatch {
641 expected: vec![self.n_features],
642 actual: vec![x.ncols()],
643 context: "number of features must match fitted model".into(),
644 });
645 }
646 let n_samples = x.nrows();
647 let mut predictions = Array1::zeros(n_samples);
648 for i in 0..n_samples {
649 let row = x.row(i);
650 let leaf = traverse(&self.nodes, &row);
651 if let Node::Leaf { value, .. } = self.nodes[leaf] {
652 predictions[i] = value;
653 }
654 }
655 Ok(predictions)
656 }
657}
658
659impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeRegressor<F> {
660 fn feature_importances(&self) -> &Array1<F> {
661 &self.feature_importances
662 }
663}
664
665impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreeRegressor<F> {
667 fn fit_pipeline(
668 &self,
669 x: &Array2<F>,
670 y: &Array1<F>,
671 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
672 let fitted = self.fit(x, y)?;
673 Ok(Box::new(fitted))
674 }
675}
676
677impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreeRegressor<F> {
678 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
679 self.predict(x)
680 }
681}
682
683fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
689 let result = match strategy {
690 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
691 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
692 MaxFeatures::All => n_features,
693 MaxFeatures::Fixed(n) => n.min(n_features),
694 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
695 };
696 result.max(1).min(n_features)
697}
698
699fn float_to_usize<F: Float>(v: F) -> usize {
701 v.to_f64().map(|f| f.round() as usize).unwrap_or(0)
702}
703
704fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
706 use rand::RngCore;
707 let u = (rng.next_u64() as f64) / (u64::MAX as f64);
709 let range = max_val - min_val;
710 min_val + F::from(u).unwrap() * range
711}
712
713fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
715 if total == 0 {
716 return F::zero();
717 }
718 let total_f = F::from(total).unwrap();
719 let mut impurity = F::one();
720 for &count in class_counts {
721 let p = F::from(count).unwrap() / total_f;
722 impurity = impurity - p * p;
723 }
724 impurity
725}
726
727fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
729 if total == 0 {
730 return F::zero();
731 }
732 let total_f = F::from(total).unwrap();
733 let mut ent = F::zero();
734 for &count in class_counts {
735 if count > 0 {
736 let p = F::from(count).unwrap() / total_f;
737 ent = ent - p * p.ln();
738 }
739 }
740 ent
741}
742
743fn compute_impurity<F: Float>(
745 class_counts: &[usize],
746 total: usize,
747 criterion: ClassificationCriterion,
748) -> F {
749 match criterion {
750 ClassificationCriterion::Gini => gini_impurity(class_counts, total),
751 ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
752 }
753}
754
755fn make_classification_leaf<F: Float>(
757 nodes: &mut Vec<Node<F>>,
758 class_counts: &[usize],
759 n_classes: usize,
760 n_samples: usize,
761) -> usize {
762 let majority_class = class_counts
763 .iter()
764 .enumerate()
765 .max_by_key(|&(_, &count)| count)
766 .map(|(i, _)| i)
767 .unwrap_or(0);
768
769 let total_f = if n_samples > 0 {
770 F::from(n_samples).unwrap()
771 } else {
772 F::one()
773 };
774 let distribution: Vec<F> = (0..n_classes)
775 .map(|c| F::from(class_counts[c]).unwrap() / total_f)
776 .collect();
777
778 let idx = nodes.len();
779 nodes.push(Node::Leaf {
780 value: F::from(majority_class).unwrap(),
781 class_distribution: Some(distribution),
782 n_samples,
783 });
784 idx
785}
786
787fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
789 if indices.is_empty() {
790 return F::zero();
791 }
792 let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
793 sum / F::from(indices.len()).unwrap()
794}
795
796#[allow(clippy::too_many_arguments)]
806fn build_extra_classification_tree<F: Float>(
807 data: &ClassificationData<'_, F>,
808 indices: &[usize],
809 nodes: &mut Vec<Node<F>>,
810 depth: usize,
811 params: &TreeParams,
812 n_features: usize,
813 max_features_n: usize,
814 rng: &mut StdRng,
815) -> usize {
816 let n = indices.len();
817
818 let mut class_counts = vec![0usize; data.n_classes];
819 for &i in indices {
820 class_counts[data.y[i]] += 1;
821 }
822
823 let should_stop = n < params.min_samples_split
824 || params.max_depth.is_some_and(|d| depth >= d)
825 || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
826
827 if should_stop {
828 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
829 }
830
831 let best = find_random_classification_split(
832 data,
833 indices,
834 params.min_samples_leaf,
835 n_features,
836 max_features_n,
837 rng,
838 );
839
840 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
841 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
842 .iter()
843 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
844
845 if left_indices.len() < params.min_samples_leaf
847 || right_indices.len() < params.min_samples_leaf
848 {
849 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
850 }
851
852 let node_idx = nodes.len();
853 nodes.push(Node::Leaf {
854 value: F::zero(),
855 class_distribution: None,
856 n_samples: 0,
857 }); let left_idx = build_extra_classification_tree(
860 data,
861 &left_indices,
862 nodes,
863 depth + 1,
864 params,
865 n_features,
866 max_features_n,
867 rng,
868 );
869 let right_idx = build_extra_classification_tree(
870 data,
871 &right_indices,
872 nodes,
873 depth + 1,
874 params,
875 n_features,
876 max_features_n,
877 rng,
878 );
879
880 nodes[node_idx] = Node::Split {
881 feature: best_feature,
882 threshold: best_threshold,
883 left: left_idx,
884 right: right_idx,
885 impurity_decrease: best_impurity_decrease,
886 n_samples: n,
887 };
888
889 node_idx
890 } else {
891 make_classification_leaf(nodes, &class_counts, data.n_classes, n)
892 }
893}
894
895#[allow(clippy::too_many_arguments)]
901fn find_random_classification_split<F: Float>(
902 data: &ClassificationData<'_, F>,
903 indices: &[usize],
904 min_samples_leaf: usize,
905 n_features: usize,
906 max_features_n: usize,
907 rng: &mut StdRng,
908) -> Option<(usize, F, F)> {
909 let n = indices.len();
910 let n_f = F::from(n).unwrap();
911
912 let mut parent_counts = vec![0usize; data.n_classes];
913 for &i in indices {
914 parent_counts[data.y[i]] += 1;
915 }
916 let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
917
918 let mut best_score = F::neg_infinity();
919 let mut best_feature = 0;
920 let mut best_threshold = F::zero();
921
922 let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
924 let k = max_features_n.min(feat_indices.len());
926 rand_sample_indices(rng, feat_indices.len(), k)
927 .into_vec()
928 .into_iter()
929 .map(|i| feat_indices[i])
930 .collect()
931 } else {
932 let k = max_features_n.min(n_features);
933 rand_sample_indices(rng, n_features, k).into_vec()
934 };
935
936 for feat in feature_subset {
937 let mut feat_min = F::infinity();
939 let mut feat_max = F::neg_infinity();
940 for &i in indices {
941 let val = data.x[[i, feat]];
942 if val < feat_min {
943 feat_min = val;
944 }
945 if val > feat_max {
946 feat_max = val;
947 }
948 }
949
950 if feat_min >= feat_max {
952 continue;
953 }
954
955 let threshold = random_threshold(rng, feat_min, feat_max);
957
958 let mut left_counts = vec![0usize; data.n_classes];
960 let mut right_counts = vec![0usize; data.n_classes];
961 let mut left_n = 0usize;
962
963 for &i in indices {
964 let cls = data.y[i];
965 if data.x[[i, feat]] <= threshold {
966 left_counts[cls] += 1;
967 left_n += 1;
968 } else {
969 right_counts[cls] += 1;
970 }
971 }
972
973 let right_n = n - left_n;
974 if left_n < min_samples_leaf || right_n < min_samples_leaf {
975 continue;
976 }
977
978 let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
979 let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
980 let left_weight = F::from(left_n).unwrap() / n_f;
981 let right_weight = F::from(right_n).unwrap() / n_f;
982 let weighted_child_impurity = left_weight * left_impurity + right_weight * right_impurity;
983 let impurity_decrease = parent_impurity - weighted_child_impurity;
984
985 if impurity_decrease > best_score {
986 best_score = impurity_decrease;
987 best_feature = feat;
988 best_threshold = threshold;
989 }
990 }
991
992 if best_score > F::zero() {
993 Some((best_feature, best_threshold, best_score * n_f))
994 } else {
995 None
996 }
997}
998
999#[allow(clippy::too_many_arguments)]
1005fn build_extra_regression_tree<F: Float>(
1006 data: &RegressionData<'_, F>,
1007 indices: &[usize],
1008 nodes: &mut Vec<Node<F>>,
1009 depth: usize,
1010 params: &TreeParams,
1011 n_features: usize,
1012 max_features_n: usize,
1013 rng: &mut StdRng,
1014) -> usize {
1015 let n = indices.len();
1016 let mean = mean_value(data.y, indices);
1017
1018 let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1019
1020 if should_stop {
1021 let idx = nodes.len();
1022 nodes.push(Node::Leaf {
1023 value: mean,
1024 class_distribution: None,
1025 n_samples: n,
1026 });
1027 return idx;
1028 }
1029
1030 let parent_sum_sq: F = indices
1032 .iter()
1033 .map(|&i| {
1034 let diff = data.y[i] - mean;
1035 diff * diff
1036 })
1037 .fold(F::zero(), |a, b| a + b);
1038 let parent_mse = parent_sum_sq / F::from(n).unwrap();
1039
1040 if parent_mse <= F::epsilon() {
1041 let idx = nodes.len();
1042 nodes.push(Node::Leaf {
1043 value: mean,
1044 class_distribution: None,
1045 n_samples: n,
1046 });
1047 return idx;
1048 }
1049
1050 let best = find_random_regression_split(
1051 data,
1052 indices,
1053 params.min_samples_leaf,
1054 n_features,
1055 max_features_n,
1056 rng,
1057 );
1058
1059 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1060 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1061 .iter()
1062 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1063
1064 if left_indices.len() < params.min_samples_leaf
1066 || right_indices.len() < params.min_samples_leaf
1067 {
1068 let idx = nodes.len();
1069 nodes.push(Node::Leaf {
1070 value: mean,
1071 class_distribution: None,
1072 n_samples: n,
1073 });
1074 return idx;
1075 }
1076
1077 let node_idx = nodes.len();
1078 nodes.push(Node::Leaf {
1079 value: F::zero(),
1080 class_distribution: None,
1081 n_samples: 0,
1082 }); let left_idx = build_extra_regression_tree(
1085 data,
1086 &left_indices,
1087 nodes,
1088 depth + 1,
1089 params,
1090 n_features,
1091 max_features_n,
1092 rng,
1093 );
1094 let right_idx = build_extra_regression_tree(
1095 data,
1096 &right_indices,
1097 nodes,
1098 depth + 1,
1099 params,
1100 n_features,
1101 max_features_n,
1102 rng,
1103 );
1104
1105 nodes[node_idx] = Node::Split {
1106 feature: best_feature,
1107 threshold: best_threshold,
1108 left: left_idx,
1109 right: right_idx,
1110 impurity_decrease: best_impurity_decrease,
1111 n_samples: n,
1112 };
1113
1114 node_idx
1115 } else {
1116 let idx = nodes.len();
1117 nodes.push(Node::Leaf {
1118 value: mean,
1119 class_distribution: None,
1120 n_samples: n,
1121 });
1122 idx
1123 }
1124}
1125
1126#[allow(clippy::too_many_arguments)]
1132fn find_random_regression_split<F: Float>(
1133 data: &RegressionData<'_, F>,
1134 indices: &[usize],
1135 min_samples_leaf: usize,
1136 n_features: usize,
1137 max_features_n: usize,
1138 rng: &mut StdRng,
1139) -> Option<(usize, F, F)> {
1140 let n = indices.len();
1141 let n_f = F::from(n).unwrap();
1142
1143 let parent_sum: F = indices
1144 .iter()
1145 .map(|&i| data.y[i])
1146 .fold(F::zero(), |a, b| a + b);
1147 let parent_sum_sq: F = indices
1148 .iter()
1149 .map(|&i| data.y[i] * data.y[i])
1150 .fold(F::zero(), |a, b| a + b);
1151 let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1152
1153 let mut best_score = F::neg_infinity();
1154 let mut best_feature = 0;
1155 let mut best_threshold = F::zero();
1156
1157 let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
1159 let k = max_features_n.min(feat_indices.len());
1160 rand_sample_indices(rng, feat_indices.len(), k)
1161 .into_vec()
1162 .into_iter()
1163 .map(|i| feat_indices[i])
1164 .collect()
1165 } else {
1166 let k = max_features_n.min(n_features);
1167 rand_sample_indices(rng, n_features, k).into_vec()
1168 };
1169
1170 for feat in feature_subset {
1171 let mut feat_min = F::infinity();
1173 let mut feat_max = F::neg_infinity();
1174 for &i in indices {
1175 let val = data.x[[i, feat]];
1176 if val < feat_min {
1177 feat_min = val;
1178 }
1179 if val > feat_max {
1180 feat_max = val;
1181 }
1182 }
1183
1184 if feat_min >= feat_max {
1186 continue;
1187 }
1188
1189 let threshold = random_threshold(rng, feat_min, feat_max);
1191
1192 let mut left_sum = F::zero();
1194 let mut left_sum_sq = F::zero();
1195 let mut left_n: usize = 0;
1196
1197 for &i in indices {
1198 if data.x[[i, feat]] <= threshold {
1199 let val = data.y[i];
1200 left_sum = left_sum + val;
1201 left_sum_sq = left_sum_sq + val * val;
1202 left_n += 1;
1203 }
1204 }
1205
1206 let right_n = n - left_n;
1207 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1208 continue;
1209 }
1210
1211 let left_n_f = F::from(left_n).unwrap();
1212 let right_n_f = F::from(right_n).unwrap();
1213
1214 let left_mean = left_sum / left_n_f;
1215 let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1216
1217 let right_sum = parent_sum - left_sum;
1218 let right_sum_sq = parent_sum_sq - left_sum_sq;
1219 let right_mean = right_sum / right_n_f;
1220 let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1221
1222 let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1223 let mse_decrease = parent_mse - weighted_child_mse;
1224
1225 if mse_decrease > best_score {
1226 best_score = mse_decrease;
1227 best_feature = feat;
1228 best_threshold = threshold;
1229 }
1230 }
1231
1232 if best_score > F::zero() {
1233 Some((best_feature, best_threshold, best_score * n_f))
1234 } else {
1235 None
1236 }
1237}
1238
1239#[allow(clippy::too_many_arguments)]
1247pub(crate) fn build_extra_classification_tree_for_ensemble<F: Float>(
1248 x: &Array2<F>,
1249 y: &[usize],
1250 n_classes: usize,
1251 indices: &[usize],
1252 feature_indices: Option<&[usize]>,
1253 params: &TreeParams,
1254 criterion: ClassificationCriterion,
1255 n_features: usize,
1256 max_features_n: usize,
1257 rng: &mut StdRng,
1258) -> Vec<Node<F>> {
1259 let data = ClassificationData {
1260 x,
1261 y,
1262 n_classes,
1263 feature_indices,
1264 criterion,
1265 };
1266 let mut nodes = Vec::new();
1267 build_extra_classification_tree(
1268 &data,
1269 indices,
1270 &mut nodes,
1271 0,
1272 params,
1273 n_features,
1274 max_features_n,
1275 rng,
1276 );
1277 nodes
1278}
1279
1280#[allow(clippy::too_many_arguments)]
1284pub(crate) fn build_extra_regression_tree_for_ensemble<F: Float>(
1285 x: &Array2<F>,
1286 y: &Array1<F>,
1287 indices: &[usize],
1288 feature_indices: Option<&[usize]>,
1289 params: &TreeParams,
1290 n_features: usize,
1291 max_features_n: usize,
1292 rng: &mut StdRng,
1293) -> Vec<Node<F>> {
1294 let data = RegressionData {
1295 x,
1296 y,
1297 feature_indices,
1298 };
1299 let mut nodes = Vec::new();
1300 build_extra_regression_tree(
1301 &data,
1302 indices,
1303 &mut nodes,
1304 0,
1305 params,
1306 n_features,
1307 max_features_n,
1308 rng,
1309 );
1310 nodes
1311}
1312
1313#[cfg(test)]
1318mod tests {
1319 use super::*;
1320 use approx::assert_relative_eq;
1321 use ndarray::array;
1322
1323 #[test]
1326 fn test_extra_classifier_simple_binary() {
1327 let x = Array2::from_shape_vec(
1328 (6, 2),
1329 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1330 )
1331 .unwrap();
1332 let y = array![0, 0, 0, 1, 1, 1];
1333
1334 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1335 let fitted = model.fit(&x, &y).unwrap();
1336 let preds = fitted.predict(&x).unwrap();
1337
1338 assert_eq!(preds.len(), 6);
1339 for i in 0..3 {
1341 assert_eq!(preds[i], 0, "sample {i} should be class 0");
1342 }
1343 for i in 3..6 {
1344 assert_eq!(preds[i], 1, "sample {i} should be class 1");
1345 }
1346 }
1347
1348 #[test]
1349 fn test_extra_classifier_single_class() {
1350 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1351 let y = array![0, 0, 0];
1352
1353 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1354 let fitted = model.fit(&x, &y).unwrap();
1355 let preds = fitted.predict(&x).unwrap();
1356
1357 assert_eq!(preds, array![0, 0, 0]);
1358 }
1359
1360 #[test]
1361 fn test_extra_classifier_max_depth_1() {
1362 let x =
1363 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1364 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1365
1366 let model = ExtraTreeClassifier::<f64>::new()
1367 .with_max_depth(Some(1))
1368 .with_max_features(MaxFeatures::All)
1369 .with_random_state(42);
1370 let fitted = model.fit(&x, &y).unwrap();
1371 let preds = fitted.predict(&x).unwrap();
1372
1373 assert_eq!(fitted.nodes().len(), 3);
1376 }
1377
1378 #[test]
1379 fn test_extra_classifier_predict_proba() {
1380 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1381 let y = array![0, 0, 0, 1, 1, 1];
1382
1383 let model = ExtraTreeClassifier::<f64>::new()
1384 .with_max_features(MaxFeatures::All)
1385 .with_random_state(42);
1386 let fitted = model.fit(&x, &y).unwrap();
1387 let proba = fitted.predict_proba(&x).unwrap();
1388
1389 assert_eq!(proba.dim(), (6, 2));
1390 for i in 0..6 {
1392 let row_sum = proba.row(i).sum();
1393 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1394 }
1395 }
1396
1397 #[test]
1398 fn test_extra_classifier_feature_importances() {
1399 let x = Array2::from_shape_vec(
1400 (8, 2),
1401 vec![
1402 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1403 ],
1404 )
1405 .unwrap();
1406 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1407
1408 let model = ExtraTreeClassifier::<f64>::new()
1409 .with_max_features(MaxFeatures::All)
1410 .with_random_state(42);
1411 let fitted = model.fit(&x, &y).unwrap();
1412 let importances = fitted.feature_importances();
1413
1414 assert_eq!(importances.len(), 2);
1415 let total: f64 = importances.sum();
1417 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1418 assert!(importances[0] > importances[1]);
1420 }
1421
1422 #[test]
1423 fn test_extra_classifier_shape_mismatch() {
1424 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1425 let y = array![0, 0]; let model = ExtraTreeClassifier::<f64>::new();
1428 assert!(model.fit(&x, &y).is_err());
1429 }
1430
1431 #[test]
1432 fn test_extra_classifier_empty_data() {
1433 let x = Array2::<f64>::zeros((0, 2));
1434 let y = Array1::<usize>::zeros(0);
1435
1436 let model = ExtraTreeClassifier::<f64>::new();
1437 assert!(model.fit(&x, &y).is_err());
1438 }
1439
1440 #[test]
1441 fn test_extra_classifier_invalid_min_samples_split() {
1442 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1443 let y = array![0, 0, 1];
1444
1445 let model = ExtraTreeClassifier::<f64>::new().with_min_samples_split(1);
1446 assert!(model.fit(&x, &y).is_err());
1447 }
1448
1449 #[test]
1450 fn test_extra_classifier_classes() {
1451 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1452 let y = array![0, 0, 0, 2, 2, 2]; let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1455 let fitted = model.fit(&x, &y).unwrap();
1456
1457 assert_eq!(fitted.classes(), &[0, 2]);
1458 assert_eq!(fitted.n_classes(), 2);
1459 }
1460
1461 #[test]
1462 fn test_extra_classifier_predict_shape_mismatch() {
1463 let x =
1464 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1465 let y = array![0, 0, 1, 1];
1466
1467 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1468 let fitted = model.fit(&x, &y).unwrap();
1469
1470 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1471 assert!(fitted.predict(&x_wrong).is_err());
1472 }
1473
1474 #[test]
1475 fn test_extra_classifier_f32() {
1476 let x = Array2::from_shape_vec(
1477 (6, 2),
1478 vec![
1479 1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0,
1480 ],
1481 )
1482 .unwrap();
1483 let y = array![0, 0, 0, 1, 1, 1];
1484
1485 let model = ExtraTreeClassifier::<f32>::new().with_random_state(42);
1486 let fitted = model.fit(&x, &y).unwrap();
1487 let preds = fitted.predict(&x).unwrap();
1488 assert_eq!(preds.len(), 6);
1489 }
1490
1491 #[test]
1492 fn test_extra_classifier_deterministic() {
1493 let x = Array2::from_shape_vec(
1494 (8, 2),
1495 vec![
1496 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1497 ],
1498 )
1499 .unwrap();
1500 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1501
1502 let model1 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1503 let model2 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1504
1505 let fitted1 = model1.fit(&x, &y).unwrap();
1506 let fitted2 = model2.fit(&x, &y).unwrap();
1507
1508 let preds1 = fitted1.predict(&x).unwrap();
1509 let preds2 = fitted2.predict(&x).unwrap();
1510
1511 assert_eq!(preds1, preds2);
1512 }
1513
1514 #[test]
1517 fn test_extra_regressor_simple() {
1518 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1519 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1520
1521 let model = ExtraTreeRegressor::<f64>::new()
1522 .with_max_features(MaxFeatures::All)
1523 .with_random_state(42);
1524 let fitted = model.fit(&x, &y).unwrap();
1525 let preds = fitted.predict(&x).unwrap();
1526
1527 assert_eq!(preds.len(), 6);
1529 for i in 0..6 {
1530 assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1531 }
1532 }
1533
1534 #[test]
1535 fn test_extra_regressor_constant_target() {
1536 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1537 let y = array![5.0, 5.0, 5.0, 5.0];
1538
1539 let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1540 let fitted = model.fit(&x, &y).unwrap();
1541 let preds = fitted.predict(&x).unwrap();
1542
1543 for &p in preds.iter() {
1544 assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1545 }
1546 }
1547
1548 #[test]
1549 fn test_extra_regressor_feature_importances() {
1550 let x = Array2::from_shape_vec(
1551 (8, 2),
1552 vec![
1553 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1554 ],
1555 )
1556 .unwrap();
1557 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1558
1559 let model = ExtraTreeRegressor::<f64>::new()
1560 .with_max_features(MaxFeatures::All)
1561 .with_random_state(42);
1562 let fitted = model.fit(&x, &y).unwrap();
1563 let importances = fitted.feature_importances();
1564
1565 assert_eq!(importances.len(), 2);
1566 let total: f64 = importances.sum();
1567 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1568 assert!(importances[0] > importances[1]);
1570 }
1571
1572 #[test]
1573 fn test_extra_regressor_shape_mismatch() {
1574 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1575 let y = array![1.0, 2.0]; let model = ExtraTreeRegressor::<f64>::new();
1578 assert!(model.fit(&x, &y).is_err());
1579 }
1580
1581 #[test]
1582 fn test_extra_regressor_empty_data() {
1583 let x = Array2::<f64>::zeros((0, 2));
1584 let y = Array1::<f64>::zeros(0);
1585
1586 let model = ExtraTreeRegressor::<f64>::new();
1587 assert!(model.fit(&x, &y).is_err());
1588 }
1589
1590 #[test]
1591 fn test_extra_regressor_predict_shape_mismatch() {
1592 let x =
1593 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1594 let y = array![1.0, 2.0, 3.0, 4.0];
1595
1596 let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1597 let fitted = model.fit(&x, &y).unwrap();
1598
1599 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1600 assert!(fitted.predict(&x_wrong).is_err());
1601 }
1602
1603 #[test]
1604 fn test_extra_regressor_max_depth() {
1605 let x =
1606 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1607 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1608
1609 let model = ExtraTreeRegressor::<f64>::new()
1610 .with_max_depth(Some(1))
1611 .with_max_features(MaxFeatures::All)
1612 .with_random_state(42);
1613 let fitted = model.fit(&x, &y).unwrap();
1614
1615 assert_eq!(fitted.nodes().len(), 3);
1617 }
1618
1619 #[test]
1620 fn test_extra_regressor_deterministic() {
1621 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1622 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1623
1624 let model1 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1625 let model2 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1626
1627 let fitted1 = model1.fit(&x, &y).unwrap();
1628 let fitted2 = model2.fit(&x, &y).unwrap();
1629
1630 let preds1 = fitted1.predict(&x).unwrap();
1631 let preds2 = fitted2.predict(&x).unwrap();
1632
1633 for i in 0..6 {
1634 assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1635 }
1636 }
1637
1638 #[test]
1639 fn test_extra_regressor_f32() {
1640 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1641 let y = array![1.0f32, 2.0, 3.0, 4.0];
1642
1643 let model = ExtraTreeRegressor::<f32>::new().with_random_state(42);
1644 let fitted = model.fit(&x, &y).unwrap();
1645 let preds = fitted.predict(&x).unwrap();
1646 assert_eq!(preds.len(), 4);
1647 }
1648
1649 #[test]
1652 fn test_classifier_builder_methods() {
1653 let model = ExtraTreeClassifier::<f64>::new()
1654 .with_max_depth(Some(5))
1655 .with_min_samples_split(10)
1656 .with_min_samples_leaf(3)
1657 .with_max_features(MaxFeatures::Log2)
1658 .with_criterion(ClassificationCriterion::Entropy)
1659 .with_random_state(42);
1660
1661 assert_eq!(model.max_depth, Some(5));
1662 assert_eq!(model.min_samples_split, 10);
1663 assert_eq!(model.min_samples_leaf, 3);
1664 assert_eq!(model.max_features, MaxFeatures::Log2);
1665 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1666 assert_eq!(model.random_state, Some(42));
1667 }
1668
1669 #[test]
1670 fn test_regressor_builder_methods() {
1671 let model = ExtraTreeRegressor::<f64>::new()
1672 .with_max_depth(Some(10))
1673 .with_min_samples_split(5)
1674 .with_min_samples_leaf(2)
1675 .with_max_features(MaxFeatures::Fixed(3))
1676 .with_criterion(RegressionCriterion::Mse)
1677 .with_random_state(99);
1678
1679 assert_eq!(model.max_depth, Some(10));
1680 assert_eq!(model.min_samples_split, 5);
1681 assert_eq!(model.min_samples_leaf, 2);
1682 assert_eq!(model.max_features, MaxFeatures::Fixed(3));
1683 assert_eq!(model.criterion, RegressionCriterion::Mse);
1684 assert_eq!(model.random_state, Some(99));
1685 }
1686
1687 #[test]
1688 fn test_classifier_default() {
1689 let model = ExtraTreeClassifier::<f64>::default();
1690 assert_eq!(model.max_depth, None);
1691 assert_eq!(model.min_samples_split, 2);
1692 assert_eq!(model.min_samples_leaf, 1);
1693 assert_eq!(model.max_features, MaxFeatures::Sqrt);
1694 assert_eq!(model.criterion, ClassificationCriterion::Gini);
1695 assert_eq!(model.random_state, None);
1696 }
1697
1698 #[test]
1699 fn test_regressor_default() {
1700 let model = ExtraTreeRegressor::<f64>::default();
1701 assert_eq!(model.max_depth, None);
1702 assert_eq!(model.min_samples_split, 2);
1703 assert_eq!(model.min_samples_leaf, 1);
1704 assert_eq!(model.max_features, MaxFeatures::All);
1705 assert_eq!(model.criterion, RegressionCriterion::Mse);
1706 assert_eq!(model.random_state, None);
1707 }
1708}