1use num_traits::cast::AsPrimitive;
2use num_traits::{Float, FromPrimitive, NumCast, ToPrimitive};
3use rayon::prelude::*;
4use std::cmp::Ordering;
5use std::collections::HashMap;
6use std::fmt::{self, Debug};
7use std::hash::Hash;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
11pub enum TreeType {
12 Regression,
14 Classification,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum SplitCriterion {
21 Mse,
23 Mae,
25 Gini,
27 Entropy,
29}
30
31#[derive(Debug, Clone)]
33struct Node<T, F>
34where
35 T: Clone + PartialOrd + Debug + ToPrimitive,
36 F: Float,
37{
38 feature_idx: Option<usize>,
40 threshold: Option<T>,
42 value: Option<T>,
44 class_distribution: Option<HashMap<T, usize>>,
46 left: Option<usize>,
48 right: Option<usize>,
50 _phantom: std::marker::PhantomData<F>,
52}
53
54impl<T, F> Node<T, F>
55where
56 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
57 F: Float,
58{
59 fn new_split(feature_idx: usize, threshold: T) -> Self {
61 Node {
62 feature_idx: Some(feature_idx),
63 threshold: Some(threshold),
64 value: None,
65 class_distribution: None,
66 left: None,
67 right: None,
68 _phantom: std::marker::PhantomData,
69 }
70 }
71
72 fn new_leaf_regression(value: T) -> Self {
74 Node {
75 feature_idx: None,
76 threshold: None,
77 value: Some(value),
78 class_distribution: None,
79 left: None,
80 right: None,
81 _phantom: std::marker::PhantomData,
82 }
83 }
84
85 fn new_leaf_classification(value: T, class_distribution: HashMap<T, usize>) -> Self {
87 Node {
88 feature_idx: None,
89 threshold: None,
90 value: Some(value),
91 class_distribution: Some(class_distribution),
92 left: None,
93 right: None,
94 _phantom: std::marker::PhantomData,
95 }
96 }
97
98 fn is_leaf(&self) -> bool {
100 self.feature_idx.is_none()
101 }
102}
103
104#[derive(Debug, Clone)]
110pub struct DecisionTree<T, F>
111where
112 T: Clone + PartialOrd + Debug + ToPrimitive,
113 F: Float,
114{
115 tree_type: TreeType,
117 criterion: SplitCriterion,
119 max_depth: usize,
121 min_samples_split: usize,
123 min_samples_leaf: usize,
125 nodes: Vec<Node<T, F>>,
127}
128
129impl<T, F> DecisionTree<T, F>
130where
131 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
132 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
133 f64: AsPrimitive<F>,
134 usize: AsPrimitive<F>,
135 T: AsPrimitive<F>,
136 F: AsPrimitive<T>,
137{
138 pub fn new(
140 tree_type: TreeType,
141 criterion: SplitCriterion,
142 max_depth: usize,
143 min_samples_split: usize,
144 min_samples_leaf: usize,
145 ) -> Self {
146 Self {
147 tree_type,
148 criterion,
149 max_depth,
150 min_samples_split,
151 min_samples_leaf,
152 nodes: Vec::new(),
153 }
154 }
155
156 pub fn fit<D>(&mut self, features: &[Vec<D>], target: &[T])
158 where
159 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
160 T: FromPrimitive,
161 {
162 assert!(!features.is_empty(), "Features cannot be empty");
163 assert!(!target.is_empty(), "Target cannot be empty");
164 assert_eq!(
165 features.len(),
166 target.len(),
167 "Features and target must have the same length"
168 );
169
170 let n_features = features[0].len();
172 for feature_vec in features {
173 assert_eq!(
174 feature_vec.len(),
175 n_features,
176 "All feature vectors must have the same length"
177 );
178 }
179
180 self.nodes = Vec::new();
182
183 let indices: Vec<usize> = (0..features.len()).collect();
185
186 self.build_tree(features, target, &indices, 0);
188 }
189
190 fn build_tree<D>(
192 &mut self,
193 features: &[Vec<D>],
194 target: &[T],
195 indices: &[usize],
196 depth: usize,
197 ) -> usize
198 where
199 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
200 {
201 if depth >= self.max_depth
203 || indices.len() < self.min_samples_split
204 || self.is_pure(target, indices)
205 {
206 let node_idx = self.nodes.len();
207 if self.tree_type == TreeType::Regression {
208 let value = self.calculate_mean(target, indices);
210 self.nodes.push(Node::new_leaf_regression(value));
211 } else {
212 let (value, class_counts) = self.calculate_class_distribution(target, indices);
214 self.nodes
215 .push(Node::new_leaf_classification(value, class_counts));
216 }
217 return node_idx;
218 }
219
220 let (feature_idx, threshold, left_indices, right_indices) =
222 self.find_best_split(features, target, indices);
223
224 if left_indices.is_empty() || right_indices.is_empty() {
226 let node_idx = self.nodes.len();
227 if self.tree_type == TreeType::Regression {
228 let value = self.calculate_mean(target, indices);
229 self.nodes.push(Node::new_leaf_regression(value));
230 } else {
231 let (value, class_counts) = self.calculate_class_distribution(target, indices);
232 self.nodes
233 .push(Node::new_leaf_classification(value, class_counts));
234 }
235 return node_idx;
236 }
237
238 let node_idx = self.nodes.len();
240
241 let t_threshold = NumCast::from(threshold).unwrap_or_else(|| {
243 panic!("Failed to convert threshold to the feature type");
244 });
245
246 self.nodes.push(Node::new_split(feature_idx, t_threshold));
247
248 let left_idx = self.build_tree(features, target, &left_indices, depth + 1);
250 let right_idx = self.build_tree(features, target, &right_indices, depth + 1);
251
252 self.nodes[node_idx].left = Some(left_idx);
254 self.nodes[node_idx].right = Some(right_idx);
255
256 node_idx
257 }
258
259 fn find_best_split<D>(
261 &self,
262 features: &[Vec<D>],
263 target: &[T],
264 indices: &[usize],
265 ) -> (usize, D, Vec<usize>, Vec<usize>)
266 where
267 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
268 {
269 let n_features = features[0].len();
270
271 let mut best_impurity = F::infinity();
273 let mut best_feature = 0;
274 let mut best_threshold = features[indices[0]][0].clone();
275 let mut best_left = Vec::new();
276 let mut best_right = Vec::new();
277
278 let results: Vec<_> = (0..n_features)
280 .into_par_iter()
281 .filter_map(|feature_idx| {
282 let mut feature_values: Vec<(usize, D)> = indices
284 .iter()
285 .map(|&idx| (idx, features[idx][feature_idx].clone()))
286 .collect();
287
288 feature_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
290
291 let mut values: Vec<D> = Vec::new();
293 let mut prev_val: Option<&D> = None;
294
295 for (_, val) in &feature_values {
296 if prev_val.is_none()
297 || prev_val
298 .unwrap()
299 .partial_cmp(val)
300 .unwrap_or(Ordering::Equal)
301 != Ordering::Equal
302 {
303 values.push(val.clone());
304 prev_val = Some(val);
305 }
306 }
307
308 if values.len() <= 1 {
310 return None;
311 }
312
313 let mut feature_best_impurity = F::infinity();
315 let mut feature_best_threshold = values[0].clone();
316 let mut feature_best_left = Vec::new();
317 let mut feature_best_right = Vec::new();
318
319 for i in 0..values.len() - 1 {
320 let val1: F = values[i].as_();
322 let val2: F = values[i + 1].as_();
323
324 let mid_value = (val1 + val2) / F::from(2.0).unwrap();
326
327 let threshold = NumCast::from(mid_value).unwrap_or_else(|| {
329 panic!("Failed to convert threshold to the feature type");
330 });
331
332 let mut left_indices = Vec::new();
334 let mut right_indices = Vec::new();
335
336 for &idx in indices {
337 let feature_value = &features[idx][feature_idx];
338 if feature_value
339 .partial_cmp(&threshold)
340 .unwrap_or(Ordering::Equal)
341 != Ordering::Greater
342 {
343 left_indices.push(idx);
344 } else {
345 right_indices.push(idx);
346 }
347 }
348
349 if left_indices.len() < self.min_samples_leaf
351 || right_indices.len() < self.min_samples_leaf
352 {
353 continue;
354 }
355
356 let impurity =
358 self.calculate_split_impurity(target, &left_indices, &right_indices);
359
360 if impurity < feature_best_impurity {
362 feature_best_impurity = impurity;
363 feature_best_threshold = threshold;
364 feature_best_left = left_indices;
365 feature_best_right = right_indices;
366 }
367 }
368
369 if !feature_best_left.is_empty() && !feature_best_right.is_empty() {
371 Some((
372 feature_idx,
373 feature_best_impurity,
374 feature_best_threshold,
375 feature_best_left,
376 feature_best_right,
377 ))
378 } else {
379 None
380 }
381 })
382 .collect();
383
384 for (feature_idx, impurity, threshold, left, right) in results {
386 if impurity < best_impurity {
387 best_impurity = impurity;
388 best_feature = feature_idx;
389 best_threshold = threshold;
390 best_left = left;
391 best_right = right;
392 }
393 }
394
395 (best_feature, best_threshold, best_left, best_right)
396 }
397
398 fn calculate_split_impurity(
400 &self,
401 target: &[T],
402 left_indices: &[usize],
403 right_indices: &[usize],
404 ) -> F {
405 let n_left = left_indices.len();
406 let n_right = right_indices.len();
407 let n_total = n_left + n_right;
408
409 if n_left == 0 || n_right == 0 {
410 return F::infinity();
411 }
412
413 let left_weight: F = (n_left as f64).as_();
414 let right_weight: F = (n_right as f64).as_();
415 let total: F = (n_total as f64).as_();
416
417 let left_ratio = left_weight / total;
418 let right_ratio = right_weight / total;
419
420 match (self.tree_type, self.criterion) {
421 (TreeType::Regression, SplitCriterion::Mse) => {
422 let left_mse = self.calculate_mse(target, left_indices);
424 let right_mse = self.calculate_mse(target, right_indices);
425 left_ratio * left_mse + right_ratio * right_mse
426 }
427 (TreeType::Regression, SplitCriterion::Mae) => {
428 let left_mae = self.calculate_mae(target, left_indices);
430 let right_mae = self.calculate_mae(target, right_indices);
431 left_ratio * left_mae + right_ratio * right_mae
432 }
433 (TreeType::Classification, SplitCriterion::Gini) => {
434 let left_gini = self.calculate_gini(target, left_indices);
436 let right_gini = self.calculate_gini(target, right_indices);
437 left_ratio * left_gini + right_ratio * right_gini
438 }
439 (TreeType::Classification, SplitCriterion::Entropy) => {
440 let left_entropy = self.calculate_entropy(target, left_indices);
442 let right_entropy = self.calculate_entropy(target, right_indices);
443 left_ratio * left_entropy + right_ratio * right_entropy
444 }
445 _ => panic!("Invalid combination of tree_type and criterion"),
446 }
447 }
448
449 fn calculate_mse(&self, target: &[T], indices: &[usize]) -> F {
451 if indices.is_empty() {
452 return F::zero();
453 }
454
455 let mean = self.calculate_mean(target, indices);
456 let mean_f: F = mean.as_();
457
458 let sum_squared_error: F = indices
459 .iter()
460 .map(|&idx| {
461 let error: F = target[idx].as_() - mean_f;
462 error * error
463 })
464 .fold(F::zero(), |a, b| a + b);
465
466 sum_squared_error / F::from(indices.len()).unwrap()
467 }
468
469 fn calculate_mae(&self, target: &[T], indices: &[usize]) -> F {
471 if indices.is_empty() {
472 return F::zero();
473 }
474
475 let mean = self.calculate_mean(target, indices);
476 let mean_f: F = mean.as_();
477
478 let sum_absolute_error: F = indices
479 .iter()
480 .map(|&idx| {
481 let error: F = target[idx].as_() - mean_f;
482 error.abs()
483 })
484 .fold(F::zero(), |a, b| a + b);
485
486 sum_absolute_error / F::from(indices.len()).unwrap()
487 }
488
489 fn calculate_gini(&self, target: &[T], indices: &[usize]) -> F {
491 if indices.is_empty() {
492 return F::zero();
493 }
494
495 let (_, class_counts) = self.calculate_class_distribution(target, indices);
496 let n_samples = indices.len();
497
498 let gini = F::one()
499 - class_counts
500 .values()
501 .map(|&count| {
502 let probability: F = (count as f64 / n_samples as f64).as_();
503 probability * probability
504 })
505 .fold(F::zero(), |a, b| a + b);
506
507 gini
508 }
509
510 fn calculate_entropy(&self, target: &[T], indices: &[usize]) -> F {
512 if indices.is_empty() {
513 return F::zero();
514 }
515
516 let (_, class_counts) = self.calculate_class_distribution(target, indices);
517 let n_samples = indices.len();
518
519 let entropy = -class_counts
520 .values()
521 .map(|&count| {
522 let probability: F = (count as f64 / n_samples as f64).as_();
523 if probability > F::zero() {
524 probability * probability.ln()
525 } else {
526 F::zero()
527 }
528 })
529 .fold(F::zero(), |a, b| a + b);
530
531 entropy
532 }
533
534 fn calculate_mean(&self, target: &[T], indices: &[usize]) -> T {
536 if indices.is_empty() {
537 return NumCast::from(0.0).unwrap_or_else(|| {
538 panic!("Failed to convert 0.0 to the target type");
539 });
540 }
541
542 let sum: F = indices
545 .iter()
546 .map(|&idx| target[idx].as_())
547 .fold(F::zero(), |a, b| a + b);
548
549 let count: F = F::from(indices.len()).unwrap();
550 let mean_f = sum / count;
551
552 NumCast::from(mean_f).unwrap_or_else(|| {
554 panic!("Failed to convert mean to the target type");
555 })
556 }
557
558 fn calculate_class_distribution(
560 &self,
561 target: &[T],
562 indices: &[usize],
563 ) -> (T, HashMap<T, usize>) {
564 let mut class_counts: HashMap<T, usize> = HashMap::new();
565
566 for &idx in indices {
567 let class = target[idx].clone();
568 *class_counts.entry(class).or_insert(0) += 1;
569 }
570
571 let (majority_class, _) = class_counts
573 .iter()
574 .max_by_key(|&(_, count)| *count)
575 .map(|(class, count)| (class.clone(), *count))
576 .unwrap_or_else(|| {
577 (NumCast::from(0.0).unwrap(), 0)
579 });
580
581 (majority_class, class_counts)
582 }
583
584 fn is_pure(&self, target: &[T], indices: &[usize]) -> bool {
586 if indices.is_empty() {
587 return true;
588 }
589
590 let first_value = &target[indices[0]];
591 indices.iter().all(|&idx| {
592 target[idx]
593 .partial_cmp(first_value)
594 .unwrap_or(Ordering::Equal)
595 == Ordering::Equal
596 })
597 }
598
599 pub fn predict<D>(&self, features: &[Vec<D>]) -> Vec<T>
601 where
602 D: Clone + PartialOrd + NumCast,
603 {
604 features
605 .iter()
606 .map(|feature_vec| self.predict_single(feature_vec))
607 .collect()
608 }
609
610 fn predict_single<D>(&self, features: &[D]) -> T
612 where
613 D: Clone + PartialOrd + NumCast,
614 T: NumCast,
615 {
616 if self.nodes.is_empty() {
617 panic!("Decision tree has not been trained yet");
618 }
619
620 let mut node_idx = 0;
621 loop {
622 let node = &self.nodes[node_idx];
623
624 if node.is_leaf() {
625 return node.value.as_ref().unwrap().clone();
626 }
627
628 let feature_idx = node.feature_idx.unwrap();
629 let threshold = node.threshold.as_ref().unwrap();
630
631 let feature_val = &features[feature_idx];
632
633 let threshold_d = D::from(threshold.clone())
636 .unwrap_or_else(|| panic!("Failed to convert threshold to feature type"));
637
638 let comparison = feature_val
639 .partial_cmp(&threshold_d)
640 .unwrap_or(Ordering::Equal);
641
642 if comparison != Ordering::Greater {
643 node_idx = node.left.unwrap();
644 } else {
645 node_idx = node.right.unwrap();
646 }
647 }
648 }
649
650 pub fn feature_importances(&self) -> Vec<F> {
652 if self.nodes.is_empty() {
653 return Vec::new();
654 }
655
656 let n_features = self
658 .nodes
659 .iter()
660 .find(|node| !node.is_leaf())
661 .and_then(|node| node.feature_idx)
662 .map(|idx| idx + 1)
663 .unwrap_or(0);
664
665 if n_features == 0 {
666 return Vec::new();
667 }
668
669 let mut feature_counts = vec![0; n_features];
671 for node in &self.nodes {
672 if let Some(feature_idx) = node.feature_idx {
673 feature_counts[feature_idx] += 1;
674 }
675 }
676
677 let total_count: f64 = feature_counts.iter().sum::<usize>() as f64;
679 if total_count > 0.0 {
680 feature_counts
681 .iter()
682 .map(|&count| (count as f64 / total_count).as_())
683 .collect()
684 } else {
685 vec![F::zero(); n_features]
686 }
687 }
688
689 pub fn tree_structure(&self) -> String {
691 if self.nodes.is_empty() {
692 return "Empty tree".to_string();
693 }
694
695 let mut result = String::new();
696 self.print_node(0, 0, &mut result);
697 result
698 }
699
700 fn print_node(&self, node_idx: usize, depth: usize, result: &mut String) {
702 let node = &self.nodes[node_idx];
703 let indent = " ".repeat(depth);
704
705 if node.is_leaf() {
706 if self.tree_type == TreeType::Classification {
707 let class_distribution = node.class_distribution.as_ref().unwrap();
708 let classes: Vec<String> = class_distribution
709 .iter()
710 .map(|(class, count)| format!("{:?}: {}", class, count))
711 .collect();
712
713 result.push_str(&format!(
714 "{}Leaf: prediction = {:?}, distribution = {{{}}}\n",
715 indent,
716 node.value.as_ref().unwrap(),
717 classes.join(", ")
718 ));
719 } else {
720 result.push_str(&format!(
721 "{}Leaf: prediction = {:?}\n",
722 indent,
723 node.value.as_ref().unwrap()
724 ));
725 }
726 } else {
727 result.push_str(&format!(
728 "{}Node: feature {} <= {:?}\n",
729 indent,
730 node.feature_idx.unwrap(),
731 node.threshold.as_ref().unwrap()
732 ));
733
734 if let Some(left_idx) = node.left {
735 self.print_node(left_idx, depth + 1, result);
736 }
737
738 if let Some(right_idx) = node.right {
739 self.print_node(right_idx, depth + 1, result);
740 }
741 }
742 }
743}
744
745impl<T, F> fmt::Display for DecisionTree<T, F>
746where
747 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
748 F: Float,
749{
750 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
751 write!(
752 f,
753 "DecisionTree({:?}, {:?}, max_depth={}, nodes={})",
754 self.tree_type,
755 self.criterion,
756 self.max_depth,
757 self.nodes.len()
758 )
759 }
760}
761
762impl<T, F> DecisionTree<T, F>
764where
765 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
766 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
767 f64: AsPrimitive<F>,
768 usize: AsPrimitive<F>,
769 T: AsPrimitive<F>,
770 F: AsPrimitive<T>,
771{
772 pub fn get_max_depth(&self) -> usize {
774 self.max_depth
775 }
776
777 pub fn get_node_count(&self) -> usize {
779 self.nodes.len()
780 }
781
782 pub fn is_trained(&self) -> bool {
784 !self.nodes.is_empty()
785 }
786
787 pub fn get_leaf_count(&self) -> usize {
789 self.nodes.iter().filter(|node| node.is_leaf()).count()
790 }
791
792 pub fn calculate_depth(&self) -> usize {
794 if self.nodes.is_empty() {
795 return 0;
796 }
797
798 fn depth_helper<T, F>(nodes: &[Node<T, F>], node_idx: usize, current_depth: usize) -> usize
800 where
801 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
802 F: Float,
803 {
804 let node = &nodes[node_idx];
805
806 if node.is_leaf() {
807 return current_depth;
808 }
809
810 let left_depth = depth_helper(nodes, node.left.unwrap(), current_depth + 1);
811 let right_depth = depth_helper(nodes, node.right.unwrap(), current_depth + 1);
812
813 std::cmp::max(left_depth, right_depth)
814 }
815
816 depth_helper(&self.nodes, 0, 0)
817 }
818
819 pub fn summary(&self) -> String {
821 if !self.is_trained() {
822 return "Decision tree is not trained yet".to_string();
823 }
824
825 let leaf_count = self.get_leaf_count();
826 let node_count = self.get_node_count();
827 let actual_depth = self.calculate_depth();
828
829 format!(
830 "Decision Tree Summary:\n\
831 - Type: {:?}\n\
832 - Criterion: {:?}\n\
833 - Max depth: {}\n\
834 - Actual depth: {}\n\
835 - Total nodes: {}\n\
836 - Leaf nodes: {}\n\
837 - Internal nodes: {}",
838 self.tree_type,
839 self.criterion,
840 self.max_depth,
841 actual_depth,
842 node_count,
843 leaf_count,
844 node_count - leaf_count
845 )
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852 use std::time::Duration;
853
854 #[derive(Clone, Debug, PartialOrd, Copy)]
856 struct TestFloat(f64);
857
858 impl PartialEq for TestFloat {
859 fn eq(&self, other: &Self) -> bool {
860 (self.0 - other.0).abs() < f64::EPSILON
861 }
862 }
863
864 impl Eq for TestFloat {}
865
866 impl std::hash::Hash for TestFloat {
867 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
868 let bits = self.0.to_bits();
869 bits.hash(state);
870 }
871 }
872
873 impl ToPrimitive for TestFloat {
874 fn to_i64(&self) -> Option<i64> {
875 self.0.to_i64()
876 }
877
878 fn to_u64(&self) -> Option<u64> {
879 self.0.to_u64()
880 }
881
882 fn to_f64(&self) -> Option<f64> {
883 Some(self.0)
884 }
885 }
886
887 impl NumCast for TestFloat {
888 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
889 n.to_f64().map(TestFloat)
890 }
891 }
892
893 impl FromPrimitive for TestFloat {
894 fn from_i64(n: i64) -> Option<Self> {
895 Some(TestFloat(n as f64))
896 }
897
898 fn from_u64(n: u64) -> Option<Self> {
899 Some(TestFloat(n as f64))
900 }
901
902 fn from_f64(n: f64) -> Option<Self> {
903 Some(TestFloat(n))
904 }
905 }
906
907 impl AsPrimitive<f64> for TestFloat {
908 fn as_(self) -> f64 {
909 self.0
910 }
911 }
912
913 impl AsPrimitive<TestFloat> for f64 {
914 fn as_(self) -> TestFloat {
915 TestFloat(self)
916 }
917 }
918
919 #[test]
921 fn test_diabetes_prediction() {
922 let mut tree = DecisionTree::<TestFloat, f64>::new(
924 TreeType::Regression,
925 SplitCriterion::Mse,
926 5, 2, 1, );
930
931 let features = vec![
933 vec![45.0, 22.5, 95.0, 120.0, 0.0], vec![50.0, 26.0, 105.0, 140.0, 1.0], vec![35.0, 23.0, 90.0, 115.0, 0.0], vec![55.0, 30.0, 140.0, 150.0, 1.0], vec![60.0, 29.5, 130.0, 145.0, 1.0], vec![40.0, 24.0, 85.0, 125.0, 0.0], vec![48.0, 27.0, 110.0, 135.0, 1.0], vec![65.0, 31.0, 150.0, 155.0, 1.0], vec![42.0, 25.0, 100.0, 130.0, 0.0], vec![58.0, 32.0, 145.0, 160.0, 1.0], ];
944
945 let target = vec![
947 TestFloat(2.0),
948 TestFloat(5.5),
949 TestFloat(1.5),
950 TestFloat(8.0),
951 TestFloat(6.5),
952 TestFloat(2.0),
953 TestFloat(5.0),
954 TestFloat(8.5),
955 TestFloat(3.0),
956 TestFloat(9.0),
957 ];
958
959 tree.fit(&features, &target);
961
962 let test_features = vec![
964 vec![45.0, 23.0, 90.0, 120.0, 0.0], vec![62.0, 31.0, 145.0, 155.0, 1.0], ];
967
968 let predictions = tree.predict(&test_features);
969
970 assert!(
972 predictions[0].0 < 5.0,
973 "Young healthy patient should have low risk score"
974 );
975 assert!(
976 predictions[1].0 > 5.0,
977 "Older patient with high metrics should have high risk score"
978 );
979
980 assert!(tree.is_trained());
982 assert!(tree.calculate_depth() <= tree.get_max_depth());
983 assert!(tree.get_leaf_count() > 0);
984
985 println!("Diabetes prediction tree:\n{}", tree.summary());
987 }
988
989 #[test]
991 fn test_disease_classification() {
992 let mut tree = DecisionTree::<u8, f64>::new(
994 TreeType::Classification,
995 SplitCriterion::Gini,
996 4, 2, 1, );
1000
1001 let features = vec![
1004 vec![3, 1, 2, 1, 0, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], vec![0, 3, 1, 0, 2, 2], vec![3, 2, 3, 2, 1, 0], vec![1, 3, 2, 0, 0, 3], vec![2, 0, 2, 3, 1, 0], vec![0, 2, 1, 0, 2, 2], vec![3, 1, 2, 1, 1, 0], vec![2, 3, 2, 0, 1, 2], vec![1, 0, 1, 3, 0, 0], vec![0, 3, 2, 0, 1, 3], ];
1017
1018 let target = vec![1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
1020
1021 tree.fit(&features, &target);
1023
1024 let test_features = vec![
1026 vec![3, 2, 2, 1, 1, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], ];
1030
1031 let predictions = tree.predict(&test_features);
1032
1033 assert_eq!(predictions[0], 1, "Should diagnose as Flu");
1035 assert_eq!(predictions[1], 2, "Should diagnose as COVID");
1036 assert_eq!(predictions[2], 3, "Should diagnose as Migraine");
1037
1038 println!("Disease classification tree:\n{}", tree.summary());
1040 }
1041
1042 #[test]
1043 fn test_system_failure_prediction() {
1044 let mut tree = DecisionTree::<i32, f64>::new(
1049 TreeType::Regression,
1050 SplitCriterion::Mse,
1051 2, 5, 2, );
1055
1056 let features = vec![
1059 vec![30, 40, 0],
1061 vec![35, 45, 1],
1062 vec![40, 50, 0],
1063 vec![25, 35, 1],
1064 vec![30, 40, 0],
1065 vec![90, 95, 10],
1067 vec![85, 90, 8],
1068 vec![95, 98, 15],
1069 vec![90, 95, 12],
1070 vec![80, 85, 7],
1071 ];
1072
1073 let target = vec![
1075 1000, 900, 950, 1100, 1050, 10, 15, 5, 8, 20, ];
1078
1079 tree.fit(&features, &target);
1081
1082 println!("System failure tree summary:\n{}", tree.summary());
1084
1085 if tree.is_trained() {
1087 println!("Tree structure:\n{}", tree.tree_structure());
1088 }
1089
1090 if tree.is_trained() {
1092 let test_features = vec![
1094 vec![30, 40, 0], vec![90, 95, 10], ];
1097
1098 let predictions = match std::panic::catch_unwind(|| tree.predict(&test_features)) {
1100 Ok(preds) => {
1101 println!("Successfully made predictions: {:?}", preds);
1102 preds
1103 }
1104 Err(_) => {
1105 println!("Error during prediction - likely an issue with tree node references");
1106 return; }
1108 };
1109
1110 if predictions.len() == 2 {
1112 assert!(
1113 predictions[0] > predictions[1],
1114 "Healthy system should have longer time to failure than failing system"
1115 );
1116 }
1117 } else {
1118 println!("Tree wasn't properly trained - skipping prediction tests");
1119 }
1120 }
1121
1122 #[test]
1124 fn test_security_incident_classification() {
1125 let mut tree = DecisionTree::<u8, f64>::new(
1127 TreeType::Classification,
1128 SplitCriterion::Entropy,
1129 5, 2, 1, );
1133
1134 let features = vec![
1136 vec![1, 0, 0, 0, 0], vec![5, 1, 1, 1, 0], vec![15, 3, 2, 1, 1], vec![2, 0, 1, 0, 0], vec![8, 2, 1, 1, 0], vec![20, 4, 3, 1, 1], vec![1, 0, 0, 1, 0], vec![6, 1, 2, 1, 0], vec![25, 5, 3, 1, 1], vec![3, 0, 0, 0, 0], vec![7, 2, 1, 0, 0], vec![18, 3, 2, 1, 1], vec![0, 0, 0, 0, 0], vec![9, 2, 2, 1, 0], vec![22, 4, 3, 1, 1], ];
1152
1153 let target = vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2];
1155
1156 tree.fit(&features, &target);
1158
1159 let test_features = vec![
1161 vec![2, 0, 0, 0, 0], vec![7, 1, 1, 1, 0], vec![17, 3, 2, 1, 1], ];
1165
1166 let predictions = tree.predict(&test_features);
1167
1168 assert_eq!(predictions[0], 0, "Should classify as normal activity");
1170 assert_eq!(predictions[1], 1, "Should classify as suspicious activity");
1171 assert_eq!(predictions[2], 2, "Should classify as potential breach");
1172
1173 println!(
1175 "Security incident classification tree:\n{}",
1176 tree.tree_structure()
1177 );
1178 }
1179
1180 #[test]
1182 fn test_custom_type_performance_analysis() {
1183 #[derive(Clone, PartialEq, Eq, Hash, Debug, Copy)]
1185 struct ResponseTime(Duration);
1186
1187 impl PartialOrd for ResponseTime {
1188 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1189 self.0.partial_cmp(&other.0)
1190 }
1191 }
1192
1193 impl ToPrimitive for ResponseTime {
1194 fn to_i64(&self) -> Option<i64> {
1195 Some(self.0.as_millis() as i64)
1196 }
1197
1198 fn to_u64(&self) -> Option<u64> {
1199 Some(self.0.as_millis() as u64)
1200 }
1201
1202 fn to_f64(&self) -> Option<f64> {
1203 Some(self.0.as_millis() as f64)
1204 }
1205 }
1206
1207 impl AsPrimitive<f64> for ResponseTime {
1208 fn as_(self) -> f64 {
1209 self.0.as_millis() as f64
1210 }
1211 }
1212
1213 impl NumCast for ResponseTime {
1214 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
1215 n.to_u64()
1216 .map(|ms| ResponseTime(Duration::from_millis(ms as u64)))
1217 }
1218 }
1219
1220 impl FromPrimitive for ResponseTime {
1221 fn from_i64(n: i64) -> Option<Self> {
1222 if n >= 0 {
1223 Some(ResponseTime(Duration::from_millis(n as u64)))
1224 } else {
1225 None
1226 }
1227 }
1228
1229 fn from_u64(n: u64) -> Option<Self> {
1230 Some(ResponseTime(Duration::from_millis(n)))
1231 }
1232
1233 fn from_f64(n: f64) -> Option<Self> {
1234 if n >= 0.0 {
1235 Some(ResponseTime(Duration::from_millis(n as u64)))
1236 } else {
1237 None
1238 }
1239 }
1240 }
1241
1242 impl AsPrimitive<ResponseTime> for f64 {
1244 fn as_(self) -> ResponseTime {
1245 ResponseTime(Duration::from_millis(self as u64))
1246 }
1247 }
1248
1249 let mut tree = DecisionTree::<ResponseTime, f64>::new(
1251 TreeType::Regression,
1252 SplitCriterion::Mse,
1253 3, 2, 1, );
1257
1258 let features = vec![
1260 vec![10, 20, 3, 5],
1261 vec![50, 40, 8, 2],
1262 vec![20, 30, 4, 4],
1263 vec![100, 60, 12, 0],
1264 vec![30, 35, 6, 3],
1265 vec![80, 50, 10, 1],
1266 ];
1267
1268 let target = vec![
1270 ResponseTime(Duration::from_millis(100)),
1271 ResponseTime(Duration::from_millis(350)),
1272 ResponseTime(Duration::from_millis(150)),
1273 ResponseTime(Duration::from_millis(600)),
1274 ResponseTime(Duration::from_millis(200)),
1275 ResponseTime(Duration::from_millis(450)),
1276 ];
1277
1278 tree.fit(&features, &target);
1280
1281 let test_features = vec![
1283 vec![15, 25, 3, 4], vec![90, 55, 11, 0], ];
1286
1287 let predictions = tree.predict(&test_features);
1288
1289 assert!(
1291 predictions[0].0.as_millis() < 200,
1292 "Small request should have fast response time"
1293 );
1294 assert!(
1295 predictions[1].0.as_millis() > 400,
1296 "Large request should have slow response time"
1297 );
1298
1299 println!("Response time prediction tree:\n{}", tree.summary());
1301 }
1302
1303 #[test]
1305 #[should_panic(expected = "Features cannot be empty")]
1306 fn test_empty_features() {
1307 let mut tree =
1308 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1309
1310 let empty_features: Vec<Vec<f64>> = vec![];
1312 let empty_target: Vec<i32> = vec![];
1313
1314 tree.fit(&empty_features, &empty_target);
1315 }
1316
1317 #[test]
1319 fn test_single_class_classification() {
1320 let mut tree =
1321 DecisionTree::<u8, f64>::new(TreeType::Classification, SplitCriterion::Gini, 3, 2, 1);
1322
1323 let features = vec![
1325 vec![1, 2, 3],
1326 vec![4, 5, 6],
1327 vec![7, 8, 9],
1328 vec![10, 11, 12],
1329 ];
1330
1331 let target = vec![1, 1, 1, 1];
1333
1334 tree.fit(&features, &target);
1336
1337 let prediction = tree.predict(&vec![vec![2, 3, 4]]);
1339
1340 assert_eq!(prediction[0], 1);
1342
1343 assert_eq!(tree.get_node_count(), 1);
1345 assert_eq!(tree.get_leaf_count(), 1);
1346 }
1347}