Skip to main content

anofox_ml_trees/
classifier.rs

1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use ndarray::{Array1, Array2};
3
4use crate::node::TreeNode;
5use crate::split::{
6    compute_impurity, compute_sample_weights_from_class_weight, compute_weighted_impurity,
7    count_classes, find_best_split_weighted, find_best_split_with_features, leaf_value,
8    select_feature_subset, weighted_count_classes, weighted_leaf_value, ClassWeight, MaxFeatures,
9    SplitCriterion,
10};
11
12/// Decision tree classifier parameters (unfitted state).
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct DecisionTreeClassifier {
15    pub max_depth: Option<usize>,
16    pub min_samples_split: usize,
17    pub min_samples_leaf: usize,
18    pub criterion: SplitCriterion,
19    /// Maximum number of features to consider at each split.
20    pub max_features: Option<MaxFeatures>,
21    /// Per-sample weights.
22    #[serde(skip)]
23    pub sample_weight: Option<Array1<f64>>,
24    /// Class weighting strategy.
25    pub class_weight: Option<ClassWeight>,
26}
27
28impl DecisionTreeClassifier {
29    /// Create a new `DecisionTreeClassifier` with sensible defaults.
30    pub fn new() -> Self {
31        Self {
32            max_depth: None,
33            min_samples_split: 2,
34            min_samples_leaf: 1,
35            criterion: SplitCriterion::Gini,
36            max_features: None,
37            sample_weight: None,
38            class_weight: None,
39        }
40    }
41
42    /// Set the maximum depth of the tree.
43    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
44        self.max_depth = max_depth;
45        self
46    }
47
48    /// Set the minimum number of samples required to split a node.
49    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
50        self.min_samples_split = min_samples_split;
51        self
52    }
53
54    /// Set the minimum number of samples required in a leaf node.
55    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
56        self.min_samples_leaf = min_samples_leaf;
57        self
58    }
59
60    /// Set the split quality criterion.
61    pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
62        self.criterion = criterion;
63        self
64    }
65
66    /// Set the maximum number of features to consider at each split.
67    pub fn with_max_features(mut self, max_features: Option<MaxFeatures>) -> Self {
68        self.max_features = max_features;
69        self
70    }
71
72    /// Set per-sample weights.
73    pub fn with_sample_weight(mut self, sample_weight: Option<Array1<f64>>) -> Self {
74        self.sample_weight = sample_weight;
75        self
76    }
77
78    /// Set class weighting strategy.
79    pub fn with_class_weight(mut self, class_weight: Option<ClassWeight>) -> Self {
80        self.class_weight = class_weight;
81        self
82    }
83}
84
85impl Default for DecisionTreeClassifier {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91/// Fitted decision tree classifier.
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
94pub struct FittedDecisionTreeClassifier<F: Float> {
95    tree: TreeNode<F>,
96    n_features: usize,
97}
98
99impl<F: Float> Fit<F> for DecisionTreeClassifier {
100    type Fitted = FittedDecisionTreeClassifier<F>;
101
102    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
103        if x.nrows() != y.len() {
104            return Err(RustMlError::ShapeMismatch(format!(
105                "X has {} rows but y has {} elements",
106                x.nrows(),
107                y.len()
108            )));
109        }
110        if x.is_empty() {
111            return Err(RustMlError::EmptyInput("training data is empty".into()));
112        }
113
114        let indices: Vec<usize> = (0..x.nrows()).collect();
115        let n_features = x.ncols();
116        let max_features_k = self.max_features.map(|mf| mf.resolve(n_features));
117
118        // Compute effective sample weights (merge class_weight and sample_weight)
119        let effective_weights: Option<Array1<F>> = {
120            let class_w = self
121                .class_weight
122                .as_ref()
123                .map(|cw| compute_sample_weights_from_class_weight(y, cw));
124            let sample_w = self
125                .sample_weight
126                .as_ref()
127                .map(|sw| sw.mapv(|v| F::from_f64(v).unwrap()));
128            match (class_w, sample_w) {
129                (Some(cw), Some(sw)) => Some(cw * sw),
130                (Some(cw), None) => Some(cw),
131                (None, Some(sw)) => Some(sw),
132                (None, None) => None,
133            }
134        };
135
136        let params = TreeBuildParams {
137            max_depth: self.max_depth,
138            min_samples_split: self.min_samples_split,
139            min_samples_leaf: self.min_samples_leaf,
140            criterion: self.criterion,
141            max_features_k,
142            n_features,
143        };
144        let tree = build_tree(x, y, &indices, 0, &params, 0, effective_weights.as_ref());
145
146        Ok(FittedDecisionTreeClassifier {
147            tree,
148            n_features: x.ncols(),
149        })
150    }
151}
152
153impl<F: Float> Predict<F> for FittedDecisionTreeClassifier<F> {
154    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
155        if x.ncols() != self.n_features {
156            return Err(RustMlError::ShapeMismatch(format!(
157                "expected {} features, got {}",
158                self.n_features,
159                x.ncols()
160            )));
161        }
162
163        let predictions: Vec<F> = x
164            .rows()
165            .into_iter()
166            .map(|row| self.tree.predict_one(row.as_slice().unwrap()))
167            .collect();
168
169        Ok(Array1::from_vec(predictions))
170    }
171}
172
173impl<F: Float> FittedDecisionTreeClassifier<F> {
174    /// Feature importances (normalized to sum to 1).
175    pub fn feature_importances(&self) -> Array1<F> {
176        let n_samples = tree_n_samples(&self.tree);
177        let raw = self.tree.feature_importances(self.n_features, n_samples);
178        let sum: F = raw.iter().copied().fold(F::zero(), |a, b| a + b);
179        if sum > F::zero() {
180            Array1::from_vec(raw.into_iter().map(|v| v / sum).collect())
181        } else {
182            Array1::zeros(self.n_features)
183        }
184    }
185
186    pub fn tree(&self) -> &TreeNode<F> {
187        &self.tree
188    }
189
190    /// Predict class probabilities for each sample.
191    ///
192    /// Returns an `Array2<F>` of shape `(n_samples, n_classes)` where each row
193    /// sums to 1.0. Classes are sorted in ascending order.
194    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
195        if x.ncols() != self.n_features {
196            return Err(RustMlError::ShapeMismatch(format!(
197                "expected {} features, got {}",
198                self.n_features,
199                x.ncols()
200            )));
201        }
202
203        // Collect all unique classes from the tree
204        let classes = collect_classes(&self.tree);
205
206        let n_samples = x.nrows();
207        let n_classes = classes.len();
208        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
209
210        for (i, row) in x.rows().into_iter().enumerate() {
211            let leaf = find_leaf(&self.tree, row.as_slice().unwrap());
212            if let TreeNode::Leaf {
213                class_counts: Some(counts),
214                ..
215            } = leaf
216            {
217                let total: usize = counts.iter().map(|&(_, c)| c).sum();
218                let total_f = F::from_usize(total).unwrap();
219                for &(class_val, count) in counts {
220                    if let Some(ci) = classes
221                        .iter()
222                        .position(|&c| (c - class_val).abs() < F::from_f64(1e-9).unwrap())
223                    {
224                        proba[[i, ci]] = F::from_usize(count).unwrap() / total_f;
225                    }
226                }
227            } else {
228                // Regression leaf or no counts — put all weight on predicted class
229                let pred = self.tree.predict_one(row.as_slice().unwrap());
230                if let Some(ci) = classes
231                    .iter()
232                    .position(|&c| (c - pred).abs() < F::from_f64(1e-9).unwrap())
233                {
234                    proba[[i, ci]] = F::one();
235                }
236            }
237        }
238
239        Ok(proba)
240    }
241
242    /// Returns the unique sorted class labels learned during fitting.
243    pub fn classes(&self) -> Vec<F> {
244        collect_classes(&self.tree)
245    }
246
247    /// Number of features expected at prediction time.
248    pub fn n_features(&self) -> usize {
249        self.n_features
250    }
251}
252
253impl<F: Float> PredictProba<F> for FittedDecisionTreeClassifier<F> {
254    fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
255        // Delegate to the inherent method.
256        Self::predict_proba(self, x)
257    }
258}
259
260/// Bundled parameters for recursive tree building (avoids too many function args).
261struct TreeBuildParams {
262    max_depth: Option<usize>,
263    min_samples_split: usize,
264    min_samples_leaf: usize,
265    criterion: SplitCriterion,
266    /// Resolved max features count per split (None = use all).
267    max_features_k: Option<usize>,
268    /// Total number of features in the dataset.
269    n_features: usize,
270}
271
272fn build_tree<F: Float>(
273    x: &Array2<F>,
274    y: &Array1<F>,
275    indices: &[usize],
276    depth: usize,
277    params: &TreeBuildParams,
278    node_id: u64,
279    weights: Option<&Array1<F>>,
280) -> TreeNode<F> {
281    let n_samples = indices.len();
282    let impurity = match weights {
283        Some(w) => compute_weighted_impurity(y, indices, w, params.criterion),
284        None => compute_impurity(y, indices, params.criterion),
285    };
286
287    // Check stopping criteria
288    let should_stop = n_samples < params.min_samples_split
289        || params.max_depth.is_some_and(|d| depth >= d)
290        || impurity < F::from_f64(1e-15).unwrap();
291
292    if should_stop {
293        return make_leaf(y, indices, params.criterion, weights);
294    }
295
296    let feature_subset;
297    let feature_indices: &[usize] = if let Some(k) = params.max_features_k {
298        let seed = node_id
299            .wrapping_mul(0x517CC1B727220A95)
300            .wrapping_add(depth as u64);
301        feature_subset = select_feature_subset(params.n_features, k, seed);
302        &feature_subset
303    } else {
304        feature_subset = (0..params.n_features).collect();
305        &feature_subset
306    };
307
308    let split_result = match weights {
309        Some(w) => find_best_split_weighted(
310            x,
311            y,
312            indices,
313            w,
314            params.criterion,
315            params.min_samples_leaf,
316            feature_indices,
317        ),
318        None => find_best_split_with_features(
319            x,
320            y,
321            indices,
322            params.criterion,
323            params.min_samples_leaf,
324            feature_indices,
325        ),
326    };
327
328    match split_result {
329        Some(split) => {
330            let left = build_tree(
331                x,
332                y,
333                &split.left_indices,
334                depth + 1,
335                params,
336                node_id.wrapping_mul(2).wrapping_add(1),
337                weights,
338            );
339            let right = build_tree(
340                x,
341                y,
342                &split.right_indices,
343                depth + 1,
344                params,
345                node_id.wrapping_mul(2).wrapping_add(2),
346                weights,
347            );
348
349            TreeNode::Split {
350                feature_index: split.feature_index,
351                threshold: split.threshold,
352                left: Box::new(left),
353                right: Box::new(right),
354                n_samples,
355                impurity,
356            }
357        }
358        None => make_leaf(y, indices, params.criterion, weights),
359    }
360}
361
362fn make_leaf<F: Float>(
363    y: &Array1<F>,
364    indices: &[usize],
365    criterion: SplitCriterion,
366    weights: Option<&Array1<F>>,
367) -> TreeNode<F> {
368    let value = match weights {
369        Some(w) => weighted_leaf_value(y, indices, w, criterion),
370        None => leaf_value(y, indices, criterion),
371    };
372    let class_counts = match criterion {
373        SplitCriterion::Gini | SplitCriterion::Entropy => match weights {
374            Some(w) => {
375                // Store weighted counts as approximate integer counts for predict_proba compat
376                let wc = weighted_count_classes(y, indices, w);
377                Some(
378                    wc.into_iter()
379                        .map(|(class, weight)| {
380                            // Scale weight to integer-like count (multiply by 1000 for precision)
381                            (class, (weight.to_f64().unwrap() * 1000.0).round() as usize)
382                        })
383                        .collect(),
384                )
385            }
386            None => Some(count_classes(y, indices)),
387        },
388        SplitCriterion::Mse => None,
389    };
390    TreeNode::Leaf {
391        value,
392        n_samples: indices.len(),
393        class_counts,
394    }
395}
396
397/// Traverse the tree to find the leaf node for a given sample.
398fn find_leaf<'a, F: Float>(node: &'a TreeNode<F>, features: &[F]) -> &'a TreeNode<F> {
399    match node {
400        TreeNode::Leaf { .. } => node,
401        TreeNode::Split {
402            feature_index,
403            threshold,
404            left,
405            right,
406            ..
407        } => {
408            if features[*feature_index] <= *threshold {
409                find_leaf(left, features)
410            } else {
411                find_leaf(right, features)
412            }
413        }
414    }
415}
416
417/// Collect all unique sorted class labels from the tree's leaf nodes.
418fn collect_classes<F: Float>(node: &TreeNode<F>) -> Vec<F> {
419    let mut classes = Vec::new();
420    collect_classes_recursive(node, &mut classes);
421    classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
422    classes.dedup_by(|a, b| (*a - *b).abs() < F::from_f64(1e-9).unwrap());
423    classes
424}
425
426fn collect_classes_recursive<F: Float>(node: &TreeNode<F>, classes: &mut Vec<F>) {
427    match node {
428        TreeNode::Leaf {
429            class_counts: Some(counts),
430            ..
431        } => {
432            for &(class_val, _) in counts {
433                classes.push(class_val);
434            }
435        }
436        TreeNode::Leaf { value, .. } => {
437            classes.push(*value);
438        }
439        TreeNode::Split { left, right, .. } => {
440            collect_classes_recursive(left, classes);
441            collect_classes_recursive(right, classes);
442        }
443    }
444}
445
446fn tree_n_samples<F: Float>(node: &TreeNode<F>) -> usize {
447    match node {
448        TreeNode::Leaf { n_samples, .. } => *n_samples,
449        TreeNode::Split { n_samples, .. } => *n_samples,
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use approx::assert_abs_diff_eq;
457    use ndarray::array;
458
459    #[test]
460    fn test_simple_classification() {
461        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
462        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
463
464        let tree = DecisionTreeClassifier::default();
465        let fitted = Fit::fit(&tree, &x, &y).unwrap();
466
467        let preds = fitted.predict(&array![[1.5], [5.5]]).unwrap();
468        assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
469        assert_abs_diff_eq!(preds[1], 1.0, epsilon = 1e-10);
470    }
471
472    #[test]
473    fn test_max_depth() {
474        let x = array![[1.0], [2.0], [3.0], [4.0]];
475        let y = array![0.0, 0.0, 1.0, 1.0];
476
477        let tree = DecisionTreeClassifier {
478            max_depth: Some(1),
479            ..Default::default()
480        };
481        let fitted = Fit::fit(&tree, &x, &y).unwrap();
482        let preds = fitted.predict(&x).unwrap();
483
484        // With max_depth=1, should still separate the two classes
485        assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
486        assert_abs_diff_eq!(preds[3], 1.0, epsilon = 1e-10);
487    }
488
489    #[test]
490    fn test_feature_importances() {
491        let x = array![[1.0, 100.0], [2.0, 200.0], [3.0, 300.0], [4.0, 400.0]];
492        let y = array![0.0, 0.0, 1.0, 1.0];
493
494        let tree = DecisionTreeClassifier::default();
495        let fitted = Fit::fit(&tree, &x, &y).unwrap();
496
497        let importances = fitted.feature_importances();
498        // Sum should be 1.0
499        let sum: f64 = importances.iter().sum();
500        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
501    }
502
503    #[test]
504    fn test_min_samples_split_constraint() {
505        // 4 samples with min_samples_split=5 means the root can never split
506        let x = array![[1.0], [2.0], [3.0], [4.0]];
507        let y = array![0.0, 0.0, 1.0, 1.0];
508
509        let tree = DecisionTreeClassifier::new().with_min_samples_split(5);
510        let fitted = Fit::fit(&tree, &x, &y).unwrap();
511        let preds = fitted.predict(&x).unwrap();
512
513        // All predictions should be the same (single leaf)
514        let first = preds[0];
515        for &p in preds.iter() {
516            assert_abs_diff_eq!(p, first, epsilon = 1e-10);
517        }
518    }
519
520    #[test]
521    fn test_min_samples_leaf_constraint() {
522        // 4 samples, 2 of each class. min_samples_leaf=2 means leaves need >= 2 samples.
523        // A split into [0,0] and [1,1] satisfies this, but min_samples_leaf=3 would not.
524        let x = array![[1.0], [2.0], [3.0], [4.0]];
525        let y = array![0.0, 0.0, 1.0, 1.0];
526
527        let tree = DecisionTreeClassifier::new().with_min_samples_leaf(3);
528        let fitted = Fit::fit(&tree, &x, &y).unwrap();
529        let preds = fitted.predict(&x).unwrap();
530
531        // With min_samples_leaf=3 on 4 samples, no valid split exists (each side would
532        // have at most 2 samples), so the tree degenerates to a single leaf.
533        let first = preds[0];
534        for &p in preds.iter() {
535            assert_abs_diff_eq!(p, first, epsilon = 1e-10);
536        }
537    }
538
539    #[test]
540    fn test_multiclass_three_classes() {
541        // 9 data points, 3 classes separated by feature value
542        let x = array![
543            [1.0],
544            [2.0],
545            [3.0], // class 0
546            [5.0],
547            [6.0],
548            [7.0], // class 1
549            [9.0],
550            [10.0],
551            [11.0] // class 2
552        ];
553        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
554
555        let tree = DecisionTreeClassifier::default();
556        let fitted = Fit::fit(&tree, &x, &y).unwrap();
557        let preds = fitted.predict(&x).unwrap();
558
559        for (pred, target) in preds.iter().zip(y.iter()) {
560            assert_abs_diff_eq!(pred, target, epsilon = 1e-10);
561        }
562    }
563
564    #[test]
565    fn test_single_class_input() {
566        let x = array![[1.0], [2.0], [3.0], [4.0]];
567        let y = array![7.0, 7.0, 7.0, 7.0];
568
569        let tree = DecisionTreeClassifier::default();
570        let fitted = Fit::fit(&tree, &x, &y).unwrap();
571        let preds = fitted.predict(&x).unwrap();
572
573        for &p in preds.iter() {
574            assert_abs_diff_eq!(p, 7.0, epsilon = 1e-10);
575        }
576    }
577
578    #[test]
579    fn test_single_feature() {
580        // Simple binary split on one feature
581        let x = array![[0.0], [1.0], [2.0], [10.0], [11.0], [12.0]];
582        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
583
584        let tree = DecisionTreeClassifier::default();
585        let fitted = Fit::fit(&tree, &x, &y).unwrap();
586
587        let test_x = array![[0.5], [11.5]];
588        let preds = fitted.predict(&test_x).unwrap();
589        assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
590        assert_abs_diff_eq!(preds[1], 1.0, epsilon = 1e-10);
591    }
592
593    #[test]
594    fn test_stump_depth_one() {
595        let x = array![[1.0], [2.0], [3.0], [4.0]];
596        let y = array![0.0, 0.0, 1.0, 1.0];
597
598        let tree = DecisionTreeClassifier::new().with_max_depth(Some(1));
599        let fitted = Fit::fit(&tree, &x, &y).unwrap();
600
601        // The root should be a Split whose children are both Leaves
602        match fitted.tree() {
603            TreeNode::Split { left, right, .. } => {
604                assert!(matches!(**left, TreeNode::Leaf { .. }));
605                assert!(matches!(**right, TreeNode::Leaf { .. }));
606            }
607            TreeNode::Leaf { .. } => panic!("expected a stump (Split node), got Leaf"),
608        }
609    }
610
611    #[test]
612    fn test_shape_mismatch_error() {
613        let x = array![[1.0], [2.0], [3.0]];
614        let y = array![0.0, 1.0]; // 3 rows vs 2 labels
615
616        let tree = DecisionTreeClassifier::default();
617        let result = Fit::<f64>::fit(&tree, &x, &y);
618        assert!(result.is_err());
619        match result.unwrap_err() {
620            RustMlError::ShapeMismatch(_) => {} // expected
621            other => panic!("expected ShapeMismatch, got {:?}", other),
622        }
623    }
624
625    #[test]
626    fn test_empty_input_error() {
627        let x: Array2<f64> = Array2::zeros((0, 0));
628        let y: Array1<f64> = array![];
629
630        let tree = DecisionTreeClassifier::default();
631        let result = Fit::<f64>::fit(&tree, &x, &y);
632        assert!(result.is_err());
633        match result.unwrap_err() {
634            RustMlError::EmptyInput(_) => {} // expected
635            other => panic!("expected EmptyInput, got {:?}", other),
636        }
637    }
638
639    #[test]
640    fn test_predict_wrong_features() {
641        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
642        let y = array![0.0, 0.0, 1.0, 1.0];
643
644        let tree = DecisionTreeClassifier::default();
645        let fitted = Fit::fit(&tree, &x, &y).unwrap();
646
647        // Predict with 3 features instead of 2
648        let bad_x = array![[1.0, 2.0, 3.0]];
649        let result = fitted.predict(&bad_x);
650        assert!(result.is_err());
651        match result.unwrap_err() {
652            RustMlError::ShapeMismatch(_) => {} // expected
653            other => panic!("expected ShapeMismatch, got {:?}", other),
654        }
655    }
656
657    #[test]
658    fn test_large_feature_values() {
659        // Very large feature values should not cause panics or NaN
660        let x = array![
661            [1e10_f64, -1e10],
662            [2e10, -2e10],
663            [3e10, -3e10],
664            [4e10, -4e10],
665        ];
666        let y = array![0.0_f64, 0.0, 1.0, 1.0];
667
668        let tree = DecisionTreeClassifier::default();
669        let fitted = Fit::fit(&tree, &x, &y).unwrap();
670        let preds = fitted.predict(&x).unwrap();
671        for &p in preds.iter() {
672            assert!(p.is_finite(), "prediction should be finite, got {}", p);
673        }
674    }
675
676    #[test]
677    fn test_small_feature_values() {
678        // Very small feature values should still produce valid splits
679        let x = array![[1e-10], [2e-10], [3e-10], [4e-10],];
680        let y = array![0.0, 0.0, 1.0, 1.0];
681
682        let tree = DecisionTreeClassifier::default();
683        let fitted = Fit::fit(&tree, &x, &y).unwrap();
684        let preds = fitted.predict(&x).unwrap();
685        // Should separate the two classes
686        assert_abs_diff_eq!(preds[0], 0.0, epsilon = 1e-10);
687        assert_abs_diff_eq!(preds[3], 1.0, epsilon = 1e-10);
688    }
689
690    #[test]
691    fn test_near_identical_feature_values() {
692        // Features that differ by tiny amounts (near machine epsilon)
693        let x = array![[1.0 + 1e-14], [1.0 + 2e-14], [1.0 + 3e-14], [1.0 + 4e-14],];
694        let y = array![0.0, 0.0, 1.0, 1.0];
695
696        let tree = DecisionTreeClassifier::default();
697        let fitted = Fit::fit(&tree, &x, &y).unwrap();
698        let preds = fitted.predict(&x).unwrap();
699        // Should not panic; predictions should be valid labels
700        for &p in preds.iter() {
701            assert!(
702                p == 0.0 || p == 1.0,
703                "prediction should be 0 or 1, got {}",
704                p
705            );
706        }
707    }
708
709    mod prop_tests {
710        use super::*;
711        use proptest::prelude::*;
712        use std::collections::HashSet;
713
714        /// Generate deterministic training data for classification.
715        fn make_classification_data(
716            n_samples: usize,
717            n_features: usize,
718            n_classes: usize,
719            seed: u64,
720        ) -> (Array2<f64>, Array1<f64>) {
721            use std::collections::hash_map::DefaultHasher;
722            use std::hash::{Hash, Hasher};
723
724            let mut x_data = Vec::with_capacity(n_samples * n_features);
725            let mut y_data = Vec::with_capacity(n_samples);
726
727            for i in 0..n_samples {
728                for j in 0..n_features {
729                    let mut h = DefaultHasher::new();
730                    seed.hash(&mut h);
731                    (i as u64).hash(&mut h);
732                    (j as u64).hash(&mut h);
733                    let bits = h.finish();
734                    let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
735                    x_data.push(v);
736                }
737                let mut h = DefaultHasher::new();
738                seed.hash(&mut h);
739                (i as u64).hash(&mut h);
740                0xDEAD_BEEFu64.hash(&mut h);
741                let label = (h.finish() % n_classes as u64) as f64;
742                y_data.push(label);
743            }
744
745            let x = Array2::from_shape_vec((n_samples, n_features), x_data).unwrap();
746            let y = Array1::from_vec(y_data);
747            (x, y)
748        }
749
750        proptest! {
751            #[test]
752            fn tree_predictions_are_valid_labels(
753                n_samples in 4..30usize,
754                n_features in 1..5usize,
755                seed in 0u64..1000,
756            ) {
757                let n_classes = 3;
758                let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
759
760                // Collect unique training labels
761                let train_labels: HashSet<u64> = y.iter()
762                    .map(|&v| v.to_bits())
763                    .collect();
764
765                let tree = DecisionTreeClassifier::new()
766                    .with_max_depth(Some(5));
767                let fitted = Fit::fit(&tree, &x, &y).unwrap();
768                let preds = fitted.predict(&x).unwrap();
769
770                for (i, &p) in preds.iter().enumerate() {
771                    prop_assert!(
772                        train_labels.contains(&p.to_bits()),
773                        "prediction {} at index {} is not a valid training label",
774                        p, i
775                    );
776                }
777            }
778
779            #[test]
780            fn tree_deterministic(seed in 0u64..1000) {
781                let (x, y) = make_classification_data(20, 3, 3, seed);
782
783                let tree = DecisionTreeClassifier::new()
784                    .with_max_depth(Some(4));
785
786                let fitted1 = Fit::fit(&tree, &x, &y).unwrap();
787                let fitted2 = Fit::fit(&tree, &x, &y).unwrap();
788
789                let preds1 = fitted1.predict(&x).unwrap();
790                let preds2 = fitted2.predict(&x).unwrap();
791
792                for (i, (&a, &b)) in preds1.iter().zip(preds2.iter()).enumerate() {
793                    prop_assert!((a - b).abs() < 1e-15,
794                        "non-deterministic prediction at index {}: {} vs {}", i, a, b);
795                }
796            }
797        }
798    }
799}