1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use anofox_ml_trees::{
3 ClassWeight, DecisionTreeClassifier, FittedDecisionTreeClassifier, SplitCriterion,
4};
5use ndarray::{Array1, Array2};
6use rand::rngs::StdRng;
7use rand::{Rng, SeedableRng};
8use rayon::prelude::*;
9
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct RandomForestClassifier {
16 pub n_estimators: usize,
18 pub max_depth: Option<usize>,
20 pub min_samples_split: usize,
22 pub min_samples_leaf: usize,
24 pub max_features: Option<usize>,
26 pub criterion: SplitCriterion,
28 pub bootstrap: bool,
30 pub max_samples: Option<f64>,
33 pub class_weight: Option<ClassWeight>,
35 pub oob_score: bool,
37 pub seed: u64,
39}
40
41impl RandomForestClassifier {
42 pub fn new(n_estimators: usize) -> Self {
44 Self {
45 n_estimators,
46 max_depth: None,
47 min_samples_split: 2,
48 min_samples_leaf: 1,
49 max_features: None,
50 criterion: SplitCriterion::Gini,
51 bootstrap: true,
52 max_samples: None,
53 class_weight: None,
54 oob_score: false,
55 seed: 0,
56 }
57 }
58
59 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
60 self.max_depth = max_depth;
61 self
62 }
63 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
64 self.min_samples_split = min_samples_split;
65 self
66 }
67 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
68 self.min_samples_leaf = min_samples_leaf;
69 self
70 }
71 pub fn with_max_features(mut self, max_features: Option<usize>) -> Self {
72 self.max_features = max_features;
73 self
74 }
75 pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
76 self.criterion = criterion;
77 self
78 }
79 pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
80 self.bootstrap = bootstrap;
81 self
82 }
83 pub fn with_max_samples(mut self, max_samples: Option<f64>) -> Self {
84 self.max_samples = max_samples;
85 self
86 }
87 pub fn with_class_weight(mut self, class_weight: Option<ClassWeight>) -> Self {
88 self.class_weight = class_weight;
89 self
90 }
91 pub fn with_oob_score(mut self, oob_score: bool) -> Self {
92 self.oob_score = oob_score;
93 self
94 }
95 pub fn with_seed(mut self, seed: u64) -> Self {
96 self.seed = seed;
97 self
98 }
99}
100
101impl Default for RandomForestClassifier {
102 fn default() -> Self {
103 Self::new(100)
104 }
105}
106
107#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
109#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
110struct ForestTree<F: Float> {
111 tree: FittedDecisionTreeClassifier<F>,
112 feature_indices: Vec<usize>,
116}
117
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
120#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
121pub struct FittedRandomForestClassifier<F: Float> {
122 trees: Vec<ForestTree<F>>,
123 n_features: usize,
124 oob_score_value: Option<f64>,
126}
127
128impl<F: Float> Fit<F> for RandomForestClassifier {
129 type Fitted = FittedRandomForestClassifier<F>;
130
131 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
132 if x.nrows() != y.len() {
133 return Err(RustMlError::ShapeMismatch(format!(
134 "X has {} rows but y has {} elements",
135 x.nrows(),
136 y.len()
137 )));
138 }
139 if x.is_empty() {
140 return Err(RustMlError::EmptyInput("training data is empty".into()));
141 }
142 if self.n_estimators == 0 {
143 return Err(RustMlError::InvalidParameter(
144 "n_estimators must be > 0".into(),
145 ));
146 }
147
148 let n_samples = x.nrows();
149 let n_features = x.ncols();
150
151 if let Some(k) = self.max_features {
152 if k == 0 || k > n_features {
153 return Err(RustMlError::InvalidParameter(format!(
154 "max_features={k} is invalid for data with {n_features} features"
155 )));
156 }
157 }
158
159 let mut rng = StdRng::seed_from_u64(self.seed);
160
161 let draw_size = if let Some(frac) = self.max_samples {
163 if frac <= 0.0 || frac > 1.0 {
164 return Err(RustMlError::InvalidParameter(
165 "max_samples must be in (0, 1]".into(),
166 ));
167 }
168 (n_samples as f64 * frac).ceil() as usize
169 } else {
170 n_samples
171 };
172
173 let tree_params = DecisionTreeClassifier {
174 max_depth: self.max_depth,
175 min_samples_split: self.min_samples_split,
176 min_samples_leaf: self.min_samples_leaf,
177 criterion: self.criterion,
178 max_features: None,
179 sample_weight: None,
180 class_weight: self.class_weight.clone(),
181 };
182
183 let sample_plans: Vec<(Vec<usize>, Vec<usize>)> = (0..self.n_estimators)
185 .map(|_| {
186 let row_indices: Vec<usize> = if self.bootstrap {
187 (0..draw_size)
188 .map(|_| rng.gen_range(0..n_samples))
189 .collect()
190 } else {
191 (0..n_samples).collect()
192 };
193 let feature_indices = select_features(n_features, self.max_features, &mut rng);
194 (row_indices, feature_indices)
195 })
196 .collect();
197
198 let all_row_indices: Vec<Vec<usize>> = if self.oob_score {
200 sample_plans.iter().map(|(ri, _)| ri.clone()).collect()
201 } else {
202 Vec::new()
203 };
204
205 let trees: Result<Vec<ForestTree<F>>> = sample_plans
207 .into_par_iter()
208 .map(|(row_indices, feature_indices)| {
209 let x_sub = build_sub_matrix(x, &row_indices, &feature_indices);
210 let y_sub = Array1::from_vec(row_indices.iter().map(|&i| y[i]).collect::<Vec<F>>());
211 let fitted_tree: FittedDecisionTreeClassifier<F> =
212 tree_params.fit(&x_sub, &y_sub)?;
213 Ok(ForestTree {
214 tree: fitted_tree,
215 feature_indices,
216 })
217 })
218 .collect();
219 let trees = trees?;
220
221 let oob_score_value = if self.oob_score && self.bootstrap {
223 compute_oob_score_classification(&trees, x, y, n_samples, &all_row_indices)
224 } else {
225 None
226 };
227
228 Ok(FittedRandomForestClassifier {
229 trees,
230 n_features,
231 oob_score_value,
232 })
233 }
234}
235
236impl<F: Float> Predict<F> for FittedRandomForestClassifier<F> {
237 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
238 if x.ncols() != self.n_features {
239 return Err(RustMlError::ShapeMismatch(format!(
240 "expected {} features, got {}",
241 self.n_features,
242 x.ncols()
243 )));
244 }
245
246 let n_samples = x.nrows();
247 let n_trees = self.trees.len();
248
249 let all_preds: Result<Vec<Array1<F>>> = self
251 .trees
252 .par_iter()
253 .map(|forest_tree| {
254 let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
255 forest_tree.tree.predict(&sub_x)
256 })
257 .collect();
258 let all_preds = all_preds?;
259
260 let mut predictions = Vec::with_capacity(n_samples);
262 let mut votes = Vec::with_capacity(n_trees);
263 for i in 0..n_samples {
264 votes.clear();
265 for tree_pred in &all_preds {
266 votes.push(tree_pred[i]);
267 }
268 predictions.push(majority_vote(&votes));
269 }
270
271 Ok(Array1::from_vec(predictions))
272 }
273}
274
275impl<F: Float> FittedRandomForestClassifier<F> {
276 pub fn feature_importances(&self) -> Array1<F> {
282 let mut importances = vec![F::zero(); self.n_features];
283 let n_trees = F::from_usize(self.trees.len()).unwrap();
284
285 for forest_tree in &self.trees {
286 let tree_importances = forest_tree.tree.feature_importances();
287 for (local_idx, &original_idx) in forest_tree.feature_indices.iter().enumerate() {
288 importances[original_idx] += tree_importances[local_idx] / n_trees;
289 }
290 }
291
292 let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
294 if sum > F::zero() {
295 Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
296 } else {
297 Array1::zeros(self.n_features)
298 }
299 }
300
301 pub fn n_estimators(&self) -> usize {
303 self.trees.len()
304 }
305
306 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
311 if x.ncols() != self.n_features {
312 return Err(RustMlError::ShapeMismatch(format!(
313 "expected {} features, got {}",
314 self.n_features,
315 x.ncols()
316 )));
317 }
318
319 let all_proba: Result<Vec<Array2<F>>> = self
321 .trees
322 .par_iter()
323 .map(|forest_tree| {
324 let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
325 forest_tree.tree.predict_proba(&sub_x)
326 })
327 .collect();
328 let all_proba = all_proba?;
329
330 let mut global_classes: Vec<F> = Vec::new();
332 let eps = F::from_f64(1e-9).unwrap();
333 for forest_tree in &self.trees {
334 for c in forest_tree.tree.classes() {
335 if !global_classes.iter().any(|&gc| (gc - c).abs() < eps) {
336 global_classes.push(c);
337 }
338 }
339 }
340 global_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
341
342 let n_samples = x.nrows();
343 let n_classes = global_classes.len();
344 let n_trees_f = F::from_usize(self.trees.len()).unwrap();
345 let mut avg_proba = Array2::<F>::zeros((n_samples, n_classes));
346
347 for (tree_idx, forest_tree) in self.trees.iter().enumerate() {
349 let tree_classes = forest_tree.tree.classes();
350 let tree_proba = &all_proba[tree_idx];
351
352 for (local_ci, &tc) in tree_classes.iter().enumerate() {
353 if let Some(global_ci) = global_classes.iter().position(|&gc| (gc - tc).abs() < eps)
354 {
355 for i in 0..n_samples {
356 avg_proba[[i, global_ci]] += tree_proba[[i, local_ci]] / n_trees_f;
357 }
358 }
359 }
360 }
361
362 Ok(avg_proba)
363 }
364
365 pub fn oob_score(&self) -> Option<f64> {
367 self.oob_score_value
368 }
369
370 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<f64> {
372 let preds = self.predict(x)?;
373 let n = y.len();
374 let correct = preds
375 .iter()
376 .zip(y.iter())
377 .filter(|(&p, &t)| (p - t).abs() < F::from_f64(1e-9).unwrap())
378 .count();
379 Ok(correct as f64 / n as f64)
380 }
381
382 pub fn classes(&self) -> Vec<F> {
384 let eps = F::from_f64(1e-9).unwrap();
385 let mut classes: Vec<F> = Vec::new();
386 for forest_tree in &self.trees {
387 for c in forest_tree.tree.classes() {
388 if !classes.iter().any(|&gc| (gc - c).abs() < eps) {
389 classes.push(c);
390 }
391 }
392 }
393 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
394 classes
395 }
396}
397
398fn select_features(n_features: usize, max_features: Option<usize>, rng: &mut StdRng) -> Vec<usize> {
401 match max_features {
402 None => (0..n_features).collect(),
403 Some(k) => {
404 let mut indices: Vec<usize> = (0..n_features).collect();
406 for i in 0..k {
407 let j = rng.gen_range(i..n_features);
408 indices.swap(i, j);
409 }
410 indices.truncate(k);
411 indices.sort_unstable();
412 indices
413 }
414 }
415}
416
417fn build_sub_matrix<F: Float>(
419 x: &Array2<F>,
420 row_indices: &[usize],
421 col_indices: &[usize],
422) -> Array2<F> {
423 let n_rows = row_indices.len();
425 let n_cols = col_indices.len();
426 let mut data = Vec::with_capacity(n_rows * n_cols);
427 for &ri in row_indices {
428 for &ci in col_indices {
429 data.push(x[[ri, ci]]);
430 }
431 }
432 Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
433}
434
435fn build_sub_matrix_cols<F: Float>(x: &Array2<F>, col_indices: &[usize]) -> Array2<F> {
439 let n_rows = x.nrows();
440 let n_cols = col_indices.len();
441 let mut data = Vec::with_capacity(n_rows * n_cols);
442 for i in 0..n_rows {
443 for &ci in col_indices {
444 data.push(x[[i, ci]]);
445 }
446 }
447 Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
448}
449
450fn compute_oob_score_classification<F: Float>(
453 trees: &[ForestTree<F>],
454 x: &Array2<F>,
455 y: &Array1<F>,
456 n_samples: usize,
457 bootstrap_indices: &[Vec<usize>],
458) -> Option<f64> {
459 use std::collections::HashSet;
460
461 let mut correct = 0usize;
462 let mut evaluated = 0usize;
463
464 for i in 0..n_samples {
465 let mut votes: Vec<F> = Vec::new();
466 for (t_idx, forest_tree) in trees.iter().enumerate() {
467 let in_bag: HashSet<usize> = bootstrap_indices[t_idx].iter().copied().collect();
469 if in_bag.contains(&i) {
470 continue;
471 }
472 let sub_x = build_sub_matrix_cols_single(x, i, &forest_tree.feature_indices);
474 if let Ok(pred) = forest_tree.tree.predict(&sub_x) {
475 votes.push(pred[0]);
476 }
477 }
478 if !votes.is_empty() {
479 let oob_pred = majority_vote(&votes);
480 if (oob_pred - y[i]).abs() < F::from_f64(1e-9).unwrap() {
481 correct += 1;
482 }
483 evaluated += 1;
484 }
485 }
486
487 if evaluated > 0 {
488 Some(correct as f64 / evaluated as f64)
489 } else {
490 None
491 }
492}
493
494fn build_sub_matrix_cols_single<F: Float>(
496 x: &Array2<F>,
497 row: usize,
498 col_indices: &[usize],
499) -> Array2<F> {
500 let n_cols = col_indices.len();
501 let data: Vec<F> = col_indices.iter().map(|&ci| x[[row, ci]]).collect();
502 Array2::from_shape_vec((1, n_cols), data).expect("shape matches data length")
503}
504
505#[inline]
508fn majority_vote<F: Float>(votes: &[F]) -> F {
509 use std::collections::HashMap;
510 let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
511 for &v in votes {
512 let key = v.to_f64().unwrap().to_bits();
513 counts.entry(key).and_modify(|e| e.1 += 1).or_insert((v, 1));
514 }
515 counts
516 .into_values()
517 .max_by_key(|&(_, count)| count)
518 .unwrap()
519 .0
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use approx::assert_abs_diff_eq;
526 use ndarray::array;
527
528 #[test]
529 fn test_basic_classification() {
530 let x = array![
531 [1.0, 0.0],
532 [2.0, 0.0],
533 [3.0, 0.0],
534 [10.0, 1.0],
535 [11.0, 1.0],
536 [12.0, 1.0]
537 ];
538 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
539
540 let rf = RandomForestClassifier {
541 n_estimators: 20,
542 max_depth: Some(3),
543 seed: 42,
544 ..Default::default()
545 };
546 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
547
548 let preds = fitted.predict(&x).unwrap();
549 for (p, t) in preds.iter().zip(y.iter()) {
550 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
551 }
552 }
553
554 #[test]
555 fn test_reproducibility() {
556 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
557 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
558
559 let rf = RandomForestClassifier {
560 n_estimators: 10,
561 seed: 123,
562 ..Default::default()
563 };
564
565 let fitted1: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
566 let fitted2: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
567
568 let preds1 = fitted1.predict(&x).unwrap();
569 let preds2 = fitted2.predict(&x).unwrap();
570
571 for (a, b) in preds1.iter().zip(preds2.iter()) {
572 assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
573 }
574 }
575
576 #[test]
577 fn test_max_features() {
578 let x = array![
579 [1.0, 100.0, 0.5],
580 [2.0, 200.0, 0.6],
581 [3.0, 300.0, 0.7],
582 [10.0, 400.0, 0.8],
583 [11.0, 500.0, 0.9],
584 [12.0, 600.0, 1.0]
585 ];
586 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
587
588 let rf = RandomForestClassifier {
589 n_estimators: 30,
590 max_features: Some(2),
591 seed: 99,
592 ..Default::default()
593 };
594 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
595
596 let preds = fitted.predict(&x).unwrap();
598 for (p, t) in preds.iter().zip(y.iter()) {
599 assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
600 }
601 }
602
603 #[test]
604 fn test_feature_importances_sum_to_one() {
605 let x = array![
606 [1.0, 100.0],
607 [2.0, 200.0],
608 [3.0, 300.0],
609 [4.0, 400.0],
610 [5.0, 500.0],
611 [6.0, 600.0]
612 ];
613 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
614
615 let rf = RandomForestClassifier {
616 n_estimators: 20,
617 seed: 7,
618 ..Default::default()
619 };
620 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
621
622 let importances = fitted.feature_importances();
623 let sum: f64 = importances.iter().sum();
624 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
625 }
626
627 #[test]
628 fn test_n_estimators() {
629 let x = array![[1.0], [2.0], [3.0], [4.0]];
630 let y = array![0.0, 0.0, 1.0, 1.0];
631
632 let rf = RandomForestClassifier {
633 n_estimators: 7,
634 seed: 0,
635 ..Default::default()
636 };
637 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
638 assert_eq!(fitted.n_estimators(), 7);
639 }
640
641 #[test]
642 fn test_shape_mismatch_error() {
643 let x = array![[1.0], [2.0]];
644 let y = array![0.0, 1.0, 2.0];
645
646 let rf = RandomForestClassifier::default();
647 let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
648 assert!(result.is_err());
649 }
650
651 #[test]
652 fn test_predict_wrong_features_error() {
653 let x = array![[1.0, 2.0], [3.0, 4.0]];
654 let y = array![0.0, 1.0];
655
656 let rf = RandomForestClassifier {
657 n_estimators: 5,
658 seed: 0,
659 ..Default::default()
660 };
661 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
662
663 let x_bad = array![[1.0], [2.0]];
664 let result = fitted.predict(&x_bad);
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn test_invalid_max_features() {
670 let x = array![[1.0, 2.0], [3.0, 4.0]];
671 let y = array![0.0, 1.0];
672
673 let rf = RandomForestClassifier {
674 n_estimators: 5,
675 max_features: Some(5),
676 seed: 0,
677 ..Default::default()
678 };
679 let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
680 assert!(result.is_err());
681 }
682
683 #[test]
684 fn test_feature_importances_non_negative() {
685 let x = array![
686 [1.0, 100.0, 0.5],
687 [2.0, 200.0, 0.6],
688 [3.0, 300.0, 0.7],
689 [10.0, 400.0, 0.8],
690 [11.0, 500.0, 0.9],
691 [12.0, 600.0, 1.0]
692 ];
693 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
694
695 let rf = RandomForestClassifier {
696 n_estimators: 20,
697 seed: 7,
698 ..Default::default()
699 };
700 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
701
702 let importances = fitted.feature_importances();
703 for &imp in importances.iter() {
704 assert!(
705 imp >= 0.0,
706 "feature importance must be non-negative, got {imp}"
707 );
708 }
709 }
710
711 #[test]
712 fn test_n_estimators_one() {
713 let x = array![
714 [1.0, 0.0],
715 [2.0, 0.0],
716 [3.0, 0.0],
717 [10.0, 1.0],
718 [11.0, 1.0],
719 [12.0, 1.0]
720 ];
721 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
722
723 let rf = RandomForestClassifier {
724 n_estimators: 1,
725 max_depth: Some(3),
726 seed: 42,
727 ..Default::default()
728 };
729 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
730 assert_eq!(fitted.n_estimators(), 1);
731
732 let preds = fitted.predict(&x).unwrap();
734 assert_eq!(preds.len(), y.len());
735 }
736
737 #[test]
738 fn test_predictions_are_valid_labels() {
739 let x = array![
740 [1.0, 0.0],
741 [2.0, 0.0],
742 [3.0, 0.0],
743 [10.0, 1.0],
744 [11.0, 1.0],
745 [12.0, 1.0],
746 [20.0, 2.0],
747 [21.0, 2.0],
748 [22.0, 2.0]
749 ];
750 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
751
752 let rf = RandomForestClassifier {
753 n_estimators: 30,
754 max_depth: Some(5),
755 seed: 42,
756 ..Default::default()
757 };
758 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
759
760 let preds = fitted.predict(&x).unwrap();
761 let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
762 for &p in preds.iter() {
763 assert!(
764 valid_labels.contains(&p.to_bits()),
765 "prediction {p} is not a valid training label"
766 );
767 }
768 }
769
770 #[test]
771 fn test_empty_input_error() {
772 let x: Array2<f64> = Array2::zeros((0, 2));
773 let y: Array1<f64> = Array1::zeros(0);
774
775 let rf = RandomForestClassifier::default();
776 let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
777 assert!(result.is_err());
778 }
779
780 #[test]
781 fn test_zero_estimators_error() {
782 let x = array![[1.0, 2.0], [3.0, 4.0]];
783 let y = array![0.0, 1.0];
784
785 let rf = RandomForestClassifier {
786 n_estimators: 0,
787 seed: 0,
788 ..Default::default()
789 };
790 let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
791 assert!(result.is_err());
792 }
793
794 mod prop_tests {
795 use super::*;
796 use proptest::prelude::*;
797 use std::collections::HashSet;
798
799 fn make_classification_data(
801 n_samples: usize,
802 n_features: usize,
803 n_classes: usize,
804 seed: u64,
805 ) -> (Array2<f64>, Array1<f64>) {
806 use std::collections::hash_map::DefaultHasher;
807 use std::hash::{Hash, Hasher};
808
809 let mut x_data = Vec::with_capacity(n_samples * n_features);
810 let mut y_data = Vec::with_capacity(n_samples);
811
812 for i in 0..n_samples {
813 for j in 0..n_features {
814 let mut h = DefaultHasher::new();
815 seed.hash(&mut h);
816 (i as u64).hash(&mut h);
817 (j as u64).hash(&mut h);
818 let bits = h.finish();
819 let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
820 x_data.push(v);
821 }
822 let mut h = DefaultHasher::new();
823 seed.hash(&mut h);
824 (i as u64).hash(&mut h);
825 0xDEAD_BEEFu64.hash(&mut h);
826 let label = (h.finish() % n_classes as u64) as f64;
827 y_data.push(label);
828 }
829
830 let x = Array2::from_shape_vec((n_samples, n_features), x_data).unwrap();
831 let y = Array1::from_vec(y_data);
832 (x, y)
833 }
834
835 proptest! {
836 #[test]
837 fn predictions_are_valid_labels(
838 n_samples in 6..30usize,
839 n_features in 1..5usize,
840 n_classes in 2..5usize,
841 seed in 0u64..1000,
842 ) {
843 let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
844
845 let train_labels: HashSet<u64> = y.iter()
846 .map(|&v| v.to_bits())
847 .collect();
848
849 let rf = RandomForestClassifier {
850 n_estimators: 10,
851 max_depth: Some(5),
852 seed: seed as u64,
853 ..Default::default()
854 };
855 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
856 let preds = fitted.predict(&x).unwrap();
857
858 for (i, &p) in preds.iter().enumerate() {
859 prop_assert!(
860 train_labels.contains(&p.to_bits()),
861 "prediction {} at index {} is not a valid training label",
862 p, i
863 );
864 }
865 }
866
867 #[test]
868 fn feature_importances_sum_to_one(
869 n_samples in 6..30usize,
870 n_features in 1..5usize,
871 seed in 0u64..1000,
872 ) {
873 let n_classes = 3;
874 let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
875
876 let rf = RandomForestClassifier {
877 n_estimators: 10,
878 max_depth: Some(5),
879 seed: seed as u64,
880 ..Default::default()
881 };
882 let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
883 let importances = fitted.feature_importances();
884 let sum: f64 = importances.iter().sum();
885
886 prop_assert!(
891 (sum - 1.0).abs() < 1e-10 || sum == 0.0,
892 "feature importances sum to {} (expected ~1.0 or 0.0 for no-split case), n_samples={}, n_features={}, seed={}",
893 sum, n_samples, n_features, seed
894 );
895 for (i, &imp) in importances.iter().enumerate() {
897 prop_assert!(
898 imp >= 0.0,
899 "importance[{}] = {} is negative",
900 i, imp
901 );
902 }
903 }
904 }
905 }
906}
907
908impl<F: Float> PredictProba<F> for FittedRandomForestClassifier<F> {
909 fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
910 Self::predict_proba(self, x)
911 }
912}