1use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::{Float, FromPrimitive, ToPrimitive};
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use rand::seq::index::sample as rand_sample_indices;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 ClassificationCriterion, Node, RegressionCriterion, TreeParams, compute_feature_importances,
41 traverse,
42};
43use crate::random_forest::MaxFeatures;
44
45struct ClassificationData<'a, F> {
51 x: &'a Array2<F>,
52 y: &'a [usize],
53 n_classes: usize,
54 feature_indices: Option<&'a [usize]>,
55 criterion: ClassificationCriterion,
56}
57
58struct RegressionData<'a, F> {
60 x: &'a Array2<F>,
61 y: &'a Array1<F>,
62 feature_indices: Option<&'a [usize]>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ExtraTreeClassifier<F> {
81 pub max_depth: Option<usize>,
83 pub min_samples_split: usize,
85 pub min_samples_leaf: usize,
87 pub max_features: MaxFeatures,
89 pub criterion: ClassificationCriterion,
91 pub random_state: Option<u64>,
93 _marker: std::marker::PhantomData<F>,
94}
95
96impl<F: Float> ExtraTreeClassifier<F> {
97 #[must_use]
103 pub fn new() -> Self {
104 Self {
105 max_depth: None,
106 min_samples_split: 2,
107 min_samples_leaf: 1,
108 max_features: MaxFeatures::Sqrt,
109 criterion: ClassificationCriterion::Gini,
110 random_state: None,
111 _marker: std::marker::PhantomData,
112 }
113 }
114
115 #[must_use]
117 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
118 self.max_depth = max_depth;
119 self
120 }
121
122 #[must_use]
124 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
125 self.min_samples_split = min_samples_split;
126 self
127 }
128
129 #[must_use]
131 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
132 self.min_samples_leaf = min_samples_leaf;
133 self
134 }
135
136 #[must_use]
138 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
139 self.max_features = max_features;
140 self
141 }
142
143 #[must_use]
145 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
146 self.criterion = criterion;
147 self
148 }
149
150 #[must_use]
152 pub fn with_random_state(mut self, seed: u64) -> Self {
153 self.random_state = Some(seed);
154 self
155 }
156}
157
158impl<F: Float> Default for ExtraTreeClassifier<F> {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[derive(Debug, Clone)]
174pub struct FittedExtraTreeClassifier<F> {
175 nodes: Vec<Node<F>>,
177 classes: Vec<usize>,
179 n_features: usize,
181 feature_importances: Array1<F>,
183}
184
185impl<F: Float + Send + Sync + 'static> FittedExtraTreeClassifier<F> {
186 #[must_use]
188 pub fn nodes(&self) -> &[Node<F>] {
189 &self.nodes
190 }
191
192 #[must_use]
194 pub fn n_features(&self) -> usize {
195 self.n_features
196 }
197
198 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207 if x.ncols() != self.n_features {
208 return Err(FerroError::ShapeMismatch {
209 expected: vec![self.n_features],
210 actual: vec![x.ncols()],
211 context: "number of features must match fitted model".into(),
212 });
213 }
214 let n_samples = x.nrows();
215 let n_classes = self.classes.len();
216 let mut proba = Array2::zeros((n_samples, n_classes));
217 for i in 0..n_samples {
218 let row = x.row(i);
219 let leaf = traverse(&self.nodes, &row);
220 if let Node::Leaf {
221 class_distribution: Some(ref dist),
222 ..
223 } = self.nodes[leaf]
224 {
225 for (j, &p) in dist.iter().enumerate() {
226 proba[[i, j]] = p;
227 }
228 }
229 }
230 Ok(proba)
231 }
232
233 pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
241 if x.nrows() != y.len() {
242 return Err(FerroError::ShapeMismatch {
243 expected: vec![x.nrows()],
244 actual: vec![y.len()],
245 context: "y length must match number of samples in X".into(),
246 });
247 }
248 let preds = self.predict(x)?;
249 Ok(crate::mean_accuracy(&preds, y))
250 }
251
252 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
259 let proba = self.predict_proba(x)?;
260 Ok(crate::log_proba(&proba))
261 }
262}
263
264impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreeClassifier<F> {
265 type Fitted = FittedExtraTreeClassifier<F>;
266 type Error = FerroError;
267
268 fn fit(
277 &self,
278 x: &Array2<F>,
279 y: &Array1<usize>,
280 ) -> Result<FittedExtraTreeClassifier<F>, FerroError> {
281 let (n_samples, n_features) = x.dim();
282
283 if n_samples != y.len() {
284 return Err(FerroError::ShapeMismatch {
285 expected: vec![n_samples],
286 actual: vec![y.len()],
287 context: "y length must match number of samples in X".into(),
288 });
289 }
290 if n_samples == 0 {
291 return Err(FerroError::InsufficientSamples {
292 required: 1,
293 actual: 0,
294 context: "ExtraTreeClassifier requires at least one sample".into(),
295 });
296 }
297 if self.min_samples_split < 2 {
298 return Err(FerroError::InvalidParameter {
299 name: "min_samples_split".into(),
300 reason: "must be at least 2".into(),
301 });
302 }
303 if self.min_samples_leaf < 1 {
304 return Err(FerroError::InvalidParameter {
305 name: "min_samples_leaf".into(),
306 reason: "must be at least 1".into(),
307 });
308 }
309
310 let mut classes: Vec<usize> = y.iter().copied().collect();
312 classes.sort_unstable();
313 classes.dedup();
314 let n_classes = classes.len();
315
316 let y_mapped: Vec<usize> = y
318 .iter()
319 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
320 .collect();
321
322 let indices: Vec<usize> = (0..n_samples).collect();
323
324 let max_features_n = resolve_max_features(self.max_features, n_features);
325
326 let mut rng = if let Some(seed) = self.random_state {
327 StdRng::seed_from_u64(seed)
328 } else {
329 StdRng::from_os_rng()
330 };
331
332 let data = ClassificationData {
333 x,
334 y: &y_mapped,
335 n_classes,
336 feature_indices: None,
337 criterion: self.criterion,
338 };
339 let params = TreeParams {
340 max_depth: self.max_depth,
341 min_samples_split: self.min_samples_split,
342 min_samples_leaf: self.min_samples_leaf,
343 };
344
345 let mut nodes: Vec<Node<F>> = Vec::new();
346 build_extra_classification_tree(
347 &data,
348 &indices,
349 &mut nodes,
350 0,
351 ¶ms,
352 n_features,
353 max_features_n,
354 &mut rng,
355 );
356
357 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
358
359 Ok(FittedExtraTreeClassifier {
360 nodes,
361 classes,
362 n_features,
363 feature_importances,
364 })
365 }
366}
367
368impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeClassifier<F> {
369 type Output = Array1<usize>;
370 type Error = FerroError;
371
372 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
379 if x.ncols() != self.n_features {
380 return Err(FerroError::ShapeMismatch {
381 expected: vec![self.n_features],
382 actual: vec![x.ncols()],
383 context: "number of features must match fitted model".into(),
384 });
385 }
386 let n_samples = x.nrows();
387 let mut predictions = Array1::zeros(n_samples);
388 for i in 0..n_samples {
389 let row = x.row(i);
390 let leaf = traverse(&self.nodes, &row);
391 if let Node::Leaf { value, .. } = self.nodes[leaf] {
392 predictions[i] = float_to_usize(value);
393 }
394 }
395 Ok(predictions)
396 }
397}
398
399impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeClassifier<F> {
400 fn feature_importances(&self) -> &Array1<F> {
401 &self.feature_importances
402 }
403}
404
405impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreeClassifier<F> {
406 fn classes(&self) -> &[usize] {
407 &self.classes
408 }
409
410 fn n_classes(&self) -> usize {
411 self.classes.len()
412 }
413}
414
415impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
417 for ExtraTreeClassifier<F>
418{
419 fn fit_pipeline(
420 &self,
421 x: &Array2<F>,
422 y: &Array1<F>,
423 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
424 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
425 let fitted = self.fit(x, &y_usize)?;
426 Ok(Box::new(FittedExtraTreeClassifierPipelineAdapter(fitted)))
427 }
428}
429
430struct FittedExtraTreeClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
432 FittedExtraTreeClassifier<F>,
433);
434
435impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
436 for FittedExtraTreeClassifierPipelineAdapter<F>
437{
438 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
439 let preds = self.0.predict(x)?;
440 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
441 }
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct ExtraTreeRegressor<F> {
460 pub max_depth: Option<usize>,
462 pub min_samples_split: usize,
464 pub min_samples_leaf: usize,
466 pub max_features: MaxFeatures,
468 pub criterion: RegressionCriterion,
470 pub random_state: Option<u64>,
472 _marker: std::marker::PhantomData<F>,
473}
474
475impl<F: Float> ExtraTreeRegressor<F> {
476 #[must_use]
482 pub fn new() -> Self {
483 Self {
484 max_depth: None,
485 min_samples_split: 2,
486 min_samples_leaf: 1,
487 max_features: MaxFeatures::All,
488 criterion: RegressionCriterion::Mse,
489 random_state: None,
490 _marker: std::marker::PhantomData,
491 }
492 }
493
494 #[must_use]
496 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
497 self.max_depth = max_depth;
498 self
499 }
500
501 #[must_use]
503 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
504 self.min_samples_split = min_samples_split;
505 self
506 }
507
508 #[must_use]
510 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
511 self.min_samples_leaf = min_samples_leaf;
512 self
513 }
514
515 #[must_use]
517 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
518 self.max_features = max_features;
519 self
520 }
521
522 #[must_use]
524 pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
525 self.criterion = criterion;
526 self
527 }
528
529 #[must_use]
531 pub fn with_random_state(mut self, seed: u64) -> Self {
532 self.random_state = Some(seed);
533 self
534 }
535}
536
537impl<F: Float> Default for ExtraTreeRegressor<F> {
538 fn default() -> Self {
539 Self::new()
540 }
541}
542
543#[derive(Debug, Clone)]
551pub struct FittedExtraTreeRegressor<F> {
552 nodes: Vec<Node<F>>,
554 n_features: usize,
556 feature_importances: Array1<F>,
558}
559
560impl<F: Float + Send + Sync + 'static> FittedExtraTreeRegressor<F> {
561 #[must_use]
563 pub fn nodes(&self) -> &[Node<F>] {
564 &self.nodes
565 }
566
567 #[must_use]
569 pub fn n_features(&self) -> usize {
570 self.n_features
571 }
572
573 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
581 if x.nrows() != y.len() {
582 return Err(FerroError::ShapeMismatch {
583 expected: vec![x.nrows()],
584 actual: vec![y.len()],
585 context: "y length must match number of samples in X".into(),
586 });
587 }
588 let preds = self.predict(x)?;
589 Ok(crate::r2_score(&preds, y))
590 }
591}
592
593impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreeRegressor<F> {
594 type Fitted = FittedExtraTreeRegressor<F>;
595 type Error = FerroError;
596
597 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedExtraTreeRegressor<F>, FerroError> {
606 let (n_samples, n_features) = x.dim();
607
608 if n_samples != y.len() {
609 return Err(FerroError::ShapeMismatch {
610 expected: vec![n_samples],
611 actual: vec![y.len()],
612 context: "y length must match number of samples in X".into(),
613 });
614 }
615 if n_samples == 0 {
616 return Err(FerroError::InsufficientSamples {
617 required: 1,
618 actual: 0,
619 context: "ExtraTreeRegressor requires at least one sample".into(),
620 });
621 }
622 if self.min_samples_split < 2 {
623 return Err(FerroError::InvalidParameter {
624 name: "min_samples_split".into(),
625 reason: "must be at least 2".into(),
626 });
627 }
628 if self.min_samples_leaf < 1 {
629 return Err(FerroError::InvalidParameter {
630 name: "min_samples_leaf".into(),
631 reason: "must be at least 1".into(),
632 });
633 }
634
635 let indices: Vec<usize> = (0..n_samples).collect();
636 let max_features_n = resolve_max_features(self.max_features, n_features);
637
638 let mut rng = if let Some(seed) = self.random_state {
639 StdRng::seed_from_u64(seed)
640 } else {
641 StdRng::from_os_rng()
642 };
643
644 let data = RegressionData {
645 x,
646 y,
647 feature_indices: None,
648 };
649 let params = TreeParams {
650 max_depth: self.max_depth,
651 min_samples_split: self.min_samples_split,
652 min_samples_leaf: self.min_samples_leaf,
653 };
654
655 let mut nodes: Vec<Node<F>> = Vec::new();
656 build_extra_regression_tree(
657 &data,
658 &indices,
659 &mut nodes,
660 0,
661 ¶ms,
662 n_features,
663 max_features_n,
664 &mut rng,
665 );
666
667 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
668
669 Ok(FittedExtraTreeRegressor {
670 nodes,
671 n_features,
672 feature_importances,
673 })
674 }
675}
676
677impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeRegressor<F> {
678 type Output = Array1<F>;
679 type Error = FerroError;
680
681 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
688 if x.ncols() != self.n_features {
689 return Err(FerroError::ShapeMismatch {
690 expected: vec![self.n_features],
691 actual: vec![x.ncols()],
692 context: "number of features must match fitted model".into(),
693 });
694 }
695 let n_samples = x.nrows();
696 let mut predictions = Array1::zeros(n_samples);
697 for i in 0..n_samples {
698 let row = x.row(i);
699 let leaf = traverse(&self.nodes, &row);
700 if let Node::Leaf { value, .. } = self.nodes[leaf] {
701 predictions[i] = value;
702 }
703 }
704 Ok(predictions)
705 }
706}
707
708impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeRegressor<F> {
709 fn feature_importances(&self) -> &Array1<F> {
710 &self.feature_importances
711 }
712}
713
714impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreeRegressor<F> {
716 fn fit_pipeline(
717 &self,
718 x: &Array2<F>,
719 y: &Array1<F>,
720 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
721 let fitted = self.fit(x, y)?;
722 Ok(Box::new(fitted))
723 }
724}
725
726impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreeRegressor<F> {
727 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
728 self.predict(x)
729 }
730}
731
732fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
738 let result = match strategy {
739 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
740 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
741 MaxFeatures::All => n_features,
742 MaxFeatures::Fixed(n) => n.min(n_features),
743 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
744 };
745 result.max(1).min(n_features)
746}
747
748fn float_to_usize<F: Float>(v: F) -> usize {
750 v.to_f64().map_or(0, |f| f.round() as usize)
751}
752
753fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
755 use rand::RngCore;
756 let u = (rng.next_u64() as f64) / (u64::MAX as f64);
758 let range = max_val - min_val;
759 min_val + F::from(u).unwrap() * range
760}
761
762fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
764 if total == 0 {
765 return F::zero();
766 }
767 let total_f = F::from(total).unwrap();
768 let mut impurity = F::one();
769 for &count in class_counts {
770 let p = F::from(count).unwrap() / total_f;
771 impurity = impurity - p * p;
772 }
773 impurity
774}
775
776fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
778 if total == 0 {
779 return F::zero();
780 }
781 let total_f = F::from(total).unwrap();
782 let mut ent = F::zero();
783 for &count in class_counts {
784 if count > 0 {
785 let p = F::from(count).unwrap() / total_f;
786 ent = ent - p * p.ln();
787 }
788 }
789 ent
790}
791
792fn compute_impurity<F: Float>(
794 class_counts: &[usize],
795 total: usize,
796 criterion: ClassificationCriterion,
797) -> F {
798 match criterion {
799 ClassificationCriterion::Gini => gini_impurity(class_counts, total),
800 ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
801 }
802}
803
804fn make_classification_leaf<F: Float>(
806 nodes: &mut Vec<Node<F>>,
807 class_counts: &[usize],
808 n_classes: usize,
809 n_samples: usize,
810) -> usize {
811 let majority_class = class_counts
812 .iter()
813 .enumerate()
814 .max_by_key(|&(_, &count)| count)
815 .map_or(0, |(i, _)| i);
816
817 let total_f = if n_samples > 0 {
818 F::from(n_samples).unwrap()
819 } else {
820 F::one()
821 };
822 let distribution: Vec<F> = (0..n_classes)
823 .map(|c| F::from(class_counts[c]).unwrap() / total_f)
824 .collect();
825
826 let idx = nodes.len();
827 nodes.push(Node::Leaf {
828 value: F::from(majority_class).unwrap(),
829 class_distribution: Some(distribution),
830 n_samples,
831 });
832 idx
833}
834
835fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
837 if indices.is_empty() {
838 return F::zero();
839 }
840 let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
841 sum / F::from(indices.len()).unwrap()
842}
843
844#[allow(clippy::too_many_arguments)]
854fn build_extra_classification_tree<F: Float>(
855 data: &ClassificationData<'_, F>,
856 indices: &[usize],
857 nodes: &mut Vec<Node<F>>,
858 depth: usize,
859 params: &TreeParams,
860 n_features: usize,
861 max_features_n: usize,
862 rng: &mut StdRng,
863) -> usize {
864 let n = indices.len();
865
866 let mut class_counts = vec![0usize; data.n_classes];
867 for &i in indices {
868 class_counts[data.y[i]] += 1;
869 }
870
871 let should_stop = n < params.min_samples_split
872 || params.max_depth.is_some_and(|d| depth >= d)
873 || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
874
875 if should_stop {
876 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
877 }
878
879 let best = find_random_classification_split(
880 data,
881 indices,
882 params.min_samples_leaf,
883 n_features,
884 max_features_n,
885 rng,
886 );
887
888 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
889 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
890 .iter()
891 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
892
893 if left_indices.len() < params.min_samples_leaf
895 || right_indices.len() < params.min_samples_leaf
896 {
897 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
898 }
899
900 let node_idx = nodes.len();
901 nodes.push(Node::Leaf {
902 value: F::zero(),
903 class_distribution: None,
904 n_samples: 0,
905 }); let left_idx = build_extra_classification_tree(
908 data,
909 &left_indices,
910 nodes,
911 depth + 1,
912 params,
913 n_features,
914 max_features_n,
915 rng,
916 );
917 let right_idx = build_extra_classification_tree(
918 data,
919 &right_indices,
920 nodes,
921 depth + 1,
922 params,
923 n_features,
924 max_features_n,
925 rng,
926 );
927
928 nodes[node_idx] = Node::Split {
929 feature: best_feature,
930 threshold: best_threshold,
931 left: left_idx,
932 right: right_idx,
933 impurity_decrease: best_impurity_decrease,
934 n_samples: n,
935 };
936
937 node_idx
938 } else {
939 make_classification_leaf(nodes, &class_counts, data.n_classes, n)
940 }
941}
942
943#[allow(clippy::too_many_arguments)]
949fn find_random_classification_split<F: Float>(
950 data: &ClassificationData<'_, F>,
951 indices: &[usize],
952 min_samples_leaf: usize,
953 n_features: usize,
954 max_features_n: usize,
955 rng: &mut StdRng,
956) -> Option<(usize, F, F)> {
957 let n = indices.len();
958 let n_f = F::from(n).unwrap();
959
960 let mut parent_counts = vec![0usize; data.n_classes];
961 for &i in indices {
962 parent_counts[data.y[i]] += 1;
963 }
964 let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
965
966 let mut best_score = F::neg_infinity();
967 let mut best_feature = 0;
968 let mut best_threshold = F::zero();
969
970 let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
972 let k = max_features_n.min(feat_indices.len());
974 rand_sample_indices(rng, feat_indices.len(), k)
975 .into_vec()
976 .into_iter()
977 .map(|i| feat_indices[i])
978 .collect()
979 } else {
980 let k = max_features_n.min(n_features);
981 rand_sample_indices(rng, n_features, k).into_vec()
982 };
983
984 for feat in feature_subset {
985 let mut feat_min = F::infinity();
987 let mut feat_max = F::neg_infinity();
988 for &i in indices {
989 let val = data.x[[i, feat]];
990 if val < feat_min {
991 feat_min = val;
992 }
993 if val > feat_max {
994 feat_max = val;
995 }
996 }
997
998 if feat_min >= feat_max {
1000 continue;
1001 }
1002
1003 let threshold = random_threshold(rng, feat_min, feat_max);
1005
1006 let mut left_counts = vec![0usize; data.n_classes];
1008 let mut right_counts = vec![0usize; data.n_classes];
1009 let mut left_n = 0usize;
1010
1011 for &i in indices {
1012 let cls = data.y[i];
1013 if data.x[[i, feat]] <= threshold {
1014 left_counts[cls] += 1;
1015 left_n += 1;
1016 } else {
1017 right_counts[cls] += 1;
1018 }
1019 }
1020
1021 let right_n = n - left_n;
1022 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1023 continue;
1024 }
1025
1026 let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
1027 let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
1028 let left_weight = F::from(left_n).unwrap() / n_f;
1029 let right_weight = F::from(right_n).unwrap() / n_f;
1030 let weighted_child_impurity = left_weight * left_impurity + right_weight * right_impurity;
1031 let impurity_decrease = parent_impurity - weighted_child_impurity;
1032
1033 if impurity_decrease > best_score {
1034 best_score = impurity_decrease;
1035 best_feature = feat;
1036 best_threshold = threshold;
1037 }
1038 }
1039
1040 if best_score > F::zero() {
1041 Some((best_feature, best_threshold, best_score * n_f))
1042 } else {
1043 None
1044 }
1045}
1046
1047#[allow(clippy::too_many_arguments)]
1053fn build_extra_regression_tree<F: Float>(
1054 data: &RegressionData<'_, F>,
1055 indices: &[usize],
1056 nodes: &mut Vec<Node<F>>,
1057 depth: usize,
1058 params: &TreeParams,
1059 n_features: usize,
1060 max_features_n: usize,
1061 rng: &mut StdRng,
1062) -> usize {
1063 let n = indices.len();
1064 let mean = mean_value(data.y, indices);
1065
1066 let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1067
1068 if should_stop {
1069 let idx = nodes.len();
1070 nodes.push(Node::Leaf {
1071 value: mean,
1072 class_distribution: None,
1073 n_samples: n,
1074 });
1075 return idx;
1076 }
1077
1078 let parent_sum_sq: F = indices
1080 .iter()
1081 .map(|&i| {
1082 let diff = data.y[i] - mean;
1083 diff * diff
1084 })
1085 .fold(F::zero(), |a, b| a + b);
1086 let parent_mse = parent_sum_sq / F::from(n).unwrap();
1087
1088 if parent_mse <= F::epsilon() {
1089 let idx = nodes.len();
1090 nodes.push(Node::Leaf {
1091 value: mean,
1092 class_distribution: None,
1093 n_samples: n,
1094 });
1095 return idx;
1096 }
1097
1098 let best = find_random_regression_split(
1099 data,
1100 indices,
1101 params.min_samples_leaf,
1102 n_features,
1103 max_features_n,
1104 rng,
1105 );
1106
1107 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1108 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1109 .iter()
1110 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1111
1112 if left_indices.len() < params.min_samples_leaf
1114 || right_indices.len() < params.min_samples_leaf
1115 {
1116 let idx = nodes.len();
1117 nodes.push(Node::Leaf {
1118 value: mean,
1119 class_distribution: None,
1120 n_samples: n,
1121 });
1122 return idx;
1123 }
1124
1125 let node_idx = nodes.len();
1126 nodes.push(Node::Leaf {
1127 value: F::zero(),
1128 class_distribution: None,
1129 n_samples: 0,
1130 }); let left_idx = build_extra_regression_tree(
1133 data,
1134 &left_indices,
1135 nodes,
1136 depth + 1,
1137 params,
1138 n_features,
1139 max_features_n,
1140 rng,
1141 );
1142 let right_idx = build_extra_regression_tree(
1143 data,
1144 &right_indices,
1145 nodes,
1146 depth + 1,
1147 params,
1148 n_features,
1149 max_features_n,
1150 rng,
1151 );
1152
1153 nodes[node_idx] = Node::Split {
1154 feature: best_feature,
1155 threshold: best_threshold,
1156 left: left_idx,
1157 right: right_idx,
1158 impurity_decrease: best_impurity_decrease,
1159 n_samples: n,
1160 };
1161
1162 node_idx
1163 } else {
1164 let idx = nodes.len();
1165 nodes.push(Node::Leaf {
1166 value: mean,
1167 class_distribution: None,
1168 n_samples: n,
1169 });
1170 idx
1171 }
1172}
1173
1174#[allow(clippy::too_many_arguments)]
1180fn find_random_regression_split<F: Float>(
1181 data: &RegressionData<'_, F>,
1182 indices: &[usize],
1183 min_samples_leaf: usize,
1184 n_features: usize,
1185 max_features_n: usize,
1186 rng: &mut StdRng,
1187) -> Option<(usize, F, F)> {
1188 let n = indices.len();
1189 let n_f = F::from(n).unwrap();
1190
1191 let parent_sum: F = indices
1192 .iter()
1193 .map(|&i| data.y[i])
1194 .fold(F::zero(), |a, b| a + b);
1195 let parent_sum_sq: F = indices
1196 .iter()
1197 .map(|&i| data.y[i] * data.y[i])
1198 .fold(F::zero(), |a, b| a + b);
1199 let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1200
1201 let mut best_score = F::neg_infinity();
1202 let mut best_feature = 0;
1203 let mut best_threshold = F::zero();
1204
1205 let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
1207 let k = max_features_n.min(feat_indices.len());
1208 rand_sample_indices(rng, feat_indices.len(), k)
1209 .into_vec()
1210 .into_iter()
1211 .map(|i| feat_indices[i])
1212 .collect()
1213 } else {
1214 let k = max_features_n.min(n_features);
1215 rand_sample_indices(rng, n_features, k).into_vec()
1216 };
1217
1218 for feat in feature_subset {
1219 let mut feat_min = F::infinity();
1221 let mut feat_max = F::neg_infinity();
1222 for &i in indices {
1223 let val = data.x[[i, feat]];
1224 if val < feat_min {
1225 feat_min = val;
1226 }
1227 if val > feat_max {
1228 feat_max = val;
1229 }
1230 }
1231
1232 if feat_min >= feat_max {
1234 continue;
1235 }
1236
1237 let threshold = random_threshold(rng, feat_min, feat_max);
1239
1240 let mut left_sum = F::zero();
1242 let mut left_sum_sq = F::zero();
1243 let mut left_n: usize = 0;
1244
1245 for &i in indices {
1246 if data.x[[i, feat]] <= threshold {
1247 let val = data.y[i];
1248 left_sum = left_sum + val;
1249 left_sum_sq = left_sum_sq + val * val;
1250 left_n += 1;
1251 }
1252 }
1253
1254 let right_n = n - left_n;
1255 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1256 continue;
1257 }
1258
1259 let left_n_f = F::from(left_n).unwrap();
1260 let right_n_f = F::from(right_n).unwrap();
1261
1262 let left_mean = left_sum / left_n_f;
1263 let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1264
1265 let right_sum = parent_sum - left_sum;
1266 let right_sum_sq = parent_sum_sq - left_sum_sq;
1267 let right_mean = right_sum / right_n_f;
1268 let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1269
1270 let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1271 let mse_decrease = parent_mse - weighted_child_mse;
1272
1273 if mse_decrease > best_score {
1274 best_score = mse_decrease;
1275 best_feature = feat;
1276 best_threshold = threshold;
1277 }
1278 }
1279
1280 if best_score > F::zero() {
1281 Some((best_feature, best_threshold, best_score * n_f))
1282 } else {
1283 None
1284 }
1285}
1286
1287#[allow(clippy::too_many_arguments)]
1295pub(crate) fn build_extra_classification_tree_for_ensemble<F: Float>(
1296 x: &Array2<F>,
1297 y: &[usize],
1298 n_classes: usize,
1299 indices: &[usize],
1300 feature_indices: Option<&[usize]>,
1301 params: &TreeParams,
1302 criterion: ClassificationCriterion,
1303 n_features: usize,
1304 max_features_n: usize,
1305 rng: &mut StdRng,
1306) -> Vec<Node<F>> {
1307 let data = ClassificationData {
1308 x,
1309 y,
1310 n_classes,
1311 feature_indices,
1312 criterion,
1313 };
1314 let mut nodes = Vec::new();
1315 build_extra_classification_tree(
1316 &data,
1317 indices,
1318 &mut nodes,
1319 0,
1320 params,
1321 n_features,
1322 max_features_n,
1323 rng,
1324 );
1325 nodes
1326}
1327
1328#[allow(clippy::too_many_arguments)]
1332pub(crate) fn build_extra_regression_tree_for_ensemble<F: Float>(
1333 x: &Array2<F>,
1334 y: &Array1<F>,
1335 indices: &[usize],
1336 feature_indices: Option<&[usize]>,
1337 params: &TreeParams,
1338 n_features: usize,
1339 max_features_n: usize,
1340 rng: &mut StdRng,
1341) -> Vec<Node<F>> {
1342 let data = RegressionData {
1343 x,
1344 y,
1345 feature_indices,
1346 };
1347 let mut nodes = Vec::new();
1348 build_extra_regression_tree(
1349 &data,
1350 indices,
1351 &mut nodes,
1352 0,
1353 params,
1354 n_features,
1355 max_features_n,
1356 rng,
1357 );
1358 nodes
1359}
1360
1361#[cfg(test)]
1366mod tests {
1367 use super::*;
1368 use approx::assert_relative_eq;
1369 use ndarray::array;
1370
1371 #[test]
1374 fn test_extra_classifier_simple_binary() {
1375 let x = Array2::from_shape_vec(
1376 (6, 2),
1377 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1378 )
1379 .unwrap();
1380 let y = array![0, 0, 0, 1, 1, 1];
1381
1382 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1383 let fitted = model.fit(&x, &y).unwrap();
1384 let preds = fitted.predict(&x).unwrap();
1385
1386 assert_eq!(preds.len(), 6);
1387 for i in 0..3 {
1389 assert_eq!(preds[i], 0, "sample {i} should be class 0");
1390 }
1391 for i in 3..6 {
1392 assert_eq!(preds[i], 1, "sample {i} should be class 1");
1393 }
1394 }
1395
1396 #[test]
1397 fn test_extra_classifier_single_class() {
1398 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1399 let y = array![0, 0, 0];
1400
1401 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1402 let fitted = model.fit(&x, &y).unwrap();
1403 let preds = fitted.predict(&x).unwrap();
1404
1405 assert_eq!(preds, array![0, 0, 0]);
1406 }
1407
1408 #[test]
1409 fn test_extra_classifier_max_depth_1() {
1410 let x =
1411 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1412 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1413
1414 let model = ExtraTreeClassifier::<f64>::new()
1415 .with_max_depth(Some(1))
1416 .with_max_features(MaxFeatures::All)
1417 .with_random_state(42);
1418 let fitted = model.fit(&x, &y).unwrap();
1419 let _preds = fitted.predict(&x).unwrap();
1420
1421 assert_eq!(fitted.nodes().len(), 3);
1424 }
1425
1426 #[test]
1427 fn test_extra_classifier_predict_proba() {
1428 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1429 let y = array![0, 0, 0, 1, 1, 1];
1430
1431 let model = ExtraTreeClassifier::<f64>::new()
1432 .with_max_features(MaxFeatures::All)
1433 .with_random_state(42);
1434 let fitted = model.fit(&x, &y).unwrap();
1435 let proba = fitted.predict_proba(&x).unwrap();
1436
1437 assert_eq!(proba.dim(), (6, 2));
1438 for i in 0..6 {
1440 let row_sum = proba.row(i).sum();
1441 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1442 }
1443 }
1444
1445 #[test]
1446 fn test_extra_classifier_feature_importances() {
1447 let x = Array2::from_shape_vec(
1448 (8, 2),
1449 vec![
1450 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1451 ],
1452 )
1453 .unwrap();
1454 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1455
1456 let model = ExtraTreeClassifier::<f64>::new()
1457 .with_max_features(MaxFeatures::All)
1458 .with_random_state(42);
1459 let fitted = model.fit(&x, &y).unwrap();
1460 let importances = fitted.feature_importances();
1461
1462 assert_eq!(importances.len(), 2);
1463 let total: f64 = importances.sum();
1465 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1466 assert!(importances[0] > importances[1]);
1468 }
1469
1470 #[test]
1471 fn test_extra_classifier_shape_mismatch() {
1472 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1473 let y = array![0, 0]; let model = ExtraTreeClassifier::<f64>::new();
1476 assert!(model.fit(&x, &y).is_err());
1477 }
1478
1479 #[test]
1480 fn test_extra_classifier_empty_data() {
1481 let x = Array2::<f64>::zeros((0, 2));
1482 let y = Array1::<usize>::zeros(0);
1483
1484 let model = ExtraTreeClassifier::<f64>::new();
1485 assert!(model.fit(&x, &y).is_err());
1486 }
1487
1488 #[test]
1489 fn test_extra_classifier_invalid_min_samples_split() {
1490 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1491 let y = array![0, 0, 1];
1492
1493 let model = ExtraTreeClassifier::<f64>::new().with_min_samples_split(1);
1494 assert!(model.fit(&x, &y).is_err());
1495 }
1496
1497 #[test]
1498 fn test_extra_classifier_classes() {
1499 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1500 let y = array![0, 0, 0, 2, 2, 2]; let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1503 let fitted = model.fit(&x, &y).unwrap();
1504
1505 assert_eq!(fitted.classes(), &[0, 2]);
1506 assert_eq!(fitted.n_classes(), 2);
1507 }
1508
1509 #[test]
1510 fn test_extra_classifier_predict_shape_mismatch() {
1511 let x =
1512 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1513 let y = array![0, 0, 1, 1];
1514
1515 let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1516 let fitted = model.fit(&x, &y).unwrap();
1517
1518 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1519 assert!(fitted.predict(&x_wrong).is_err());
1520 }
1521
1522 #[test]
1523 fn test_extra_classifier_f32() {
1524 let x = Array2::from_shape_vec(
1525 (6, 2),
1526 vec![
1527 1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0,
1528 ],
1529 )
1530 .unwrap();
1531 let y = array![0, 0, 0, 1, 1, 1];
1532
1533 let model = ExtraTreeClassifier::<f32>::new().with_random_state(42);
1534 let fitted = model.fit(&x, &y).unwrap();
1535 let preds = fitted.predict(&x).unwrap();
1536 assert_eq!(preds.len(), 6);
1537 }
1538
1539 #[test]
1540 fn test_extra_classifier_deterministic() {
1541 let x = Array2::from_shape_vec(
1542 (8, 2),
1543 vec![
1544 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1545 ],
1546 )
1547 .unwrap();
1548 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1549
1550 let model1 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1551 let model2 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1552
1553 let fitted1 = model1.fit(&x, &y).unwrap();
1554 let fitted2 = model2.fit(&x, &y).unwrap();
1555
1556 let preds1 = fitted1.predict(&x).unwrap();
1557 let preds2 = fitted2.predict(&x).unwrap();
1558
1559 assert_eq!(preds1, preds2);
1560 }
1561
1562 #[test]
1565 fn test_extra_regressor_simple() {
1566 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1567 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1568
1569 let model = ExtraTreeRegressor::<f64>::new()
1570 .with_max_features(MaxFeatures::All)
1571 .with_random_state(42);
1572 let fitted = model.fit(&x, &y).unwrap();
1573 let preds = fitted.predict(&x).unwrap();
1574
1575 assert_eq!(preds.len(), 6);
1577 for i in 0..6 {
1578 assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1579 }
1580 }
1581
1582 #[test]
1583 fn test_extra_regressor_constant_target() {
1584 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1585 let y = array![5.0, 5.0, 5.0, 5.0];
1586
1587 let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1588 let fitted = model.fit(&x, &y).unwrap();
1589 let preds = fitted.predict(&x).unwrap();
1590
1591 for &p in &preds {
1592 assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1593 }
1594 }
1595
1596 #[test]
1597 fn test_extra_regressor_feature_importances() {
1598 let x = Array2::from_shape_vec(
1599 (8, 2),
1600 vec![
1601 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1602 ],
1603 )
1604 .unwrap();
1605 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1606
1607 let model = ExtraTreeRegressor::<f64>::new()
1608 .with_max_features(MaxFeatures::All)
1609 .with_random_state(42);
1610 let fitted = model.fit(&x, &y).unwrap();
1611 let importances = fitted.feature_importances();
1612
1613 assert_eq!(importances.len(), 2);
1614 let total: f64 = importances.sum();
1615 assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1616 assert!(importances[0] > importances[1]);
1618 }
1619
1620 #[test]
1621 fn test_extra_regressor_shape_mismatch() {
1622 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1623 let y = array![1.0, 2.0]; let model = ExtraTreeRegressor::<f64>::new();
1626 assert!(model.fit(&x, &y).is_err());
1627 }
1628
1629 #[test]
1630 fn test_extra_regressor_empty_data() {
1631 let x = Array2::<f64>::zeros((0, 2));
1632 let y = Array1::<f64>::zeros(0);
1633
1634 let model = ExtraTreeRegressor::<f64>::new();
1635 assert!(model.fit(&x, &y).is_err());
1636 }
1637
1638 #[test]
1639 fn test_extra_regressor_predict_shape_mismatch() {
1640 let x =
1641 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1642 let y = array![1.0, 2.0, 3.0, 4.0];
1643
1644 let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1645 let fitted = model.fit(&x, &y).unwrap();
1646
1647 let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1648 assert!(fitted.predict(&x_wrong).is_err());
1649 }
1650
1651 #[test]
1652 fn test_extra_regressor_max_depth() {
1653 let x =
1654 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1655 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1656
1657 let model = ExtraTreeRegressor::<f64>::new()
1658 .with_max_depth(Some(1))
1659 .with_max_features(MaxFeatures::All)
1660 .with_random_state(42);
1661 let fitted = model.fit(&x, &y).unwrap();
1662
1663 assert_eq!(fitted.nodes().len(), 3);
1665 }
1666
1667 #[test]
1668 fn test_extra_regressor_deterministic() {
1669 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1670 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1671
1672 let model1 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1673 let model2 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1674
1675 let fitted1 = model1.fit(&x, &y).unwrap();
1676 let fitted2 = model2.fit(&x, &y).unwrap();
1677
1678 let preds1 = fitted1.predict(&x).unwrap();
1679 let preds2 = fitted2.predict(&x).unwrap();
1680
1681 for i in 0..6 {
1682 assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1683 }
1684 }
1685
1686 #[test]
1687 fn test_extra_regressor_f32() {
1688 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1689 let y = array![1.0f32, 2.0, 3.0, 4.0];
1690
1691 let model = ExtraTreeRegressor::<f32>::new().with_random_state(42);
1692 let fitted = model.fit(&x, &y).unwrap();
1693 let preds = fitted.predict(&x).unwrap();
1694 assert_eq!(preds.len(), 4);
1695 }
1696
1697 #[test]
1700 fn test_classifier_builder_methods() {
1701 let model = ExtraTreeClassifier::<f64>::new()
1702 .with_max_depth(Some(5))
1703 .with_min_samples_split(10)
1704 .with_min_samples_leaf(3)
1705 .with_max_features(MaxFeatures::Log2)
1706 .with_criterion(ClassificationCriterion::Entropy)
1707 .with_random_state(42);
1708
1709 assert_eq!(model.max_depth, Some(5));
1710 assert_eq!(model.min_samples_split, 10);
1711 assert_eq!(model.min_samples_leaf, 3);
1712 assert_eq!(model.max_features, MaxFeatures::Log2);
1713 assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1714 assert_eq!(model.random_state, Some(42));
1715 }
1716
1717 #[test]
1718 fn test_regressor_builder_methods() {
1719 let model = ExtraTreeRegressor::<f64>::new()
1720 .with_max_depth(Some(10))
1721 .with_min_samples_split(5)
1722 .with_min_samples_leaf(2)
1723 .with_max_features(MaxFeatures::Fixed(3))
1724 .with_criterion(RegressionCriterion::Mse)
1725 .with_random_state(99);
1726
1727 assert_eq!(model.max_depth, Some(10));
1728 assert_eq!(model.min_samples_split, 5);
1729 assert_eq!(model.min_samples_leaf, 2);
1730 assert_eq!(model.max_features, MaxFeatures::Fixed(3));
1731 assert_eq!(model.criterion, RegressionCriterion::Mse);
1732 assert_eq!(model.random_state, Some(99));
1733 }
1734
1735 #[test]
1736 fn test_classifier_default() {
1737 let model = ExtraTreeClassifier::<f64>::default();
1738 assert_eq!(model.max_depth, None);
1739 assert_eq!(model.min_samples_split, 2);
1740 assert_eq!(model.min_samples_leaf, 1);
1741 assert_eq!(model.max_features, MaxFeatures::Sqrt);
1742 assert_eq!(model.criterion, ClassificationCriterion::Gini);
1743 assert_eq!(model.random_state, None);
1744 }
1745
1746 #[test]
1747 fn test_regressor_default() {
1748 let model = ExtraTreeRegressor::<f64>::default();
1749 assert_eq!(model.max_depth, None);
1750 assert_eq!(model.min_samples_split, 2);
1751 assert_eq!(model.min_samples_leaf, 1);
1752 assert_eq!(model.max_features, MaxFeatures::All);
1753 assert_eq!(model.criterion, RegressionCriterion::Mse);
1754 assert_eq!(model.random_state, None);
1755 }
1756}