Skip to main content

anofox_ml_ensemble/
extra_trees_classifier.rs

1use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
2use anofox_ml_trees::node::TreeNode;
3use anofox_ml_trees::split::{
4    compute_impurity, count_classes, find_random_split, leaf_value, SplitCriterion,
5};
6use ndarray::{Array1, Array2};
7use rand::rngs::StdRng;
8use rand::{Rng, SeedableRng};
9use rayon::prelude::*;
10
11/// Extra-Trees (Extremely Randomized Trees) classifier parameters (unfitted state).
12///
13/// Trains an ensemble of decision trees using random split thresholds instead of
14/// the best possible split at each node. Unlike Random Forests, Extra-Trees does
15/// **not** use bootstrap sampling — each tree is trained on the full dataset.
16/// However, each tree still considers a random subset of features at each split.
17///
18/// The randomization in split thresholds reduces variance further than Random
19/// Forests and can lead to smoother decision boundaries.
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct ExtraTreesClassifier {
22    /// Number of trees in the forest.
23    pub n_estimators: usize,
24    /// Maximum depth of each tree.
25    pub max_depth: Option<usize>,
26    /// Minimum samples required to split a node.
27    pub min_samples_split: usize,
28    /// Minimum samples required in a leaf node.
29    pub min_samples_leaf: usize,
30    /// Number of features to consider per tree. If `None`, all features are used.
31    pub max_features: Option<usize>,
32    /// Random seed for reproducibility.
33    pub seed: u64,
34}
35
36impl ExtraTreesClassifier {
37    /// Create a new `ExtraTreesClassifier` with the given number of trees and default parameters.
38    pub fn new(n_estimators: usize) -> Self {
39        Self {
40            n_estimators,
41            max_depth: None,
42            min_samples_split: 2,
43            min_samples_leaf: 1,
44            max_features: None,
45            seed: 0,
46        }
47    }
48
49    /// Set the maximum depth of each tree.
50    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
51        self.max_depth = max_depth;
52        self
53    }
54
55    /// Set the minimum number of samples required to split a node.
56    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
57        self.min_samples_split = min_samples_split;
58        self
59    }
60
61    /// Set the minimum number of samples required in a leaf node.
62    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
63        self.min_samples_leaf = min_samples_leaf;
64        self
65    }
66
67    /// Set the number of features to consider per tree.
68    pub fn with_max_features(mut self, max_features: Option<usize>) -> Self {
69        self.max_features = max_features;
70        self
71    }
72
73    /// Set the random seed for reproducibility.
74    pub fn with_seed(mut self, seed: u64) -> Self {
75        self.seed = seed;
76        self
77    }
78}
79
80impl Default for ExtraTreesClassifier {
81    fn default() -> Self {
82        Self::new(100)
83    }
84}
85
86/// A single tree in the Extra-Trees ensemble together with its selected feature indices.
87#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
88#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
89struct ExtraForestTree<F: Float> {
90    tree: TreeNode<F>,
91    /// Indices of the features this tree was trained on (relative to the
92    /// original feature matrix). When `max_features` is `None` this contains
93    /// `0..n_features`.
94    feature_indices: Vec<usize>,
95    /// Number of features the tree was trained on.
96    n_features_tree: usize,
97}
98
99/// Fitted Extra-Trees classifier.
100#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
101#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
102pub struct FittedExtraTreesClassifier<F: Float> {
103    trees: Vec<ExtraForestTree<F>>,
104    n_features: usize,
105}
106
107impl<F: Float> Fit<F> for ExtraTreesClassifier {
108    type Fitted = FittedExtraTreesClassifier<F>;
109
110    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
111        if x.nrows() != y.len() {
112            return Err(RustMlError::ShapeMismatch(format!(
113                "X has {} rows but y has {} elements",
114                x.nrows(),
115                y.len()
116            )));
117        }
118        if x.is_empty() {
119            return Err(RustMlError::EmptyInput("training data is empty".into()));
120        }
121        if self.n_estimators == 0 {
122            return Err(RustMlError::InvalidParameter(
123                "n_estimators must be > 0".into(),
124            ));
125        }
126
127        let n_features = x.ncols();
128
129        if let Some(k) = self.max_features {
130            if k == 0 || k > n_features {
131                return Err(RustMlError::InvalidParameter(format!(
132                    "max_features={k} is invalid for data with {n_features} features"
133                )));
134            }
135        }
136
137        let mut rng = StdRng::seed_from_u64(self.seed);
138
139        // Pre-generate feature indices and per-tree seeds for determinism.
140        // ExtraTrees does NOT use bootstrap — each tree trains on the full dataset.
141        let tree_plans: Vec<(Vec<usize>, u64)> = (0..self.n_estimators)
142            .map(|_| {
143                let feature_indices = select_features(n_features, self.max_features, &mut rng);
144                let tree_seed: u64 = rng.gen();
145                (feature_indices, tree_seed)
146            })
147            .collect();
148
149        let max_depth = self.max_depth;
150        let min_samples_split = self.min_samples_split;
151        let min_samples_leaf = self.min_samples_leaf;
152
153        // Train trees in parallel
154        let trees: Vec<ExtraForestTree<F>> = tree_plans
155            .into_par_iter()
156            .map(|(feature_indices, tree_seed)| {
157                // Build sub-matrix with only selected features (all rows — no bootstrap)
158                let x_sub = build_sub_matrix_cols(x, &feature_indices);
159                let n_features_tree = feature_indices.len();
160                let indices: Vec<usize> = (0..x.nrows()).collect();
161
162                let tree = build_extra_tree(
163                    &x_sub,
164                    y,
165                    &indices,
166                    0,
167                    max_depth,
168                    min_samples_split,
169                    min_samples_leaf,
170                    SplitCriterion::Gini,
171                    tree_seed,
172                );
173
174                ExtraForestTree {
175                    tree,
176                    feature_indices,
177                    n_features_tree,
178                }
179            })
180            .collect();
181
182        Ok(FittedExtraTreesClassifier { trees, n_features })
183    }
184}
185
186impl<F: Float> Predict<F> for FittedExtraTreesClassifier<F> {
187    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
188        if x.ncols() != self.n_features {
189            return Err(RustMlError::ShapeMismatch(format!(
190                "expected {} features, got {}",
191                self.n_features,
192                x.ncols()
193            )));
194        }
195
196        let n_samples = x.nrows();
197        let n_trees = self.trees.len();
198
199        // Collect all tree predictions in parallel
200        let all_preds: Vec<Array1<F>> = self
201            .trees
202            .par_iter()
203            .map(|forest_tree| {
204                let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
205                let preds: Vec<F> = sub_x
206                    .rows()
207                    .into_iter()
208                    .map(|row| forest_tree.tree.predict_one(row.as_slice().unwrap()))
209                    .collect();
210                Array1::from_vec(preds)
211            })
212            .collect();
213
214        // Aggregate votes per sample (majority vote)
215        let mut predictions = Vec::with_capacity(n_samples);
216        let mut votes = Vec::with_capacity(n_trees);
217        for i in 0..n_samples {
218            votes.clear();
219            for tree_pred in &all_preds {
220                votes.push(tree_pred[i]);
221            }
222            predictions.push(majority_vote(&votes));
223        }
224
225        Ok(Array1::from_vec(predictions))
226    }
227}
228
229impl<F: Float> FittedExtraTreesClassifier<F> {
230    /// Feature importances averaged across all trees and normalized to sum to 1.
231    ///
232    /// Each tree's importances are computed in its own (possibly reduced)
233    /// feature space, then mapped back to the original feature indices and
234    /// averaged.
235    pub fn feature_importances(&self) -> Array1<F> {
236        let mut importances = vec![F::zero(); self.n_features];
237        let n_trees = F::from_usize(self.trees.len()).unwrap();
238
239        for forest_tree in &self.trees {
240            let total_samples = tree_n_samples(&forest_tree.tree);
241            let tree_raw = forest_tree
242                .tree
243                .feature_importances(forest_tree.n_features_tree, total_samples);
244            // Normalize individual tree importances
245            let sum: F = tree_raw.iter().copied().fold(F::zero(), |a, b| a + b);
246            for (local_idx, &original_idx) in forest_tree.feature_indices.iter().enumerate() {
247                if sum > F::zero() {
248                    importances[original_idx] += (tree_raw[local_idx] / sum) / n_trees;
249                }
250            }
251        }
252
253        // Normalize so importances sum to 1
254        let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
255        if sum > F::zero() {
256            Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
257        } else {
258            Array1::zeros(self.n_features)
259        }
260    }
261
262    /// Predict class probabilities for each sample.
263    ///
264    /// Returns a vector of vectors, where each inner vector contains
265    /// `(class_label, probability)` pairs sorted by class label.
266    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Vec<Vec<(F, F)>>> {
267        if x.ncols() != self.n_features {
268            return Err(RustMlError::ShapeMismatch(format!(
269                "expected {} features, got {}",
270                self.n_features,
271                x.ncols()
272            )));
273        }
274
275        let n_samples = x.nrows();
276        let n_trees = self.trees.len();
277        let n_trees_f = F::from_usize(n_trees).unwrap();
278
279        // Collect all tree predictions in parallel
280        let all_preds: Vec<Array1<F>> = self
281            .trees
282            .par_iter()
283            .map(|forest_tree| {
284                let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
285                let preds: Vec<F> = sub_x
286                    .rows()
287                    .into_iter()
288                    .map(|row| forest_tree.tree.predict_one(row.as_slice().unwrap()))
289                    .collect();
290                Array1::from_vec(preds)
291            })
292            .collect();
293
294        // For each sample, count votes per class and convert to probabilities
295        let mut result = Vec::with_capacity(n_samples);
296        for i in 0..n_samples {
297            let mut class_votes: std::collections::HashMap<u64, (F, usize)> =
298                std::collections::HashMap::new();
299            for tree_pred in &all_preds {
300                let v = tree_pred[i];
301                let key = v.to_f64().unwrap().to_bits();
302                class_votes
303                    .entry(key)
304                    .and_modify(|e| e.1 += 1)
305                    .or_insert((v, 1));
306            }
307
308            let mut probs: Vec<(F, F)> = class_votes
309                .into_values()
310                .map(|(class, count)| (class, F::from_usize(count).unwrap() / n_trees_f))
311                .collect();
312            probs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
313            result.push(probs);
314        }
315
316        Ok(result)
317    }
318
319    /// Number of trees in the ensemble.
320    pub fn n_estimators(&self) -> usize {
321        self.trees.len()
322    }
323}
324
325// ---------------------------------------------------------------------------
326// Tree-building with random splits
327// ---------------------------------------------------------------------------
328
329/// Bundled parameters for recursive extra-tree building.
330#[allow(clippy::too_many_arguments)]
331fn build_extra_tree<F: Float>(
332    x: &Array2<F>,
333    y: &Array1<F>,
334    indices: &[usize],
335    depth: usize,
336    max_depth: Option<usize>,
337    min_samples_split: usize,
338    min_samples_leaf: usize,
339    criterion: SplitCriterion,
340    seed: u64,
341) -> TreeNode<F> {
342    let n_samples = indices.len();
343    let impurity = compute_impurity(y, indices, criterion);
344
345    // Check stopping criteria
346    let should_stop = n_samples < min_samples_split
347        || max_depth.is_some_and(|d| depth >= d)
348        || impurity < F::from_f64(1e-15).unwrap();
349
350    if should_stop {
351        return make_leaf(y, indices, criterion);
352    }
353
354    // Use a depth-dependent seed so left/right children get different randomness
355    let split_seed = seed
356        .wrapping_add(depth as u64)
357        .wrapping_mul(0x517CC1B727220A95);
358
359    match find_random_split(x, y, indices, criterion, min_samples_leaf, split_seed) {
360        Some(split) => {
361            let left = build_extra_tree(
362                x,
363                y,
364                &split.left_indices,
365                depth + 1,
366                max_depth,
367                min_samples_split,
368                min_samples_leaf,
369                criterion,
370                seed.wrapping_add(1),
371            );
372            let right = build_extra_tree(
373                x,
374                y,
375                &split.right_indices,
376                depth + 1,
377                max_depth,
378                min_samples_split,
379                min_samples_leaf,
380                criterion,
381                seed.wrapping_add(2),
382            );
383
384            TreeNode::Split {
385                feature_index: split.feature_index,
386                threshold: split.threshold,
387                left: Box::new(left),
388                right: Box::new(right),
389                n_samples,
390                impurity,
391            }
392        }
393        None => make_leaf(y, indices, criterion),
394    }
395}
396
397fn make_leaf<F: Float>(y: &Array1<F>, indices: &[usize], criterion: SplitCriterion) -> TreeNode<F> {
398    let value = leaf_value(y, indices, criterion);
399    let class_counts = match criterion {
400        SplitCriterion::Gini | SplitCriterion::Entropy => Some(count_classes(y, indices)),
401        SplitCriterion::Mse => None,
402    };
403    TreeNode::Leaf {
404        value,
405        n_samples: indices.len(),
406        class_counts,
407    }
408}
409
410fn tree_n_samples<F: Float>(node: &TreeNode<F>) -> usize {
411    match node {
412        TreeNode::Leaf { n_samples, .. } => *n_samples,
413        TreeNode::Split { n_samples, .. } => *n_samples,
414    }
415}
416
417// ---------------------------------------------------------------------------
418// Helper functions (same as in random_forest_classifier)
419// ---------------------------------------------------------------------------
420
421/// Select `k` distinct feature indices from `0..n_features` without replacement.
422/// If `max_features` is `None`, returns all feature indices.
423fn select_features(n_features: usize, max_features: Option<usize>, rng: &mut StdRng) -> Vec<usize> {
424    match max_features {
425        None => (0..n_features).collect(),
426        Some(k) => {
427            // Fisher-Yates partial shuffle
428            let mut indices: Vec<usize> = (0..n_features).collect();
429            for i in 0..k {
430                let j = rng.gen_range(i..n_features);
431                indices.swap(i, j);
432            }
433            indices.truncate(k);
434            indices.sort_unstable();
435            indices
436        }
437    }
438}
439
440/// Build a sub-matrix selecting all rows but only specific columns from `x`.
441/// Produces a guaranteed C-contiguous (standard layout) array so that
442/// `row.as_slice()` works in downstream predict calls.
443fn build_sub_matrix_cols<F: Float>(x: &Array2<F>, col_indices: &[usize]) -> Array2<F> {
444    let n_rows = x.nrows();
445    let n_cols = col_indices.len();
446    let mut data = Vec::with_capacity(n_rows * n_cols);
447    for i in 0..n_rows {
448        for &ci in col_indices {
449            data.push(x[[i, ci]]);
450        }
451    }
452    Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
453}
454
455/// Return the class that appears most frequently in `votes`.
456/// Uses HashMap with f64 bit representation for O(1) lookup per vote.
457#[inline]
458fn majority_vote<F: Float>(votes: &[F]) -> F {
459    use std::collections::HashMap;
460    let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
461    for &v in votes {
462        let key = v.to_f64().unwrap().to_bits();
463        counts.entry(key).and_modify(|e| e.1 += 1).or_insert((v, 1));
464    }
465    counts
466        .into_values()
467        .max_by_key(|&(_, count)| count)
468        .unwrap()
469        .0
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use approx::assert_abs_diff_eq;
476    use ndarray::array;
477
478    #[test]
479    fn test_basic_classification() {
480        let x = array![
481            [1.0, 0.0],
482            [2.0, 0.0],
483            [3.0, 0.0],
484            [10.0, 1.0],
485            [11.0, 1.0],
486            [12.0, 1.0]
487        ];
488        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
489
490        let et = ExtraTreesClassifier {
491            n_estimators: 20,
492            max_depth: Some(3),
493            seed: 42,
494            ..Default::default()
495        };
496        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
497
498        let preds = fitted.predict(&x).unwrap();
499        for (p, t) in preds.iter().zip(y.iter()) {
500            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
501        }
502    }
503
504    #[test]
505    fn test_reproducibility() {
506        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
507        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
508
509        let et = ExtraTreesClassifier {
510            n_estimators: 10,
511            seed: 123,
512            ..Default::default()
513        };
514
515        let fitted1: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
516        let fitted2: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
517
518        let preds1 = fitted1.predict(&x).unwrap();
519        let preds2 = fitted2.predict(&x).unwrap();
520
521        for (a, b) in preds1.iter().zip(preds2.iter()) {
522            assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
523        }
524    }
525
526    #[test]
527    fn test_max_features() {
528        let x = array![
529            [1.0, 100.0, 0.5],
530            [2.0, 200.0, 0.6],
531            [3.0, 300.0, 0.7],
532            [10.0, 400.0, 0.8],
533            [11.0, 500.0, 0.9],
534            [12.0, 600.0, 1.0]
535        ];
536        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
537
538        let et = ExtraTreesClassifier {
539            n_estimators: 30,
540            max_features: Some(2),
541            seed: 99,
542            ..Default::default()
543        };
544        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
545
546        // Training accuracy should be high
547        let preds = fitted.predict(&x).unwrap();
548        for (p, t) in preds.iter().zip(y.iter()) {
549            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
550        }
551    }
552
553    #[test]
554    fn test_feature_importances_sum_to_one() {
555        let x = array![
556            [1.0, 100.0],
557            [2.0, 200.0],
558            [3.0, 300.0],
559            [4.0, 400.0],
560            [5.0, 500.0],
561            [6.0, 600.0]
562        ];
563        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
564
565        let et = ExtraTreesClassifier {
566            n_estimators: 20,
567            seed: 7,
568            ..Default::default()
569        };
570        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
571
572        let importances = fitted.feature_importances();
573        let sum: f64 = importances.iter().sum();
574        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
575    }
576
577    #[test]
578    fn test_feature_importances_non_negative() {
579        let x = array![
580            [1.0, 100.0, 0.5],
581            [2.0, 200.0, 0.6],
582            [3.0, 300.0, 0.7],
583            [10.0, 400.0, 0.8],
584            [11.0, 500.0, 0.9],
585            [12.0, 600.0, 1.0]
586        ];
587        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
588
589        let et = ExtraTreesClassifier {
590            n_estimators: 20,
591            seed: 7,
592            ..Default::default()
593        };
594        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
595
596        let importances = fitted.feature_importances();
597        for &imp in importances.iter() {
598            assert!(
599                imp >= 0.0,
600                "feature importance must be non-negative, got {imp}"
601            );
602        }
603    }
604
605    #[test]
606    fn test_n_estimators() {
607        let x = array![[1.0], [2.0], [3.0], [4.0]];
608        let y = array![0.0, 0.0, 1.0, 1.0];
609
610        let et = ExtraTreesClassifier {
611            n_estimators: 7,
612            seed: 0,
613            ..Default::default()
614        };
615        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
616        assert_eq!(fitted.n_estimators(), 7);
617    }
618
619    #[test]
620    fn test_shape_mismatch_error() {
621        let x = array![[1.0], [2.0]];
622        let y = array![0.0, 1.0, 2.0];
623
624        let et = ExtraTreesClassifier::default();
625        let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
626        assert!(result.is_err());
627    }
628
629    #[test]
630    fn test_predict_wrong_features_error() {
631        let x = array![[1.0, 2.0], [3.0, 4.0]];
632        let y = array![0.0, 1.0];
633
634        let et = ExtraTreesClassifier {
635            n_estimators: 5,
636            seed: 0,
637            ..Default::default()
638        };
639        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
640
641        let x_bad = array![[1.0], [2.0]];
642        let result = fitted.predict(&x_bad);
643        assert!(result.is_err());
644    }
645
646    #[test]
647    fn test_invalid_max_features() {
648        let x = array![[1.0, 2.0], [3.0, 4.0]];
649        let y = array![0.0, 1.0];
650
651        let et = ExtraTreesClassifier {
652            n_estimators: 5,
653            max_features: Some(5),
654            seed: 0,
655            ..Default::default()
656        };
657        let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
658        assert!(result.is_err());
659    }
660
661    #[test]
662    fn test_zero_estimators_error() {
663        let x = array![[1.0, 2.0], [3.0, 4.0]];
664        let y = array![0.0, 1.0];
665
666        let et = ExtraTreesClassifier {
667            n_estimators: 0,
668            seed: 0,
669            ..Default::default()
670        };
671        let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
672        assert!(result.is_err());
673    }
674
675    #[test]
676    fn test_empty_input_error() {
677        let x: Array2<f64> = Array2::zeros((0, 2));
678        let y: Array1<f64> = Array1::zeros(0);
679
680        let et = ExtraTreesClassifier::default();
681        let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
682        assert!(result.is_err());
683    }
684
685    #[test]
686    fn test_n_estimators_one() {
687        let x = array![
688            [1.0, 0.0],
689            [2.0, 0.0],
690            [3.0, 0.0],
691            [10.0, 1.0],
692            [11.0, 1.0],
693            [12.0, 1.0]
694        ];
695        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
696
697        let et = ExtraTreesClassifier {
698            n_estimators: 1,
699            max_depth: Some(3),
700            seed: 42,
701            ..Default::default()
702        };
703        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
704        assert_eq!(fitted.n_estimators(), 1);
705
706        // A single tree should still produce valid predictions.
707        let preds = fitted.predict(&x).unwrap();
708        assert_eq!(preds.len(), y.len());
709    }
710
711    #[test]
712    fn test_predictions_are_valid_labels() {
713        let x = array![
714            [1.0, 0.0],
715            [2.0, 0.0],
716            [3.0, 0.0],
717            [10.0, 1.0],
718            [11.0, 1.0],
719            [12.0, 1.0],
720            [20.0, 2.0],
721            [21.0, 2.0],
722            [22.0, 2.0]
723        ];
724        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
725
726        let et = ExtraTreesClassifier {
727            n_estimators: 30,
728            max_depth: Some(5),
729            seed: 42,
730            ..Default::default()
731        };
732        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
733
734        let preds = fitted.predict(&x).unwrap();
735        let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
736        for &p in preds.iter() {
737            assert!(
738                valid_labels.contains(&p.to_bits()),
739                "prediction {p} is not a valid training label"
740            );
741        }
742    }
743
744    #[test]
745    fn test_predict_proba() {
746        let x = array![
747            [1.0, 0.0],
748            [2.0, 0.0],
749            [3.0, 0.0],
750            [10.0, 1.0],
751            [11.0, 1.0],
752            [12.0, 1.0]
753        ];
754        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
755
756        let et = ExtraTreesClassifier {
757            n_estimators: 20,
758            max_depth: Some(3),
759            seed: 42,
760            ..Default::default()
761        };
762        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
763
764        let proba = fitted.predict_proba(&x).unwrap();
765        assert_eq!(proba.len(), x.nrows());
766
767        // Each sample's probabilities should sum to 1
768        for sample_probs in &proba {
769            let sum: f64 = sample_probs.iter().map(|&(_, p)| p).sum();
770            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
771        }
772
773        // For clearly separable data, class 0 samples should have high P(class=0)
774        for sample_probs in &proba[..3] {
775            let p_class0 = sample_probs
776                .iter()
777                .find(|&&(c, _)| (c - 0.0).abs() < 1e-10)
778                .map(|&(_, p)| p)
779                .unwrap_or(0.0);
780            assert!(p_class0 > 0.5, "expected P(class=0) > 0.5, got {p_class0}");
781        }
782    }
783
784    #[test]
785    fn test_predict_proba_wrong_features_error() {
786        let x = array![[1.0, 2.0], [3.0, 4.0]];
787        let y = array![0.0, 1.0];
788
789        let et = ExtraTreesClassifier {
790            n_estimators: 5,
791            seed: 0,
792            ..Default::default()
793        };
794        let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
795
796        let x_bad = array![[1.0], [2.0]];
797        let result = fitted.predict_proba(&x_bad);
798        assert!(result.is_err());
799    }
800
801    mod prop_tests {
802        use super::*;
803        use proptest::prelude::*;
804        use std::collections::HashSet;
805
806        /// Generate deterministic training data for classification.
807        fn make_classification_data(
808            n_samples: usize,
809            n_features: usize,
810            n_classes: usize,
811            seed: u64,
812        ) -> (Array2<f64>, Array1<f64>) {
813            use std::collections::hash_map::DefaultHasher;
814            use std::hash::{Hash, Hasher};
815
816            let mut x_data = Vec::with_capacity(n_samples * n_features);
817            let mut y_data = Vec::with_capacity(n_samples);
818
819            for i in 0..n_samples {
820                for j in 0..n_features {
821                    let mut h = DefaultHasher::new();
822                    seed.hash(&mut h);
823                    (i as u64).hash(&mut h);
824                    (j as u64).hash(&mut h);
825                    let bits = h.finish();
826                    let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
827                    x_data.push(v);
828                }
829                let mut h = DefaultHasher::new();
830                seed.hash(&mut h);
831                (i as u64).hash(&mut h);
832                0xDEAD_BEEFu64.hash(&mut h);
833                let label = (h.finish() % n_classes as u64) as f64;
834                y_data.push(label);
835            }
836
837            let x = Array2::from_shape_vec((n_samples, n_features), x_data).unwrap();
838            let y = Array1::from_vec(y_data);
839            (x, y)
840        }
841
842        proptest! {
843            #[test]
844            fn predictions_are_valid_labels(
845                n_samples in 6..30usize,
846                n_features in 1..5usize,
847                n_classes in 2..5usize,
848                seed in 0u64..1000,
849            ) {
850                let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
851
852                let train_labels: HashSet<u64> = y.iter()
853                    .map(|&v| v.to_bits())
854                    .collect();
855
856                let et = ExtraTreesClassifier {
857                    n_estimators: 10,
858                    max_depth: Some(5),
859                    seed: seed as u64,
860                    ..Default::default()
861                };
862                let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
863                let preds = fitted.predict(&x).unwrap();
864
865                for (i, &p) in preds.iter().enumerate() {
866                    prop_assert!(
867                        train_labels.contains(&p.to_bits()),
868                        "prediction {} at index {} is not a valid training label",
869                        p, i
870                    );
871                }
872            }
873
874            #[test]
875            fn feature_importances_sum_to_one(
876                n_samples in 6..30usize,
877                n_features in 1..5usize,
878                seed in 0u64..1000,
879            ) {
880                let n_classes = 3;
881                let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
882
883                let et = ExtraTreesClassifier {
884                    n_estimators: 10,
885                    max_depth: Some(5),
886                    seed: seed as u64,
887                    ..Default::default()
888                };
889                let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
890                let importances = fitted.feature_importances();
891                let sum: f64 = importances.iter().sum();
892
893                // Valid outcomes: (1) importances form a valid probability
894                // distribution (sum to 1), or (2) every tree in the forest
895                // is a pure leaf with no splits and all importances are zero.
896                prop_assert!(
897                    (sum - 1.0).abs() < 1e-10 || sum == 0.0,
898                    "feature importances sum to {} (expected ~1.0 or 0.0 for no-split case), n_samples={}, n_features={}, seed={}",
899                    sum, n_samples, n_features, seed
900                );
901                for (i, &imp) in importances.iter().enumerate() {
902                    prop_assert!(
903                        imp >= 0.0,
904                        "importance[{}] = {} is negative",
905                        i, imp
906                    );
907                }
908            }
909        }
910    }
911}