1use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
2use anofox_ml_trees::node::TreeNode;
3use anofox_ml_trees::split::{
4 compute_impurity, count_classes, find_random_split, leaf_value, SplitCriterion,
5};
6use ndarray::{Array1, Array2};
7use rand::rngs::StdRng;
8use rand::{Rng, SeedableRng};
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct ExtraTreesClassifier {
22 pub n_estimators: usize,
24 pub max_depth: Option<usize>,
26 pub min_samples_split: usize,
28 pub min_samples_leaf: usize,
30 pub max_features: Option<usize>,
32 pub seed: u64,
34}
35
36impl ExtraTreesClassifier {
37 pub fn new(n_estimators: usize) -> Self {
39 Self {
40 n_estimators,
41 max_depth: None,
42 min_samples_split: 2,
43 min_samples_leaf: 1,
44 max_features: None,
45 seed: 0,
46 }
47 }
48
49 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
51 self.max_depth = max_depth;
52 self
53 }
54
55 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
57 self.min_samples_split = min_samples_split;
58 self
59 }
60
61 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
63 self.min_samples_leaf = min_samples_leaf;
64 self
65 }
66
67 pub fn with_max_features(mut self, max_features: Option<usize>) -> Self {
69 self.max_features = max_features;
70 self
71 }
72
73 pub fn with_seed(mut self, seed: u64) -> Self {
75 self.seed = seed;
76 self
77 }
78}
79
80impl Default for ExtraTreesClassifier {
81 fn default() -> Self {
82 Self::new(100)
83 }
84}
85
86#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
88#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
89struct ExtraForestTree<F: Float> {
90 tree: TreeNode<F>,
91 feature_indices: Vec<usize>,
95 n_features_tree: usize,
97}
98
99#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
101#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
102pub struct FittedExtraTreesClassifier<F: Float> {
103 trees: Vec<ExtraForestTree<F>>,
104 n_features: usize,
105}
106
107impl<F: Float> Fit<F> for ExtraTreesClassifier {
108 type Fitted = FittedExtraTreesClassifier<F>;
109
110 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
111 if x.nrows() != y.len() {
112 return Err(RustMlError::ShapeMismatch(format!(
113 "X has {} rows but y has {} elements",
114 x.nrows(),
115 y.len()
116 )));
117 }
118 if x.is_empty() {
119 return Err(RustMlError::EmptyInput("training data is empty".into()));
120 }
121 if self.n_estimators == 0 {
122 return Err(RustMlError::InvalidParameter(
123 "n_estimators must be > 0".into(),
124 ));
125 }
126
127 let n_features = x.ncols();
128
129 if let Some(k) = self.max_features {
130 if k == 0 || k > n_features {
131 return Err(RustMlError::InvalidParameter(format!(
132 "max_features={k} is invalid for data with {n_features} features"
133 )));
134 }
135 }
136
137 let mut rng = StdRng::seed_from_u64(self.seed);
138
139 let tree_plans: Vec<(Vec<usize>, u64)> = (0..self.n_estimators)
142 .map(|_| {
143 let feature_indices = select_features(n_features, self.max_features, &mut rng);
144 let tree_seed: u64 = rng.gen();
145 (feature_indices, tree_seed)
146 })
147 .collect();
148
149 let max_depth = self.max_depth;
150 let min_samples_split = self.min_samples_split;
151 let min_samples_leaf = self.min_samples_leaf;
152
153 let trees: Vec<ExtraForestTree<F>> = tree_plans
155 .into_par_iter()
156 .map(|(feature_indices, tree_seed)| {
157 let x_sub = build_sub_matrix_cols(x, &feature_indices);
159 let n_features_tree = feature_indices.len();
160 let indices: Vec<usize> = (0..x.nrows()).collect();
161
162 let tree = build_extra_tree(
163 &x_sub,
164 y,
165 &indices,
166 0,
167 max_depth,
168 min_samples_split,
169 min_samples_leaf,
170 SplitCriterion::Gini,
171 tree_seed,
172 );
173
174 ExtraForestTree {
175 tree,
176 feature_indices,
177 n_features_tree,
178 }
179 })
180 .collect();
181
182 Ok(FittedExtraTreesClassifier { trees, n_features })
183 }
184}
185
186impl<F: Float> Predict<F> for FittedExtraTreesClassifier<F> {
187 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
188 if x.ncols() != self.n_features {
189 return Err(RustMlError::ShapeMismatch(format!(
190 "expected {} features, got {}",
191 self.n_features,
192 x.ncols()
193 )));
194 }
195
196 let n_samples = x.nrows();
197 let n_trees = self.trees.len();
198
199 let all_preds: Vec<Array1<F>> = self
201 .trees
202 .par_iter()
203 .map(|forest_tree| {
204 let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
205 let preds: Vec<F> = sub_x
206 .rows()
207 .into_iter()
208 .map(|row| forest_tree.tree.predict_one(row.as_slice().unwrap()))
209 .collect();
210 Array1::from_vec(preds)
211 })
212 .collect();
213
214 let mut predictions = Vec::with_capacity(n_samples);
216 let mut votes = Vec::with_capacity(n_trees);
217 for i in 0..n_samples {
218 votes.clear();
219 for tree_pred in &all_preds {
220 votes.push(tree_pred[i]);
221 }
222 predictions.push(majority_vote(&votes));
223 }
224
225 Ok(Array1::from_vec(predictions))
226 }
227}
228
229impl<F: Float> FittedExtraTreesClassifier<F> {
230 pub fn feature_importances(&self) -> Array1<F> {
236 let mut importances = vec![F::zero(); self.n_features];
237 let n_trees = F::from_usize(self.trees.len()).unwrap();
238
239 for forest_tree in &self.trees {
240 let total_samples = tree_n_samples(&forest_tree.tree);
241 let tree_raw = forest_tree
242 .tree
243 .feature_importances(forest_tree.n_features_tree, total_samples);
244 let sum: F = tree_raw.iter().copied().fold(F::zero(), |a, b| a + b);
246 for (local_idx, &original_idx) in forest_tree.feature_indices.iter().enumerate() {
247 if sum > F::zero() {
248 importances[original_idx] += (tree_raw[local_idx] / sum) / n_trees;
249 }
250 }
251 }
252
253 let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
255 if sum > F::zero() {
256 Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
257 } else {
258 Array1::zeros(self.n_features)
259 }
260 }
261
262 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Vec<Vec<(F, F)>>> {
267 if x.ncols() != self.n_features {
268 return Err(RustMlError::ShapeMismatch(format!(
269 "expected {} features, got {}",
270 self.n_features,
271 x.ncols()
272 )));
273 }
274
275 let n_samples = x.nrows();
276 let n_trees = self.trees.len();
277 let n_trees_f = F::from_usize(n_trees).unwrap();
278
279 let all_preds: Vec<Array1<F>> = self
281 .trees
282 .par_iter()
283 .map(|forest_tree| {
284 let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
285 let preds: Vec<F> = sub_x
286 .rows()
287 .into_iter()
288 .map(|row| forest_tree.tree.predict_one(row.as_slice().unwrap()))
289 .collect();
290 Array1::from_vec(preds)
291 })
292 .collect();
293
294 let mut result = Vec::with_capacity(n_samples);
296 for i in 0..n_samples {
297 let mut class_votes: std::collections::HashMap<u64, (F, usize)> =
298 std::collections::HashMap::new();
299 for tree_pred in &all_preds {
300 let v = tree_pred[i];
301 let key = v.to_f64().unwrap().to_bits();
302 class_votes
303 .entry(key)
304 .and_modify(|e| e.1 += 1)
305 .or_insert((v, 1));
306 }
307
308 let mut probs: Vec<(F, F)> = class_votes
309 .into_values()
310 .map(|(class, count)| (class, F::from_usize(count).unwrap() / n_trees_f))
311 .collect();
312 probs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
313 result.push(probs);
314 }
315
316 Ok(result)
317 }
318
319 pub fn n_estimators(&self) -> usize {
321 self.trees.len()
322 }
323}
324
325#[allow(clippy::too_many_arguments)]
331fn build_extra_tree<F: Float>(
332 x: &Array2<F>,
333 y: &Array1<F>,
334 indices: &[usize],
335 depth: usize,
336 max_depth: Option<usize>,
337 min_samples_split: usize,
338 min_samples_leaf: usize,
339 criterion: SplitCriterion,
340 seed: u64,
341) -> TreeNode<F> {
342 let n_samples = indices.len();
343 let impurity = compute_impurity(y, indices, criterion);
344
345 let should_stop = n_samples < min_samples_split
347 || max_depth.is_some_and(|d| depth >= d)
348 || impurity < F::from_f64(1e-15).unwrap();
349
350 if should_stop {
351 return make_leaf(y, indices, criterion);
352 }
353
354 let split_seed = seed
356 .wrapping_add(depth as u64)
357 .wrapping_mul(0x517CC1B727220A95);
358
359 match find_random_split(x, y, indices, criterion, min_samples_leaf, split_seed) {
360 Some(split) => {
361 let left = build_extra_tree(
362 x,
363 y,
364 &split.left_indices,
365 depth + 1,
366 max_depth,
367 min_samples_split,
368 min_samples_leaf,
369 criterion,
370 seed.wrapping_add(1),
371 );
372 let right = build_extra_tree(
373 x,
374 y,
375 &split.right_indices,
376 depth + 1,
377 max_depth,
378 min_samples_split,
379 min_samples_leaf,
380 criterion,
381 seed.wrapping_add(2),
382 );
383
384 TreeNode::Split {
385 feature_index: split.feature_index,
386 threshold: split.threshold,
387 left: Box::new(left),
388 right: Box::new(right),
389 n_samples,
390 impurity,
391 }
392 }
393 None => make_leaf(y, indices, criterion),
394 }
395}
396
397fn make_leaf<F: Float>(y: &Array1<F>, indices: &[usize], criterion: SplitCriterion) -> TreeNode<F> {
398 let value = leaf_value(y, indices, criterion);
399 let class_counts = match criterion {
400 SplitCriterion::Gini | SplitCriterion::Entropy => Some(count_classes(y, indices)),
401 SplitCriterion::Mse => None,
402 };
403 TreeNode::Leaf {
404 value,
405 n_samples: indices.len(),
406 class_counts,
407 }
408}
409
410fn tree_n_samples<F: Float>(node: &TreeNode<F>) -> usize {
411 match node {
412 TreeNode::Leaf { n_samples, .. } => *n_samples,
413 TreeNode::Split { n_samples, .. } => *n_samples,
414 }
415}
416
417fn select_features(n_features: usize, max_features: Option<usize>, rng: &mut StdRng) -> Vec<usize> {
424 match max_features {
425 None => (0..n_features).collect(),
426 Some(k) => {
427 let mut indices: Vec<usize> = (0..n_features).collect();
429 for i in 0..k {
430 let j = rng.gen_range(i..n_features);
431 indices.swap(i, j);
432 }
433 indices.truncate(k);
434 indices.sort_unstable();
435 indices
436 }
437 }
438}
439
440fn build_sub_matrix_cols<F: Float>(x: &Array2<F>, col_indices: &[usize]) -> Array2<F> {
444 let n_rows = x.nrows();
445 let n_cols = col_indices.len();
446 let mut data = Vec::with_capacity(n_rows * n_cols);
447 for i in 0..n_rows {
448 for &ci in col_indices {
449 data.push(x[[i, ci]]);
450 }
451 }
452 Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
453}
454
455#[inline]
458fn majority_vote<F: Float>(votes: &[F]) -> F {
459 use std::collections::HashMap;
460 let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
461 for &v in votes {
462 let key = v.to_f64().unwrap().to_bits();
463 counts.entry(key).and_modify(|e| e.1 += 1).or_insert((v, 1));
464 }
465 counts
466 .into_values()
467 .max_by_key(|&(_, count)| count)
468 .unwrap()
469 .0
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use approx::assert_abs_diff_eq;
476 use ndarray::array;
477
478 #[test]
479 fn test_basic_classification() {
480 let x = array![
481 [1.0, 0.0],
482 [2.0, 0.0],
483 [3.0, 0.0],
484 [10.0, 1.0],
485 [11.0, 1.0],
486 [12.0, 1.0]
487 ];
488 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
489
490 let et = ExtraTreesClassifier {
491 n_estimators: 20,
492 max_depth: Some(3),
493 seed: 42,
494 ..Default::default()
495 };
496 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
497
498 let preds = fitted.predict(&x).unwrap();
499 for (p, t) in preds.iter().zip(y.iter()) {
500 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
501 }
502 }
503
504 #[test]
505 fn test_reproducibility() {
506 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
507 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
508
509 let et = ExtraTreesClassifier {
510 n_estimators: 10,
511 seed: 123,
512 ..Default::default()
513 };
514
515 let fitted1: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
516 let fitted2: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
517
518 let preds1 = fitted1.predict(&x).unwrap();
519 let preds2 = fitted2.predict(&x).unwrap();
520
521 for (a, b) in preds1.iter().zip(preds2.iter()) {
522 assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
523 }
524 }
525
526 #[test]
527 fn test_max_features() {
528 let x = array![
529 [1.0, 100.0, 0.5],
530 [2.0, 200.0, 0.6],
531 [3.0, 300.0, 0.7],
532 [10.0, 400.0, 0.8],
533 [11.0, 500.0, 0.9],
534 [12.0, 600.0, 1.0]
535 ];
536 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
537
538 let et = ExtraTreesClassifier {
539 n_estimators: 30,
540 max_features: Some(2),
541 seed: 99,
542 ..Default::default()
543 };
544 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
545
546 let preds = fitted.predict(&x).unwrap();
548 for (p, t) in preds.iter().zip(y.iter()) {
549 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
550 }
551 }
552
553 #[test]
554 fn test_feature_importances_sum_to_one() {
555 let x = array![
556 [1.0, 100.0],
557 [2.0, 200.0],
558 [3.0, 300.0],
559 [4.0, 400.0],
560 [5.0, 500.0],
561 [6.0, 600.0]
562 ];
563 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
564
565 let et = ExtraTreesClassifier {
566 n_estimators: 20,
567 seed: 7,
568 ..Default::default()
569 };
570 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
571
572 let importances = fitted.feature_importances();
573 let sum: f64 = importances.iter().sum();
574 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
575 }
576
577 #[test]
578 fn test_feature_importances_non_negative() {
579 let x = array![
580 [1.0, 100.0, 0.5],
581 [2.0, 200.0, 0.6],
582 [3.0, 300.0, 0.7],
583 [10.0, 400.0, 0.8],
584 [11.0, 500.0, 0.9],
585 [12.0, 600.0, 1.0]
586 ];
587 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
588
589 let et = ExtraTreesClassifier {
590 n_estimators: 20,
591 seed: 7,
592 ..Default::default()
593 };
594 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
595
596 let importances = fitted.feature_importances();
597 for &imp in importances.iter() {
598 assert!(
599 imp >= 0.0,
600 "feature importance must be non-negative, got {imp}"
601 );
602 }
603 }
604
605 #[test]
606 fn test_n_estimators() {
607 let x = array![[1.0], [2.0], [3.0], [4.0]];
608 let y = array![0.0, 0.0, 1.0, 1.0];
609
610 let et = ExtraTreesClassifier {
611 n_estimators: 7,
612 seed: 0,
613 ..Default::default()
614 };
615 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
616 assert_eq!(fitted.n_estimators(), 7);
617 }
618
619 #[test]
620 fn test_shape_mismatch_error() {
621 let x = array![[1.0], [2.0]];
622 let y = array![0.0, 1.0, 2.0];
623
624 let et = ExtraTreesClassifier::default();
625 let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
626 assert!(result.is_err());
627 }
628
629 #[test]
630 fn test_predict_wrong_features_error() {
631 let x = array![[1.0, 2.0], [3.0, 4.0]];
632 let y = array![0.0, 1.0];
633
634 let et = ExtraTreesClassifier {
635 n_estimators: 5,
636 seed: 0,
637 ..Default::default()
638 };
639 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
640
641 let x_bad = array![[1.0], [2.0]];
642 let result = fitted.predict(&x_bad);
643 assert!(result.is_err());
644 }
645
646 #[test]
647 fn test_invalid_max_features() {
648 let x = array![[1.0, 2.0], [3.0, 4.0]];
649 let y = array![0.0, 1.0];
650
651 let et = ExtraTreesClassifier {
652 n_estimators: 5,
653 max_features: Some(5),
654 seed: 0,
655 ..Default::default()
656 };
657 let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
658 assert!(result.is_err());
659 }
660
661 #[test]
662 fn test_zero_estimators_error() {
663 let x = array![[1.0, 2.0], [3.0, 4.0]];
664 let y = array![0.0, 1.0];
665
666 let et = ExtraTreesClassifier {
667 n_estimators: 0,
668 seed: 0,
669 ..Default::default()
670 };
671 let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
672 assert!(result.is_err());
673 }
674
675 #[test]
676 fn test_empty_input_error() {
677 let x: Array2<f64> = Array2::zeros((0, 2));
678 let y: Array1<f64> = Array1::zeros(0);
679
680 let et = ExtraTreesClassifier::default();
681 let result: std::result::Result<FittedExtraTreesClassifier<f64>, _> = et.fit(&x, &y);
682 assert!(result.is_err());
683 }
684
685 #[test]
686 fn test_n_estimators_one() {
687 let x = array![
688 [1.0, 0.0],
689 [2.0, 0.0],
690 [3.0, 0.0],
691 [10.0, 1.0],
692 [11.0, 1.0],
693 [12.0, 1.0]
694 ];
695 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
696
697 let et = ExtraTreesClassifier {
698 n_estimators: 1,
699 max_depth: Some(3),
700 seed: 42,
701 ..Default::default()
702 };
703 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
704 assert_eq!(fitted.n_estimators(), 1);
705
706 let preds = fitted.predict(&x).unwrap();
708 assert_eq!(preds.len(), y.len());
709 }
710
711 #[test]
712 fn test_predictions_are_valid_labels() {
713 let x = array![
714 [1.0, 0.0],
715 [2.0, 0.0],
716 [3.0, 0.0],
717 [10.0, 1.0],
718 [11.0, 1.0],
719 [12.0, 1.0],
720 [20.0, 2.0],
721 [21.0, 2.0],
722 [22.0, 2.0]
723 ];
724 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
725
726 let et = ExtraTreesClassifier {
727 n_estimators: 30,
728 max_depth: Some(5),
729 seed: 42,
730 ..Default::default()
731 };
732 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
733
734 let preds = fitted.predict(&x).unwrap();
735 let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
736 for &p in preds.iter() {
737 assert!(
738 valid_labels.contains(&p.to_bits()),
739 "prediction {p} is not a valid training label"
740 );
741 }
742 }
743
744 #[test]
745 fn test_predict_proba() {
746 let x = array![
747 [1.0, 0.0],
748 [2.0, 0.0],
749 [3.0, 0.0],
750 [10.0, 1.0],
751 [11.0, 1.0],
752 [12.0, 1.0]
753 ];
754 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
755
756 let et = ExtraTreesClassifier {
757 n_estimators: 20,
758 max_depth: Some(3),
759 seed: 42,
760 ..Default::default()
761 };
762 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
763
764 let proba = fitted.predict_proba(&x).unwrap();
765 assert_eq!(proba.len(), x.nrows());
766
767 for sample_probs in &proba {
769 let sum: f64 = sample_probs.iter().map(|&(_, p)| p).sum();
770 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
771 }
772
773 for sample_probs in &proba[..3] {
775 let p_class0 = sample_probs
776 .iter()
777 .find(|&&(c, _)| (c - 0.0).abs() < 1e-10)
778 .map(|&(_, p)| p)
779 .unwrap_or(0.0);
780 assert!(p_class0 > 0.5, "expected P(class=0) > 0.5, got {p_class0}");
781 }
782 }
783
784 #[test]
785 fn test_predict_proba_wrong_features_error() {
786 let x = array![[1.0, 2.0], [3.0, 4.0]];
787 let y = array![0.0, 1.0];
788
789 let et = ExtraTreesClassifier {
790 n_estimators: 5,
791 seed: 0,
792 ..Default::default()
793 };
794 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
795
796 let x_bad = array![[1.0], [2.0]];
797 let result = fitted.predict_proba(&x_bad);
798 assert!(result.is_err());
799 }
800
801 mod prop_tests {
802 use super::*;
803 use proptest::prelude::*;
804 use std::collections::HashSet;
805
806 fn make_classification_data(
808 n_samples: usize,
809 n_features: usize,
810 n_classes: usize,
811 seed: u64,
812 ) -> (Array2<f64>, Array1<f64>) {
813 use std::collections::hash_map::DefaultHasher;
814 use std::hash::{Hash, Hasher};
815
816 let mut x_data = Vec::with_capacity(n_samples * n_features);
817 let mut y_data = Vec::with_capacity(n_samples);
818
819 for i in 0..n_samples {
820 for j in 0..n_features {
821 let mut h = DefaultHasher::new();
822 seed.hash(&mut h);
823 (i as u64).hash(&mut h);
824 (j as u64).hash(&mut h);
825 let bits = h.finish();
826 let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
827 x_data.push(v);
828 }
829 let mut h = DefaultHasher::new();
830 seed.hash(&mut h);
831 (i as u64).hash(&mut h);
832 0xDEAD_BEEFu64.hash(&mut h);
833 let label = (h.finish() % n_classes as u64) as f64;
834 y_data.push(label);
835 }
836
837 let x = Array2::from_shape_vec((n_samples, n_features), x_data).unwrap();
838 let y = Array1::from_vec(y_data);
839 (x, y)
840 }
841
842 proptest! {
843 #[test]
844 fn predictions_are_valid_labels(
845 n_samples in 6..30usize,
846 n_features in 1..5usize,
847 n_classes in 2..5usize,
848 seed in 0u64..1000,
849 ) {
850 let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
851
852 let train_labels: HashSet<u64> = y.iter()
853 .map(|&v| v.to_bits())
854 .collect();
855
856 let et = ExtraTreesClassifier {
857 n_estimators: 10,
858 max_depth: Some(5),
859 seed: seed as u64,
860 ..Default::default()
861 };
862 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
863 let preds = fitted.predict(&x).unwrap();
864
865 for (i, &p) in preds.iter().enumerate() {
866 prop_assert!(
867 train_labels.contains(&p.to_bits()),
868 "prediction {} at index {} is not a valid training label",
869 p, i
870 );
871 }
872 }
873
874 #[test]
875 fn feature_importances_sum_to_one(
876 n_samples in 6..30usize,
877 n_features in 1..5usize,
878 seed in 0u64..1000,
879 ) {
880 let n_classes = 3;
881 let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
882
883 let et = ExtraTreesClassifier {
884 n_estimators: 10,
885 max_depth: Some(5),
886 seed: seed as u64,
887 ..Default::default()
888 };
889 let fitted: FittedExtraTreesClassifier<f64> = et.fit(&x, &y).unwrap();
890 let importances = fitted.feature_importances();
891 let sum: f64 = importances.iter().sum();
892
893 prop_assert!(
897 (sum - 1.0).abs() < 1e-10 || sum == 0.0,
898 "feature importances sum to {} (expected ~1.0 or 0.0 for no-split case), n_samples={}, n_features={}, seed={}",
899 sum, n_samples, n_features, seed
900 );
901 for (i, &imp) in importances.iter().enumerate() {
902 prop_assert!(
903 imp >= 0.0,
904 "importance[{}] = {} is negative",
905 i, imp
906 );
907 }
908 }
909 }
910 }
911}