1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use ndarray::{Array1, Array2};
3
4use crate::node::TreeNode;
5use crate::split::{
6 compute_impurity, compute_sample_weights_from_class_weight, compute_weighted_impurity,
7 count_classes, find_best_split_weighted, find_best_split_with_features, leaf_value,
8 select_feature_subset, weighted_count_classes, weighted_leaf_value, ClassWeight, MaxFeatures,
9 SplitCriterion,
10};
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct DecisionTreeClassifier {
15 pub max_depth: Option<usize>,
16 pub min_samples_split: usize,
17 pub min_samples_leaf: usize,
18 pub criterion: SplitCriterion,
19 pub max_features: Option<MaxFeatures>,
21 #[serde(skip)]
23 pub sample_weight: Option<Array1<f64>>,
24 pub class_weight: Option<ClassWeight>,
26}
27
28impl DecisionTreeClassifier {
29 pub fn new() -> Self {
31 Self {
32 max_depth: None,
33 min_samples_split: 2,
34 min_samples_leaf: 1,
35 criterion: SplitCriterion::Gini,
36 max_features: None,
37 sample_weight: None,
38 class_weight: None,
39 }
40 }
41
42 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
44 self.max_depth = max_depth;
45 self
46 }
47
48 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
50 self.min_samples_split = min_samples_split;
51 self
52 }
53
54 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
56 self.min_samples_leaf = min_samples_leaf;
57 self
58 }
59
60 pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
62 self.criterion = criterion;
63 self
64 }
65
66 pub fn with_max_features(mut self, max_features: Option<MaxFeatures>) -> Self {
68 self.max_features = max_features;
69 self
70 }
71
72 pub fn with_sample_weight(mut self, sample_weight: Option<Array1<f64>>) -> Self {
74 self.sample_weight = sample_weight;
75 self
76 }
77
78 pub fn with_class_weight(mut self, class_weight: Option<ClassWeight>) -> Self {
80 self.class_weight = class_weight;
81 self
82 }
83}
84
85impl Default for DecisionTreeClassifier {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
94pub struct FittedDecisionTreeClassifier<F: Float> {
95 tree: TreeNode<F>,
96 n_features: usize,
97}
98
99impl<F: Float> Fit<F> for DecisionTreeClassifier {
100 type Fitted = FittedDecisionTreeClassifier<F>;
101
102 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
103 if x.nrows() != y.len() {
104 return Err(RustMlError::ShapeMismatch(format!(
105 "X has {} rows but y has {} elements",
106 x.nrows(),
107 y.len()
108 )));
109 }
110 if x.is_empty() {
111 return Err(RustMlError::EmptyInput("training data is empty".into()));
112 }
113
114 let indices: Vec<usize> = (0..x.nrows()).collect();
115 let n_features = x.ncols();
116 let max_features_k = self.max_features.map(|mf| mf.resolve(n_features));
117
118 let effective_weights: Option<Array1<F>> = {
120 let class_w = self
121 .class_weight
122 .as_ref()
123 .map(|cw| compute_sample_weights_from_class_weight(y, cw));
124 let sample_w = self
125 .sample_weight
126 .as_ref()
127 .map(|sw| sw.mapv(|v| F::from_f64(v).unwrap()));
128 match (class_w, sample_w) {
129 (Some(cw), Some(sw)) => Some(cw * sw),
130 (Some(cw), None) => Some(cw),
131 (None, Some(sw)) => Some(sw),
132 (None, None) => None,
133 }
134 };
135
136 let params = TreeBuildParams {
137 max_depth: self.max_depth,
138 min_samples_split: self.min_samples_split,
139 min_samples_leaf: self.min_samples_leaf,
140 criterion: self.criterion,
141 max_features_k,
142 n_features,
143 };
144 let tree = build_tree(x, y, &indices, 0, ¶ms, 0, effective_weights.as_ref());
145
146 Ok(FittedDecisionTreeClassifier {
147 tree,
148 n_features: x.ncols(),
149 })
150 }
151}
152
153impl<F: Float> Predict<F> for FittedDecisionTreeClassifier<F> {
154 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
155 if x.ncols() != self.n_features {
156 return Err(RustMlError::ShapeMismatch(format!(
157 "expected {} features, got {}",
158 self.n_features,
159 x.ncols()
160 )));
161 }
162
163 let predictions: Vec<F> = x
164 .rows()
165 .into_iter()
166 .map(|row| self.tree.predict_one(row.as_slice().unwrap()))
167 .collect();
168
169 Ok(Array1::from_vec(predictions))
170 }
171}
172
173impl<F: Float> FittedDecisionTreeClassifier<F> {
174 pub fn feature_importances(&self) -> Array1<F> {
176 let n_samples = tree_n_samples(&self.tree);
177 let raw = self.tree.feature_importances(self.n_features, n_samples);
178 let sum: F = raw.iter().copied().fold(F::zero(), |a, b| a + b);
179 if sum > F::zero() {
180 Array1::from_vec(raw.into_iter().map(|v| v / sum).collect())
181 } else {
182 Array1::zeros(self.n_features)
183 }
184 }
185
186 pub fn tree(&self) -> &TreeNode<F> {
187 &self.tree
188 }
189
190 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
195 if x.ncols() != self.n_features {
196 return Err(RustMlError::ShapeMismatch(format!(
197 "expected {} features, got {}",
198 self.n_features,
199 x.ncols()
200 )));
201 }
202
203 let classes = collect_classes(&self.tree);
205
206 let n_samples = x.nrows();
207 let n_classes = classes.len();
208 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
209
210 for (i, row) in x.rows().into_iter().enumerate() {
211 let leaf = find_leaf(&self.tree, row.as_slice().unwrap());
212 if let TreeNode::Leaf {
213 class_counts: Some(counts),
214 ..
215 } = leaf
216 {
217 let total: usize = counts.iter().map(|&(_, c)| c).sum();
218 let total_f = F::from_usize(total).unwrap();
219 for &(class_val, count) in counts {
220 if let Some(ci) = classes
221 .iter()
222 .position(|&c| (c - class_val).abs() < F::from_f64(1e-9).unwrap())
223 {
224 proba[[i, ci]] = F::from_usize(count).unwrap() / total_f;
225 }
226 }
227 } else {
228 let pred = self.tree.predict_one(row.as_slice().unwrap());
230 if let Some(ci) = classes
231 .iter()
232 .position(|&c| (c - pred).abs() < F::from_f64(1e-9).unwrap())
233 {
234 proba[[i, ci]] = F::one();
235 }
236 }
237 }
238
239 Ok(proba)
240 }
241
242 pub fn classes(&self) -> Vec<F> {
244 collect_classes(&self.tree)
245 }
246
247 pub fn n_features(&self) -> usize {
249 self.n_features
250 }
251}
252
253impl<F: Float> PredictProba<F> for FittedDecisionTreeClassifier<F> {
254 fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
255 Self::predict_proba(self, x)
257 }
258}
259
260struct TreeBuildParams {
262 max_depth: Option<usize>,
263 min_samples_split: usize,
264 min_samples_leaf: usize,
265 criterion: SplitCriterion,
266 max_features_k: Option<usize>,
268 n_features: usize,
270}
271
272fn build_tree<F: Float>(
273 x: &Array2<F>,
274 y: &Array1<F>,
275 indices: &[usize],
276 depth: usize,
277 params: &TreeBuildParams,
278 node_id: u64,
279 weights: Option<&Array1<F>>,
280) -> TreeNode<F> {
281 let n_samples = indices.len();
282 let impurity = match weights {
283 Some(w) => compute_weighted_impurity(y, indices, w, params.criterion),
284 None => compute_impurity(y, indices, params.criterion),
285 };
286
287 let should_stop = n_samples < params.min_samples_split
289 || params.max_depth.is_some_and(|d| depth >= d)
290 || impurity < F::from_f64(1e-15).unwrap();
291
292 if should_stop {
293 return make_leaf(y, indices, params.criterion, weights);
294 }
295
296 let feature_subset;
297 let feature_indices: &[usize] = if let Some(k) = params.max_features_k {
298 let seed = node_id
299 .wrapping_mul(0x517CC1B727220A95)
300 .wrapping_add(depth as u64);
301 feature_subset = select_feature_subset(params.n_features, k, seed);
302 &feature_subset
303 } else {
304 feature_subset = (0..params.n_features).collect();
305 &feature_subset
306 };
307
308 let split_result = match weights {
309 Some(w) => find_best_split_weighted(
310 x,
311 y,
312 indices,
313 w,
314 params.criterion,
315 params.min_samples_leaf,
316 feature_indices,
317 ),
318 None => find_best_split_with_features(
319 x,
320 y,
321 indices,
322 params.criterion,
323 params.min_samples_leaf,
324 feature_indices,
325 ),
326 };
327
328 match split_result {
329 Some(split) => {
330 let left = build_tree(
331 x,
332 y,
333 &split.left_indices,
334 depth + 1,
335 params,
336 node_id.wrapping_mul(2).wrapping_add(1),
337 weights,
338 );
339 let right = build_tree(
340 x,
341 y,
342 &split.right_indices,
343 depth + 1,
344 params,
345 node_id.wrapping_mul(2).wrapping_add(2),
346 weights,
347 );
348
349 TreeNode::Split {
350 feature_index: split.feature_index,
351 threshold: split.threshold,
352 left: Box::new(left),
353 right: Box::new(right),
354 n_samples,
355 impurity,
356 }
357 }
358 None => make_leaf(y, indices, params.criterion, weights),
359 }
360}
361
362fn make_leaf<F: Float>(
363 y: &Array1<F>,
364 indices: &[usize],
365 criterion: SplitCriterion,
366 weights: Option<&Array1<F>>,
367) -> TreeNode<F> {
368 let value = match weights {
369 Some(w) => weighted_leaf_value(y, indices, w, criterion),
370 None => leaf_value(y, indices, criterion),
371 };
372 let class_counts = match criterion {
373 SplitCriterion::Gini | SplitCriterion::Entropy => match weights {
374 Some(w) => {
375 let wc = weighted_count_classes(y, indices, w);
377 Some(
378 wc.into_iter()
379 .map(|(class, weight)| {
380 (class, (weight.to_f64().unwrap() * 1000.0).round() as usize)
382 })
383 .collect(),
384 )
385 }
386 None => Some(count_classes(y, indices)),
387 },
388 SplitCriterion::Mse => None,
389 };
390 TreeNode::Leaf {
391 value,
392 n_samples: indices.len(),
393 class_counts,
394 }
395}
396
397fn find_leaf<'a, F: Float>(node: &'a TreeNode<F>, features: &[F]) -> &'a TreeNode<F> {
399 match node {
400 TreeNode::Leaf { .. } => node,
401 TreeNode::Split {
402 feature_index,
403 threshold,
404 left,
405 right,
406 ..
407 } => {
408 if features[*feature_index] <= *threshold {
409 find_leaf(left, features)
410 } else {
411 find_leaf(right, features)
412 }
413 }
414 }
415}
416
417fn collect_classes<F: Float>(node: &TreeNode<F>) -> Vec<F> {
419 let mut classes = Vec::new();
420 collect_classes_recursive(node, &mut classes);
421 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
422 classes.dedup_by(|a, b| (*a - *b).abs() < F::from_f64(1e-9).unwrap());
423 classes
424}
425
426fn collect_classes_recursive<F: Float>(node: &TreeNode<F>, classes: &mut Vec<F>) {
427 match node {
428 TreeNode::Leaf {
429 class_counts: Some(counts),
430 ..
431 } => {
432 for &(class_val, _) in counts {
433 classes.push(class_val);
434 }
435 }
436 TreeNode::Leaf { value, .. } => {
437 classes.push(*value);
438 }
439 TreeNode::Split { left, right, .. } => {
440 collect_classes_recursive(left, classes);
441 collect_classes_recursive(right, classes);
442 }
443 }
444}
445
446fn tree_n_samples<F: Float>(node: &TreeNode<F>) -> usize {
447 match node {
448 TreeNode::Leaf { n_samples, .. } => *n_samples,
449 TreeNode::Split { n_samples, .. } => *n_samples,
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use approx::assert_abs_diff_eq;
457 use ndarray::array;
458
459 #[test]
460 fn test_simple_classification() {
461 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
462 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
463
464 let tree = DecisionTreeClassifier::default();
465 let fitted = Fit::fit(&tree, &x, &y).unwrap();
466
467 let preds = fitted.predict(&array![[1.5], [5.5]]).unwrap();
468 assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
469 assert_abs_diff_eq!(preds[1], 1.0, epsilon = 1e-10);
470 }
471
472 #[test]
473 fn test_max_depth() {
474 let x = array![[1.0], [2.0], [3.0], [4.0]];
475 let y = array![0.0, 0.0, 1.0, 1.0];
476
477 let tree = DecisionTreeClassifier {
478 max_depth: Some(1),
479 ..Default::default()
480 };
481 let fitted = Fit::fit(&tree, &x, &y).unwrap();
482 let preds = fitted.predict(&x).unwrap();
483
484 assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
486 assert_abs_diff_eq!(preds[3], 1.0, epsilon = 1e-10);
487 }
488
489 #[test]
490 fn test_feature_importances() {
491 let x = array![[1.0, 100.0], [2.0, 200.0], [3.0, 300.0], [4.0, 400.0]];
492 let y = array![0.0, 0.0, 1.0, 1.0];
493
494 let tree = DecisionTreeClassifier::default();
495 let fitted = Fit::fit(&tree, &x, &y).unwrap();
496
497 let importances = fitted.feature_importances();
498 let sum: f64 = importances.iter().sum();
500 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
501 }
502
503 #[test]
504 fn test_min_samples_split_constraint() {
505 let x = array![[1.0], [2.0], [3.0], [4.0]];
507 let y = array![0.0, 0.0, 1.0, 1.0];
508
509 let tree = DecisionTreeClassifier::new().with_min_samples_split(5);
510 let fitted = Fit::fit(&tree, &x, &y).unwrap();
511 let preds = fitted.predict(&x).unwrap();
512
513 let first = preds[0];
515 for &p in preds.iter() {
516 assert_abs_diff_eq!(p, first, epsilon = 1e-10);
517 }
518 }
519
520 #[test]
521 fn test_min_samples_leaf_constraint() {
522 let x = array![[1.0], [2.0], [3.0], [4.0]];
525 let y = array![0.0, 0.0, 1.0, 1.0];
526
527 let tree = DecisionTreeClassifier::new().with_min_samples_leaf(3);
528 let fitted = Fit::fit(&tree, &x, &y).unwrap();
529 let preds = fitted.predict(&x).unwrap();
530
531 let first = preds[0];
534 for &p in preds.iter() {
535 assert_abs_diff_eq!(p, first, epsilon = 1e-10);
536 }
537 }
538
539 #[test]
540 fn test_multiclass_three_classes() {
541 let x = array![
543 [1.0],
544 [2.0],
545 [3.0], [5.0],
547 [6.0],
548 [7.0], [9.0],
550 [10.0],
551 [11.0] ];
553 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
554
555 let tree = DecisionTreeClassifier::default();
556 let fitted = Fit::fit(&tree, &x, &y).unwrap();
557 let preds = fitted.predict(&x).unwrap();
558
559 for (pred, target) in preds.iter().zip(y.iter()) {
560 assert_abs_diff_eq!(pred, target, epsilon = 1e-10);
561 }
562 }
563
564 #[test]
565 fn test_single_class_input() {
566 let x = array![[1.0], [2.0], [3.0], [4.0]];
567 let y = array![7.0, 7.0, 7.0, 7.0];
568
569 let tree = DecisionTreeClassifier::default();
570 let fitted = Fit::fit(&tree, &x, &y).unwrap();
571 let preds = fitted.predict(&x).unwrap();
572
573 for &p in preds.iter() {
574 assert_abs_diff_eq!(p, 7.0, epsilon = 1e-10);
575 }
576 }
577
578 #[test]
579 fn test_single_feature() {
580 let x = array![[0.0], [1.0], [2.0], [10.0], [11.0], [12.0]];
582 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
583
584 let tree = DecisionTreeClassifier::default();
585 let fitted = Fit::fit(&tree, &x, &y).unwrap();
586
587 let test_x = array![[0.5], [11.5]];
588 let preds = fitted.predict(&test_x).unwrap();
589 assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
590 assert_abs_diff_eq!(preds[1], 1.0, epsilon = 1e-10);
591 }
592
593 #[test]
594 fn test_stump_depth_one() {
595 let x = array![[1.0], [2.0], [3.0], [4.0]];
596 let y = array![0.0, 0.0, 1.0, 1.0];
597
598 let tree = DecisionTreeClassifier::new().with_max_depth(Some(1));
599 let fitted = Fit::fit(&tree, &x, &y).unwrap();
600
601 match fitted.tree() {
603 TreeNode::Split { left, right, .. } => {
604 assert!(matches!(**left, TreeNode::Leaf { .. }));
605 assert!(matches!(**right, TreeNode::Leaf { .. }));
606 }
607 TreeNode::Leaf { .. } => panic!("expected a stump (Split node), got Leaf"),
608 }
609 }
610
611 #[test]
612 fn test_shape_mismatch_error() {
613 let x = array![[1.0], [2.0], [3.0]];
614 let y = array![0.0, 1.0]; let tree = DecisionTreeClassifier::default();
617 let result = Fit::<f64>::fit(&tree, &x, &y);
618 assert!(result.is_err());
619 match result.unwrap_err() {
620 RustMlError::ShapeMismatch(_) => {} other => panic!("expected ShapeMismatch, got {:?}", other),
622 }
623 }
624
625 #[test]
626 fn test_empty_input_error() {
627 let x: Array2<f64> = Array2::zeros((0, 0));
628 let y: Array1<f64> = array![];
629
630 let tree = DecisionTreeClassifier::default();
631 let result = Fit::<f64>::fit(&tree, &x, &y);
632 assert!(result.is_err());
633 match result.unwrap_err() {
634 RustMlError::EmptyInput(_) => {} other => panic!("expected EmptyInput, got {:?}", other),
636 }
637 }
638
639 #[test]
640 fn test_predict_wrong_features() {
641 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
642 let y = array![0.0, 0.0, 1.0, 1.0];
643
644 let tree = DecisionTreeClassifier::default();
645 let fitted = Fit::fit(&tree, &x, &y).unwrap();
646
647 let bad_x = array![[1.0, 2.0, 3.0]];
649 let result = fitted.predict(&bad_x);
650 assert!(result.is_err());
651 match result.unwrap_err() {
652 RustMlError::ShapeMismatch(_) => {} other => panic!("expected ShapeMismatch, got {:?}", other),
654 }
655 }
656
657 #[test]
658 fn test_large_feature_values() {
659 let x = array![
661 [1e10_f64, -1e10],
662 [2e10, -2e10],
663 [3e10, -3e10],
664 [4e10, -4e10],
665 ];
666 let y = array![0.0_f64, 0.0, 1.0, 1.0];
667
668 let tree = DecisionTreeClassifier::default();
669 let fitted = Fit::fit(&tree, &x, &y).unwrap();
670 let preds = fitted.predict(&x).unwrap();
671 for &p in preds.iter() {
672 assert!(p.is_finite(), "prediction should be finite, got {}", p);
673 }
674 }
675
676 #[test]
677 fn test_small_feature_values() {
678 let x = array![[1e-10], [2e-10], [3e-10], [4e-10],];
680 let y = array![0.0, 0.0, 1.0, 1.0];
681
682 let tree = DecisionTreeClassifier::default();
683 let fitted = Fit::fit(&tree, &x, &y).unwrap();
684 let preds = fitted.predict(&x).unwrap();
685 assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
687 assert_abs_diff_eq!(preds[3], 1.0, epsilon = 1e-10);
688 }
689
690 #[test]
691 fn test_near_identical_feature_values() {
692 let x = array![[1.0 + 1e-14], [1.0 + 2e-14], [1.0 + 3e-14], [1.0 + 4e-14],];
694 let y = array![0.0, 0.0, 1.0, 1.0];
695
696 let tree = DecisionTreeClassifier::default();
697 let fitted = Fit::fit(&tree, &x, &y).unwrap();
698 let preds = fitted.predict(&x).unwrap();
699 for &p in preds.iter() {
701 assert!(
702 p == 0.0 || p == 1.0,
703 "prediction should be 0 or 1, got {}",
704 p
705 );
706 }
707 }
708
709 mod prop_tests {
710 use super::*;
711 use proptest::prelude::*;
712 use std::collections::HashSet;
713
714 fn make_classification_data(
716 n_samples: usize,
717 n_features: usize,
718 n_classes: usize,
719 seed: u64,
720 ) -> (Array2<f64>, Array1<f64>) {
721 use std::collections::hash_map::DefaultHasher;
722 use std::hash::{Hash, Hasher};
723
724 let mut x_data = Vec::with_capacity(n_samples * n_features);
725 let mut y_data = Vec::with_capacity(n_samples);
726
727 for i in 0..n_samples {
728 for j in 0..n_features {
729 let mut h = DefaultHasher::new();
730 seed.hash(&mut h);
731 (i as u64).hash(&mut h);
732 (j as u64).hash(&mut h);
733 let bits = h.finish();
734 let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
735 x_data.push(v);
736 }
737 let mut h = DefaultHasher::new();
738 seed.hash(&mut h);
739 (i as u64).hash(&mut h);
740 0xDEAD_BEEFu64.hash(&mut h);
741 let label = (h.finish() % n_classes as u64) as f64;
742 y_data.push(label);
743 }
744
745 let x = Array2::from_shape_vec((n_samples, n_features), x_data).unwrap();
746 let y = Array1::from_vec(y_data);
747 (x, y)
748 }
749
750 proptest! {
751 #[test]
752 fn tree_predictions_are_valid_labels(
753 n_samples in 4..30usize,
754 n_features in 1..5usize,
755 seed in 0u64..1000,
756 ) {
757 let n_classes = 3;
758 let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
759
760 let train_labels: HashSet<u64> = y.iter()
762 .map(|&v| v.to_bits())
763 .collect();
764
765 let tree = DecisionTreeClassifier::new()
766 .with_max_depth(Some(5));
767 let fitted = Fit::fit(&tree, &x, &y).unwrap();
768 let preds = fitted.predict(&x).unwrap();
769
770 for (i, &p) in preds.iter().enumerate() {
771 prop_assert!(
772 train_labels.contains(&p.to_bits()),
773 "prediction {} at index {} is not a valid training label",
774 p, i
775 );
776 }
777 }
778
779 #[test]
780 fn tree_deterministic(seed in 0u64..1000) {
781 let (x, y) = make_classification_data(20, 3, 3, seed);
782
783 let tree = DecisionTreeClassifier::new()
784 .with_max_depth(Some(4));
785
786 let fitted1 = Fit::fit(&tree, &x, &y).unwrap();
787 let fitted2 = Fit::fit(&tree, &x, &y).unwrap();
788
789 let preds1 = fitted1.predict(&x).unwrap();
790 let preds2 = fitted2.predict(&x).unwrap();
791
792 for (i, (&a, &b)) in preds1.iter().zip(preds2.iter()).enumerate() {
793 prop_assert!((a - b).abs() < 1e-15,
794 "non-deterministic prediction at index {}: {} vs {}", i, a, b);
795 }
796 }
797 }
798 }
799}