1use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::{Float, FromPrimitive, ToPrimitive};
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use rand::seq::index::sample as rand_sample_indices;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 ClassificationCriterion, Node, RegressionCriterion, TreeParams, compute_feature_importances,
41 traverse,
42};
43use crate::random_forest::MaxFeatures;
44
45struct ClassificationData<'a, F> {
51 x: &'a Array2<F>,
52 y: &'a [usize],
53 n_classes: usize,
54 feature_indices: Option<&'a [usize]>,
55 criterion: ClassificationCriterion,
56}
57
58struct RegressionData<'a, F> {
60 x: &'a Array2<F>,
61 y: &'a Array1<F>,
62 feature_indices: Option<&'a [usize]>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ExtraTreeClassifier<F> {
81 pub max_depth: Option<usize>,
83 pub min_samples_split: usize,
85 pub min_samples_leaf: usize,
87 pub max_features: MaxFeatures,
89 pub criterion: ClassificationCriterion,
91 pub random_state: Option<u64>,
93 _marker: std::marker::PhantomData<F>,
94}
95
96impl<F: Float> ExtraTreeClassifier<F> {
97 #[must_use]
103 pub fn new() -> Self {
104 Self {
105 max_depth: None,
106 min_samples_split: 2,
107 min_samples_leaf: 1,
108 max_features: MaxFeatures::Sqrt,
109 criterion: ClassificationCriterion::Gini,
110 random_state: None,
111 _marker: std::marker::PhantomData,
112 }
113 }
114
115 #[must_use]
117 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
118 self.max_depth = max_depth;
119 self
120 }
121
122 #[must_use]
124 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
125 self.min_samples_split = min_samples_split;
126 self
127 }
128
129 #[must_use]
131 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
132 self.min_samples_leaf = min_samples_leaf;
133 self
134 }
135
136 #[must_use]
138 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
139 self.max_features = max_features;
140 self
141 }
142
143 #[must_use]
145 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
146 self.criterion = criterion;
147 self
148 }
149
150 #[must_use]
152 pub fn with_random_state(mut self, seed: u64) -> Self {
153 self.random_state = Some(seed);
154 self
155 }
156}
157
158impl<F: Float> Default for ExtraTreeClassifier<F> {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[derive(Debug, Clone)]
174pub struct FittedExtraTreeClassifier<F> {
175 nodes: Vec<Node<F>>,
177 classes: Vec<usize>,
179 n_features: usize,
181 feature_importances: Array1<F>,
183}
184
185impl<F: Float + Send + Sync + 'static> FittedExtraTreeClassifier<F> {
186 #[must_use]
188 pub fn nodes(&self) -> &[Node<F>] {
189 &self.nodes
190 }
191
192 #[must_use]
194 pub fn n_features(&self) -> usize {
195 self.n_features
196 }
197
198 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207 if x.ncols() != self.n_features {
208 return Err(FerroError::ShapeMismatch {
209 expected: vec![self.n_features],
210 actual: vec![x.ncols()],
211 context: "number of features must match fitted model".into(),
212 });
213 }
214 let n_samples = x.nrows();
215 let n_classes = self.classes.len();
216 let mut proba = Array2::zeros((n_samples, n_classes));
217 for i in 0..n_samples {
218 let row = x.row(i);
219 let leaf = traverse(&self.nodes, &row);
220 if let Node::Leaf {
221 class_distribution: Some(ref dist),
222 ..
223 } = self.nodes[leaf]
224 {
225 for (j, &p) in dist.iter().enumerate() {
226 proba[[i, j]] = p;
227 }
228 }
229 }
230 Ok(proba)
231 }
232}
233
234impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreeClassifier<F> {
235 type Fitted = FittedExtraTreeClassifier<F>;
236 type Error = FerroError;
237
238 fn fit(
247 &self,
248 x: &Array2<F>,
249 y: &Array1<usize>,
250 ) -> Result<FittedExtraTreeClassifier<F>, FerroError> {
251 let (n_samples, n_features) = x.dim();
252
253 if n_samples != y.len() {
254 return Err(FerroError::ShapeMismatch {
255 expected: vec![n_samples],
256 actual: vec![y.len()],
257 context: "y length must match number of samples in X".into(),
258 });
259 }
260 if n_samples == 0 {
261 return Err(FerroError::InsufficientSamples {
262 required: 1,
263 actual: 0,
264 context: "ExtraTreeClassifier requires at least one sample".into(),
265 });
266 }
267 if self.min_samples_split < 2 {
268 return Err(FerroError::InvalidParameter {
269 name: "min_samples_split".into(),
270 reason: "must be at least 2".into(),
271 });
272 }
273 if self.min_samples_leaf < 1 {
274 return Err(FerroError::InvalidParameter {
275 name: "min_samples_leaf".into(),
276 reason: "must be at least 1".into(),
277 });
278 }
279
280 let mut classes: Vec<usize> = y.iter().copied().collect();
282 classes.sort_unstable();
283 classes.dedup();
284 let n_classes = classes.len();
285
286 let y_mapped: Vec<usize> = y
288 .iter()
289 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
290 .collect();
291
292 let indices: Vec<usize> = (0..n_samples).collect();
293
294 let max_features_n = resolve_max_features(self.max_features, n_features);
295
296 let mut rng = if let Some(seed) = self.random_state {
297 StdRng::seed_from_u64(seed)
298 } else {
299 StdRng::from_os_rng()
300 };
301
302 let data = ClassificationData {
303 x,
304 y: &y_mapped,
305 n_classes,
306 feature_indices: None,
307 criterion: self.criterion,
308 };
309 let params = TreeParams {
310 max_depth: self.max_depth,
311 min_samples_split: self.min_samples_split,
312 min_samples_leaf: self.min_samples_leaf,
313 };
314
315 let mut nodes: Vec<Node<F>> = Vec::new();
316 build_extra_classification_tree(
317 &data,
318 &indices,
319 &mut nodes,
320 0,
321 ¶ms,
322 n_features,
323 max_features_n,
324 &mut rng,
325 );
326
327 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
328
329 Ok(FittedExtraTreeClassifier {
330 nodes,
331 classes,
332 n_features,
333 feature_importances,
334 })
335 }
336}
337
338impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeClassifier<F> {
339 type Output = Array1<usize>;
340 type Error = FerroError;
341
342 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
349 if x.ncols() != self.n_features {
350 return Err(FerroError::ShapeMismatch {
351 expected: vec![self.n_features],
352 actual: vec![x.ncols()],
353 context: "number of features must match fitted model".into(),
354 });
355 }
356 let n_samples = x.nrows();
357 let mut predictions = Array1::zeros(n_samples);
358 for i in 0..n_samples {
359 let row = x.row(i);
360 let leaf = traverse(&self.nodes, &row);
361 if let Node::Leaf { value, .. } = self.nodes[leaf] {
362 predictions[i] = float_to_usize(value);
363 }
364 }
365 Ok(predictions)
366 }
367}
368
369impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeClassifier<F> {
370 fn feature_importances(&self) -> &Array1<F> {
371 &self.feature_importances
372 }
373}
374
375impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreeClassifier<F> {
376 fn classes(&self) -> &[usize] {
377 &self.classes
378 }
379
380 fn n_classes(&self) -> usize {
381 self.classes.len()
382 }
383}
384
385impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
387 for ExtraTreeClassifier<F>
388{
389 fn fit_pipeline(
390 &self,
391 x: &Array2<F>,
392 y: &Array1<F>,
393 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
394 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
395 let fitted = self.fit(x, &y_usize)?;
396 Ok(Box::new(FittedExtraTreeClassifierPipelineAdapter(fitted)))
397 }
398}
399
400struct FittedExtraTreeClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
402 FittedExtraTreeClassifier<F>,
403);
404
405impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
406 for FittedExtraTreeClassifierPipelineAdapter<F>
407{
408 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
409 let preds = self.0.predict(x)?;
410 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
411 }
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ExtraTreeRegressor<F> {
430 pub max_depth: Option<usize>,
432 pub min_samples_split: usize,
434 pub min_samples_leaf: usize,
436 pub max_features: MaxFeatures,
438 pub criterion: RegressionCriterion,
440 pub random_state: Option<u64>,
442 _marker: std::marker::PhantomData<F>,
443}
444
445impl<F: Float> ExtraTreeRegressor<F> {
446 #[must_use]
452 pub fn new() -> Self {
453 Self {
454 max_depth: None,
455 min_samples_split: 2,
456 min_samples_leaf: 1,
457 max_features: MaxFeatures::All,
458 criterion: RegressionCriterion::Mse,
459 random_state: None,
460 _marker: std::marker::PhantomData,
461 }
462 }
463
464 #[must_use]
466 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
467 self.max_depth = max_depth;
468 self
469 }
470
471 #[must_use]
473 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
474 self.min_samples_split = min_samples_split;
475 self
476 }
477
478 #[must_use]
480 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
481 self.min_samples_leaf = min_samples_leaf;
482 self
483 }
484
485 #[must_use]
487 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
488 self.max_features = max_features;
489 self
490 }
491
492 #[must_use]
494 pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
495 self.criterion = criterion;
496 self
497 }
498
499 #[must_use]
501 pub fn with_random_state(mut self, seed: u64) -> Self {
502 self.random_state = Some(seed);
503 self
504 }
505}
506
507impl<F: Float> Default for ExtraTreeRegressor<F> {
508 fn default() -> Self {
509 Self::new()
510 }
511}
512
513#[derive(Debug, Clone)]
521pub struct FittedExtraTreeRegressor<F> {
522 nodes: Vec<Node<F>>,
524 n_features: usize,
526 feature_importances: Array1<F>,
528}
529
530impl<F: Float + Send + Sync + 'static> FittedExtraTreeRegressor<F> {
531 #[must_use]
533 pub fn nodes(&self) -> &[Node<F>] {
534 &self.nodes
535 }
536
537 #[must_use]
539 pub fn n_features(&self) -> usize {
540 self.n_features
541 }
542}
543
544impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreeRegressor<F> {
545 type Fitted = FittedExtraTreeRegressor<F>;
546 type Error = FerroError;
547
548 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedExtraTreeRegressor<F>, FerroError> {
557 let (n_samples, n_features) = x.dim();
558
559 if n_samples != y.len() {
560 return Err(FerroError::ShapeMismatch {
561 expected: vec![n_samples],
562 actual: vec![y.len()],
563 context: "y length must match number of samples in X".into(),
564 });
565 }
566 if n_samples == 0 {
567 return Err(FerroError::InsufficientSamples {
568 required: 1,
569 actual: 0,
570 context: "ExtraTreeRegressor requires at least one sample".into(),
571 });
572 }
573 if self.min_samples_split < 2 {
574 return Err(FerroError::InvalidParameter {
575 name: "min_samples_split".into(),
576 reason: "must be at least 2".into(),
577 });
578 }
579 if self.min_samples_leaf < 1 {
580 return Err(FerroError::InvalidParameter {
581 name: "min_samples_leaf".into(),
582 reason: "must be at least 1".into(),
583 });
584 }
585
586 let indices: Vec<usize> = (0..n_samples).collect();
587 let max_features_n = resolve_max_features(self.max_features, n_features);
588
589 let mut rng = if let Some(seed) = self.random_state {
590 StdRng::seed_from_u64(seed)
591 } else {
592 StdRng::from_os_rng()
593 };
594
595 let data = RegressionData {
596 x,
597 y,
598 feature_indices: None,
599 };
600 let params = TreeParams {
601 max_depth: self.max_depth,
602 min_samples_split: self.min_samples_split,
603 min_samples_leaf: self.min_samples_leaf,
604 };
605
606 let mut nodes: Vec<Node<F>> = Vec::new();
607 build_extra_regression_tree(
608 &data,
609 &indices,
610 &mut nodes,
611 0,
612 ¶ms,
613 n_features,
614 max_features_n,
615 &mut rng,
616 );
617
618 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
619
620 Ok(FittedExtraTreeRegressor {
621 nodes,
622 n_features,
623 feature_importances,
624 })
625 }
626}
627
628impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeRegressor<F> {
629 type Output = Array1<F>;
630 type Error = FerroError;
631
632 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
639 if x.ncols() != self.n_features {
640 return Err(FerroError::ShapeMismatch {
641 expected: vec![self.n_features],
642 actual: vec![x.ncols()],
643 context: "number of features must match fitted model".into(),
644 });
645 }
646 let n_samples = x.nrows();
647 let mut predictions = Array1::zeros(n_samples);
648 for i in 0..n_samples {
649 let row = x.row(i);
650 let leaf = traverse(&self.nodes, &row);
651 if let Node::Leaf { value, .. } = self.nodes[leaf] {
652 predictions[i] = value;
653 }
654 }
655 Ok(predictions)
656 }
657}
658
659impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeRegressor<F> {
660 fn feature_importances(&self) -> &Array1<F> {
661 &self.feature_importances
662 }
663}
664
665impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreeRegressor<F> {
667 fn fit_pipeline(
668 &self,
669 x: &Array2<F>,
670 y: &Array1<F>,
671 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
672 let fitted = self.fit(x, y)?;
673 Ok(Box::new(fitted))
674 }
675}
676
677impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreeRegressor<F> {
678 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
679 self.predict(x)
680 }
681}
682
683fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
689 let result = match strategy {
690 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
691 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
692 MaxFeatures::All => n_features,
693 MaxFeatures::Fixed(n) => n.min(n_features),
694 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
695 };
696 result.max(1).min(n_features)
697}
698
699fn float_to_usize<F: Float>(v: F) -> usize {
701 v.to_f64().map_or(0, |f| f.round() as usize)
702}
703
704fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
706 use rand::RngCore;
707 let u = (rng.next_u64() as f64) / (u64::MAX as f64);
709 let range = max_val - min_val;
710 min_val + F::from(u).unwrap() * range
711}
712
713fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
715 if total == 0 {
716 return F::zero();
717 }
718 let total_f = F::from(total).unwrap();
719 let mut impurity = F::one();
720 for &count in class_counts {
721 let p = F::from(count).unwrap() / total_f;
722 impurity = impurity - p * p;
723 }
724 impurity
725}
726
727fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
729 if total == 0 {
730 return F::zero();
731 }
732 let total_f = F::from(total).unwrap();
733 let mut ent = F::zero();
734 for &count in class_counts {
735 if count > 0 {
736 let p = F::from(count).unwrap() / total_f;
737 ent = ent - p * p.ln();
738 }
739 }
740 ent
741}
742
743fn compute_impurity<F: Float>(
745 class_counts: &[usize],
746 total: usize,
747 criterion: ClassificationCriterion,
748) -> F {
749 match criterion {
750 ClassificationCriterion::Gini => gini_impurity(class_counts, total),
751 ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
752 }
753}
754
755fn make_classification_leaf<F: Float>(
757 nodes: &mut Vec<Node<F>>,
758 class_counts: &[usize],
759 n_classes: usize,
760 n_samples: usize,
761) -> usize {
762 let majority_class = class_counts
763 .iter()
764 .enumerate()
765 .max_by_key(|&(_, &count)| count)
766 .map_or(0, |(i, _)| i);
767
768 let total_f = if n_samples > 0 {
769 F::from(n_samples).unwrap()
770 } else {
771 F::one()
772 };
773 let distribution: Vec<F> = (0..n_classes)
774 .map(|c| F::from(class_counts[c]).unwrap() / total_f)
775 .collect();
776
777 let idx = nodes.len();
778 nodes.push(Node::Leaf {
779 value: F::from(majority_class).unwrap(),
780 class_distribution: Some(distribution),
781 n_samples,
782 });
783 idx
784}
785
786fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
788 if indices.is_empty() {
789 return F::zero();
790 }
791 let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
792 sum / F::from(indices.len()).unwrap()
793}
794
795#[allow(clippy::too_many_arguments)]
805fn build_extra_classification_tree<F: Float>(
806 data: &ClassificationData<'_, F>,
807 indices: &[usize],
808 nodes: &mut Vec<Node<F>>,
809 depth: usize,
810 params: &TreeParams,
811 n_features: usize,
812 max_features_n: usize,
813 rng: &mut StdRng,
814) -> usize {
815 let n = indices.len();
816
817 let mut class_counts = vec![0usize; data.n_classes];
818 for &i in indices {
819 class_counts[data.y[i]] += 1;
820 }
821
822 let should_stop = n < params.min_samples_split
823 || params.max_depth.is_some_and(|d| depth >= d)
824 || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
825
826 if should_stop {
827 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
828 }
829
830 let best = find_random_classification_split(
831 data,
832 indices,
833 params.min_samples_leaf,
834 n_features,
835 max_features_n,
836 rng,
837 );
838
839 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
840 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
841 .iter()
842 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
843
844 if left_indices.len() < params.min_samples_leaf
846 || right_indices.len() < params.min_samples_leaf
847 {
848 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
849 }
850
851 let node_idx = nodes.len();
852 nodes.push(Node::Leaf {
853 value: F::zero(),
854 class_distribution: None,
855 n_samples: 0,
856 }); let left_idx = build_extra_classification_tree(
859 data,
860 &left_indices,
861 nodes,
862 depth + 1,
863 params,
864 n_features,
865 max_features_n,
866 rng,
867 );
868 let right_idx = build_extra_classification_tree(
869 data,
870 &right_indices,
871 nodes,
872 depth + 1,
873 params,
874 n_features,
875 max_features_n,
876 rng,
877 );
878
879 nodes[node_idx] = Node::Split {
880 feature: best_feature,
881 threshold: best_threshold,
882 left: left_idx,
883 right: right_idx,
884 impurity_decrease: best_impurity_decrease,
885 n_samples: n,
886 };
887
888 node_idx
889 } else {
890 make_classification_leaf(nodes, &class_counts, data.n_classes, n)
891 }
892}
893
894#[allow(clippy::too_many_arguments)]
900fn find_random_classification_split<F: Float>(
901 data: &ClassificationData<'_, F>,
902 indices: &[usize],
903 min_samples_leaf: usize,
904 n_features: usize,
905 max_features_n: usize,
906 rng: &mut StdRng,
907) -> Option<(usize, F, F)> {
908 let n = indices.len();
909 let n_f = F::from(n).unwrap();
910
911 let mut parent_counts = vec![0usize; data.n_classes];
912 for &i in indices {
913 parent_counts[data.y[i]] += 1;
914 }
915 let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
916
917 let mut best_score = F::neg_infinity();
918 let mut best_feature = 0;
919 let mut best_threshold = F::zero();
920
921 let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
923 let k = max_features_n.min(feat_indices.len());
925 rand_sample_indices(rng, feat_indices.len(), k)
926 .into_vec()
927 .into_iter()
928 .map(|i| feat_indices[i])
929 .collect()
930 } else {
931 let k = max_features_n.min(n_features);
932 rand_sample_indices(rng, n_features, k).into_vec()
933 };
934
935 for feat in feature_subset {
936 let mut feat_min = F::infinity();
938 let mut feat_max = F::neg_infinity();
939 for &i in indices {
940 let val = data.x[[i, feat]];
941 if val < feat_min {
942 feat_min = val;
943 }
944 if val > feat_max {
945 feat_max = val;
946 }
947 }
948
949 if feat_min >= feat_max {
951 continue;
952 }
953
954 let threshold = random_threshold(rng, feat_min, feat_max);
956
957 let mut left_counts = vec![0usize; data.n_classes];
959 let mut right_counts = vec![0usize; data.n_classes];
960 let mut left_n = 0usize;
961
962 for &i in indices {
963 let cls = data.y[i];
964 if data.x[[i, feat]] <= threshold {
965 left_counts[cls] += 1;
966 left_n += 1;
967 } else {
968 right_counts[cls] += 1;
969 }
970 }
971
972 let right_n = n - left_n;
973 if left_n < min_samples_leaf || right_n < min_samples_leaf {
974 continue;
975 }
976
977 let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
978 let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
979 let left_weight = F::from(left_n).unwrap() / n_f;
980 let right_weight = F::from(right_n).unwrap() / n_f;
981 let weighted_child_impurity = left_weight * left_impurity + right_weight * right_impurity;
982 let impurity_decrease = parent_impurity - weighted_child_impurity;
983
984 if impurity_decrease > best_score {
985 best_score = impurity_decrease;
986 best_feature = feat;
987 best_threshold = threshold;
988 }
989 }
990
991 if best_score > F::zero() {
992 Some((best_feature, best_threshold, best_score * n_f))
993 } else {
994 None
995 }
996}
997
998#[allow(clippy::too_many_arguments)]
1004fn build_extra_regression_tree<F: Float>(
1005 data: &RegressionData<'_, F>,
1006 indices: &[usize],
1007 nodes: &mut Vec<Node<F>>,
1008 depth: usize,
1009 params: &TreeParams,
1010 n_features: usize,
1011 max_features_n: usize,
1012 rng: &mut StdRng,
1013) -> usize {
1014 let n = indices.len();
1015 let mean = mean_value(data.y, indices);
1016
1017 let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1018
1019 if should_stop {
1020 let idx = nodes.len();
1021 nodes.push(Node::Leaf {
1022 value: mean,
1023 class_distribution: None,
1024 n_samples: n,
1025 });
1026 return idx;
1027 }
1028
1029 let parent_sum_sq: F = indices
1031 .iter()
1032 .map(|&i| {
1033 let diff = data.y[i] - mean;
1034 diff * diff
1035 })
1036 .fold(F::zero(), |a, b| a + b);
1037 let parent_mse = parent_sum_sq / F::from(n).unwrap();
1038
1039 if parent_mse <= F::epsilon() {
1040 let idx = nodes.len();
1041 nodes.push(Node::Leaf {
1042 value: mean,
1043 class_distribution: None,
1044 n_samples: n,
1045 });
1046 return idx;
1047 }
1048
1049 let best = find_random_regression_split(
1050 data,
1051 indices,
1052 params.min_samples_leaf,
1053 n_features,
1054 max_features_n,
1055 rng,
1056 );
1057
1058 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1059 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1060 .iter()
1061 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1062
1063 if left_indices.len() < params.min_samples_leaf
1065 || right_indices.len() < params.min_samples_leaf
1066 {
1067 let idx = nodes.len();
1068 nodes.push(Node::Leaf {
1069 value: mean,
1070 class_distribution: None,
1071 n_samples: n,
1072 });
1073 return idx;
1074 }
1075
1076 let node_idx = nodes.len();
1077 nodes.push(Node::Leaf {
1078 value: F::zero(),
1079 class_distribution: None,
1080 n_samples: 0,
1081 }); let left_idx = build_extra_regression_tree(
1084 data,
1085 &left_indices,
1086 nodes,
1087 depth + 1,
1088 params,
1089 n_features,
1090 max_features_n,
1091 rng,
1092 );
1093 let right_idx = build_extra_regression_tree(
1094 data,
1095 &right_indices,
1096 nodes,
1097 depth + 1,
1098 params,
1099 n_features,
1100 max_features_n,
1101 rng,
1102 );
1103
1104 nodes[node_idx] = Node::Split {
1105 feature: best_feature,
1106 threshold: best_threshold,
1107 left: left_idx,
1108 right: right_idx,
1109 impurity_decrease: best_impurity_decrease,
1110 n_samples: n,
1111 };
1112
1113 node_idx
1114 } else {
1115 let idx = nodes.len();
1116 nodes.push(Node::Leaf {
1117 value: mean,
1118 class_distribution: None,
1119 n_samples: n,
1120 });
1121 idx
1122 }
1123}
1124
1125#[allow(clippy::too_many_arguments)]
1131fn find_random_regression_split<F: Float>(
1132 data: &RegressionData<'_, F>,
1133 indices: &[usize],
1134 min_samples_leaf: usize,
1135 n_features: usize,
1136 max_features_n: usize,
1137 rng: &mut StdRng,
1138) -> Option<(usize, F, F)> {
1139 let n = indices.len();
1140 let n_f = F::from(n).unwrap();
1141
1142 let parent_sum: F = indices
1143 .iter()
1144 .map(|&i| data.y[i])
1145 .fold(F::zero(), |a, b| a + b);
1146 let parent_sum_sq: F = indices
1147 .iter()
1148 .map(|&i| data.y[i] * data.y[i])
1149 .fold(F::zero(), |a, b| a + b);
1150 let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1151
1152 let mut best_score = F::neg_infinity();
1153 let mut best_feature = 0;
1154 let mut best_threshold = F::zero();
1155
1156 let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
1158 let k = max_features_n.min(feat_indices.len());
1159 rand_sample_indices(rng, feat_indices.len(), k)
1160 .into_vec()
1161 .into_iter()
1162 .map(|i| feat_indices[i])
1163 .collect()
1164 } else {
1165 let k = max_features_n.min(n_features);
1166 rand_sample_indices(rng, n_features, k).into_vec()
1167 };
1168
1169 for feat in feature_subset {
1170 let mut feat_min = F::infinity();
1172 let mut feat_max = F::neg_infinity();
1173 for &i in indices {
1174 let val = data.x[[i, feat]];
1175 if val < feat_min {
1176 feat_min = val;
1177 }
1178 if val > feat_max {
1179 feat_max = val;
1180 }
1181 }
1182
1183 if feat_min >= feat_max {
1185 continue;
1186 }
1187
1188 let threshold = random_threshold(rng, feat_min, feat_max);
1190
1191 let mut left_sum = F::zero();
1193 let mut left_sum_sq = F::zero();
1194 let mut left_n: usize = 0;
1195
1196 for &i in indices {
1197 if data.x[[i, feat]] <= threshold {
1198 let val = data.y[i];
1199 left_sum = left_sum + val;
1200 left_sum_sq = left_sum_sq + val * val;
1201 left_n += 1;
1202 }
1203 }
1204
1205 let right_n = n - left_n;
1206 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1207 continue;
1208 }
1209
1210 let left_n_f = F::from(left_n).unwrap();
1211 let right_n_f = F::from(right_n).unwrap();
1212
1213 let left_mean = left_sum / left_n_f;
1214 let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1215
1216 let right_sum = parent_sum - left_sum;
1217 let right_sum_sq = parent_sum_sq - left_sum_sq;
1218 let right_mean = right_sum / right_n_f;
1219 let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1220
1221 let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1222 let mse_decrease = parent_mse - weighted_child_mse;
1223
1224 if mse_decrease > best_score {
1225 best_score = mse_decrease;
1226 best_feature = feat;
1227 best_threshold = threshold;
1228 }
1229 }
1230
1231 if best_score > F::zero() {
1232 Some((best_feature, best_threshold, best_score * n_f))
1233 } else {
1234 None
1235 }
1236}
1237
1238#[allow(clippy::too_many_arguments)]
1246pub(crate) fn build_extra_classification_tree_for_ensemble<F: Float>(
1247 x: &Array2<F>,
1248 y: &[usize],
1249 n_classes: usize,
1250 indices: &[usize],
1251 feature_indices: Option<&[usize]>,
1252 params: &TreeParams,
1253 criterion: ClassificationCriterion,
1254 n_features: usize,
1255 max_features_n: usize,
1256 rng: &mut StdRng,
1257) -> Vec<Node<F>> {
1258 let data = ClassificationData {
1259 x,
1260 y,
1261 n_classes,
1262 feature_indices,
1263 criterion,
1264 };
1265 let mut nodes = Vec::new();
1266 build_extra_classification_tree(
1267 &data,
1268 indices,
1269 &mut nodes,
1270 0,
1271 params,
1272 n_features,
1273 max_features_n,
1274 rng,
1275 );
1276 nodes
1277}
1278
1279#[allow(clippy::too_many_arguments)]
1283pub(crate) fn build_extra_regression_tree_for_ensemble<F: Float>(
1284 x: &Array2<F>,
1285 y: &Array1<F>,
1286 indices: &[usize],
1287 feature_indices: Option<&[usize]>,
1288 params: &TreeParams,
1289 n_features: usize,
1290 max_features_n: usize,
1291 rng: &mut StdRng,
1292) -> Vec<Node<F>> {
1293 let data = RegressionData {
1294 x,
1295 y,
1296 feature_indices,
1297 };
1298 let mut nodes = Vec::new();
1299 build_extra_regression_tree(
1300 &data,
1301 indices,
1302 &mut nodes,
1303 0,
1304 params,
1305 n_features,
1306 max_features_n,
1307 rng,
1308 );
1309 nodes
1310}
1311
1312#[cfg(test)]
1317mod tests {
1318 use super::*;
1319 use approx::assert_relative_eq;
1320 use ndarray::array;
1321
1322 #[test]
1325 fn test_extra_classifier_simple_binary() {
1326 let x = Array2::from_shape_vec(
1327 (6, 2),
1328 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1329 )
1330 .unwrap();
1331 let y = array![0, 0, 0, 1, 1, 1];
1332
1333 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1334 let fitted = model.fit(&x, &y).unwrap();
1335 let preds = fitted.predict(&x).unwrap();
1336
1337 assert_eq!(preds.len(), 6);
1338 for i in 0..3 {
1340 assert_eq!(preds[i], 0, "sample {i} should be class 0");
1341 }
1342 for i in 3..6 {
1343 assert_eq!(preds[i], 1, "sample {i} should be class 1");
1344 }
1345 }
1346
1347 #[test]
1348 fn test_extra_classifier_single_class() {
1349 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1350 let y = array![0, 0, 0];
1351
1352 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1353 let fitted = model.fit(&x, &y).unwrap();
1354 let preds = fitted.predict(&x).unwrap();
1355
1356 assert_eq!(preds, array![0, 0, 0]);
1357 }
1358
1359 #[test]
1360 fn test_extra_classifier_max_depth_1() {
1361 let x =
1362 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1363 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1364
1365 let model = ExtraTreeClassifier::<f64>::new()
1366 .with_max_depth(Some(1))
1367 .with_max_features(MaxFeatures::All)
1368 .with_random_state(42);
1369 let fitted = model.fit(&x, &y).unwrap();
1370 let _preds = fitted.predict(&x).unwrap();
1371
1372 assert_eq!(fitted.nodes().len(), 3);
1375 }
1376
1377 #[test]
1378 fn test_extra_classifier_predict_proba() {
1379 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1380 let y = array![0, 0, 0, 1, 1, 1];
1381
1382 let model = ExtraTreeClassifier::<f64>::new()
1383 .with_max_features(MaxFeatures::All)
1384 .with_random_state(42);
1385 let fitted = model.fit(&x, &y).unwrap();
1386 let proba = fitted.predict_proba(&x).unwrap();
1387
1388 assert_eq!(proba.dim(), (6, 2));
1389 for i in 0..6 {
1391 let row_sum = proba.row(i).sum();
1392 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1393 }
1394 }
1395
1396 #[test]
1397 fn test_extra_classifier_feature_importances() {
1398 let x = Array2::from_shape_vec(
1399 (8, 2),
1400 vec![
1401 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1402 ],
1403 )
1404 .unwrap();
1405 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1406
1407 let model = ExtraTreeClassifier::<f64>::new()
1408 .with_max_features(MaxFeatures::All)
1409 .with_random_state(42);
1410 let fitted = model.fit(&x, &y).unwrap();
1411 let importances = fitted.feature_importances();
1412
1413 assert_eq!(importances.len(), 2);
1414 let total: f64 = importances.sum();
1416 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1417 assert!(importances[0] > importances[1]);
1419 }
1420
1421 #[test]
1422 fn test_extra_classifier_shape_mismatch() {
1423 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1424 let y = array![0, 0]; let model = ExtraTreeClassifier::<f64>::new();
1427 assert!(model.fit(&x, &y).is_err());
1428 }
1429
1430 #[test]
1431 fn test_extra_classifier_empty_data() {
1432 let x = Array2::<f64>::zeros((0, 2));
1433 let y = Array1::<usize>::zeros(0);
1434
1435 let model = ExtraTreeClassifier::<f64>::new();
1436 assert!(model.fit(&x, &y).is_err());
1437 }
1438
1439 #[test]
1440 fn test_extra_classifier_invalid_min_samples_split() {
1441 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1442 let y = array![0, 0, 1];
1443
1444 let model = ExtraTreeClassifier::<f64>::new().with_min_samples_split(1);
1445 assert!(model.fit(&x, &y).is_err());
1446 }
1447
1448 #[test]
1449 fn test_extra_classifier_classes() {
1450 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1451 let y = array![0, 0, 0, 2, 2, 2]; let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1454 let fitted = model.fit(&x, &y).unwrap();
1455
1456 assert_eq!(fitted.classes(), &[0, 2]);
1457 assert_eq!(fitted.n_classes(), 2);
1458 }
1459
1460 #[test]
1461 fn test_extra_classifier_predict_shape_mismatch() {
1462 let x =
1463 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1464 let y = array![0, 0, 1, 1];
1465
1466 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1467 let fitted = model.fit(&x, &y).unwrap();
1468
1469 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1470 assert!(fitted.predict(&x_wrong).is_err());
1471 }
1472
1473 #[test]
1474 fn test_extra_classifier_f32() {
1475 let x = Array2::from_shape_vec(
1476 (6, 2),
1477 vec![
1478 1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0,
1479 ],
1480 )
1481 .unwrap();
1482 let y = array![0, 0, 0, 1, 1, 1];
1483
1484 let model = ExtraTreeClassifier::<f32>::new().with_random_state(42);
1485 let fitted = model.fit(&x, &y).unwrap();
1486 let preds = fitted.predict(&x).unwrap();
1487 assert_eq!(preds.len(), 6);
1488 }
1489
1490 #[test]
1491 fn test_extra_classifier_deterministic() {
1492 let x = Array2::from_shape_vec(
1493 (8, 2),
1494 vec![
1495 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1496 ],
1497 )
1498 .unwrap();
1499 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1500
1501 let model1 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1502 let model2 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1503
1504 let fitted1 = model1.fit(&x, &y).unwrap();
1505 let fitted2 = model2.fit(&x, &y).unwrap();
1506
1507 let preds1 = fitted1.predict(&x).unwrap();
1508 let preds2 = fitted2.predict(&x).unwrap();
1509
1510 assert_eq!(preds1, preds2);
1511 }
1512
1513 #[test]
1516 fn test_extra_regressor_simple() {
1517 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1518 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1519
1520 let model = ExtraTreeRegressor::<f64>::new()
1521 .with_max_features(MaxFeatures::All)
1522 .with_random_state(42);
1523 let fitted = model.fit(&x, &y).unwrap();
1524 let preds = fitted.predict(&x).unwrap();
1525
1526 assert_eq!(preds.len(), 6);
1528 for i in 0..6 {
1529 assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1530 }
1531 }
1532
1533 #[test]
1534 fn test_extra_regressor_constant_target() {
1535 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1536 let y = array![5.0, 5.0, 5.0, 5.0];
1537
1538 let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1539 let fitted = model.fit(&x, &y).unwrap();
1540 let preds = fitted.predict(&x).unwrap();
1541
1542 for &p in &preds {
1543 assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1544 }
1545 }
1546
1547 #[test]
1548 fn test_extra_regressor_feature_importances() {
1549 let x = Array2::from_shape_vec(
1550 (8, 2),
1551 vec![
1552 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1553 ],
1554 )
1555 .unwrap();
1556 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1557
1558 let model = ExtraTreeRegressor::<f64>::new()
1559 .with_max_features(MaxFeatures::All)
1560 .with_random_state(42);
1561 let fitted = model.fit(&x, &y).unwrap();
1562 let importances = fitted.feature_importances();
1563
1564 assert_eq!(importances.len(), 2);
1565 let total: f64 = importances.sum();
1566 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1567 assert!(importances[0] > importances[1]);
1569 }
1570
1571 #[test]
1572 fn test_extra_regressor_shape_mismatch() {
1573 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1574 let y = array![1.0, 2.0]; let model = ExtraTreeRegressor::<f64>::new();
1577 assert!(model.fit(&x, &y).is_err());
1578 }
1579
1580 #[test]
1581 fn test_extra_regressor_empty_data() {
1582 let x = Array2::<f64>::zeros((0, 2));
1583 let y = Array1::<f64>::zeros(0);
1584
1585 let model = ExtraTreeRegressor::<f64>::new();
1586 assert!(model.fit(&x, &y).is_err());
1587 }
1588
1589 #[test]
1590 fn test_extra_regressor_predict_shape_mismatch() {
1591 let x =
1592 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1593 let y = array![1.0, 2.0, 3.0, 4.0];
1594
1595 let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1596 let fitted = model.fit(&x, &y).unwrap();
1597
1598 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1599 assert!(fitted.predict(&x_wrong).is_err());
1600 }
1601
1602 #[test]
1603 fn test_extra_regressor_max_depth() {
1604 let x =
1605 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1606 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1607
1608 let model = ExtraTreeRegressor::<f64>::new()
1609 .with_max_depth(Some(1))
1610 .with_max_features(MaxFeatures::All)
1611 .with_random_state(42);
1612 let fitted = model.fit(&x, &y).unwrap();
1613
1614 assert_eq!(fitted.nodes().len(), 3);
1616 }
1617
1618 #[test]
1619 fn test_extra_regressor_deterministic() {
1620 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1621 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1622
1623 let model1 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1624 let model2 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1625
1626 let fitted1 = model1.fit(&x, &y).unwrap();
1627 let fitted2 = model2.fit(&x, &y).unwrap();
1628
1629 let preds1 = fitted1.predict(&x).unwrap();
1630 let preds2 = fitted2.predict(&x).unwrap();
1631
1632 for i in 0..6 {
1633 assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1634 }
1635 }
1636
1637 #[test]
1638 fn test_extra_regressor_f32() {
1639 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1640 let y = array![1.0f32, 2.0, 3.0, 4.0];
1641
1642 let model = ExtraTreeRegressor::<f32>::new().with_random_state(42);
1643 let fitted = model.fit(&x, &y).unwrap();
1644 let preds = fitted.predict(&x).unwrap();
1645 assert_eq!(preds.len(), 4);
1646 }
1647
1648 #[test]
1651 fn test_classifier_builder_methods() {
1652 let model = ExtraTreeClassifier::<f64>::new()
1653 .with_max_depth(Some(5))
1654 .with_min_samples_split(10)
1655 .with_min_samples_leaf(3)
1656 .with_max_features(MaxFeatures::Log2)
1657 .with_criterion(ClassificationCriterion::Entropy)
1658 .with_random_state(42);
1659
1660 assert_eq!(model.max_depth, Some(5));
1661 assert_eq!(model.min_samples_split, 10);
1662 assert_eq!(model.min_samples_leaf, 3);
1663 assert_eq!(model.max_features, MaxFeatures::Log2);
1664 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1665 assert_eq!(model.random_state, Some(42));
1666 }
1667
1668 #[test]
1669 fn test_regressor_builder_methods() {
1670 let model = ExtraTreeRegressor::<f64>::new()
1671 .with_max_depth(Some(10))
1672 .with_min_samples_split(5)
1673 .with_min_samples_leaf(2)
1674 .with_max_features(MaxFeatures::Fixed(3))
1675 .with_criterion(RegressionCriterion::Mse)
1676 .with_random_state(99);
1677
1678 assert_eq!(model.max_depth, Some(10));
1679 assert_eq!(model.min_samples_split, 5);
1680 assert_eq!(model.min_samples_leaf, 2);
1681 assert_eq!(model.max_features, MaxFeatures::Fixed(3));
1682 assert_eq!(model.criterion, RegressionCriterion::Mse);
1683 assert_eq!(model.random_state, Some(99));
1684 }
1685
1686 #[test]
1687 fn test_classifier_default() {
1688 let model = ExtraTreeClassifier::<f64>::default();
1689 assert_eq!(model.max_depth, None);
1690 assert_eq!(model.min_samples_split, 2);
1691 assert_eq!(model.min_samples_leaf, 1);
1692 assert_eq!(model.max_features, MaxFeatures::Sqrt);
1693 assert_eq!(model.criterion, ClassificationCriterion::Gini);
1694 assert_eq!(model.random_state, None);
1695 }
1696
1697 #[test]
1698 fn test_regressor_default() {
1699 let model = ExtraTreeRegressor::<f64>::default();
1700 assert_eq!(model.max_depth, None);
1701 assert_eq!(model.min_samples_split, 2);
1702 assert_eq!(model.min_samples_leaf, 1);
1703 assert_eq!(model.max_features, MaxFeatures::All);
1704 assert_eq!(model.criterion, RegressionCriterion::Mse);
1705 assert_eq!(model.random_state, None);
1706 }
1707}