1use crate::error::{StatsError, StatsResult};
2use num_traits::cast::AsPrimitive;
3use num_traits::{Float, FromPrimitive, NumCast, ToPrimitive};
4#[cfg(feature = "parallel")]
5use rayon::prelude::*;
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::fmt::{self, Debug};
9use std::hash::Hash;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum TreeType {
14 Regression,
16 Classification,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum SplitCriterion {
23 Mse,
25 Mae,
27 Gini,
29 Entropy,
31}
32
33#[derive(Debug, Clone)]
35struct Node<T, F>
36where
37 T: Clone + PartialOrd + Debug + ToPrimitive,
38 F: Float,
39{
40 feature_idx: Option<usize>,
42 threshold: Option<T>,
44 value: Option<T>,
46 class_distribution: Option<HashMap<T, usize>>,
48 left: Option<usize>,
50 right: Option<usize>,
52 _phantom: std::marker::PhantomData<F>,
54}
55
56impl<T, F> Node<T, F>
57where
58 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
59 F: Float,
60{
61 fn new_split(feature_idx: usize, threshold: T) -> Self {
63 Node {
64 feature_idx: Some(feature_idx),
65 threshold: Some(threshold),
66 value: None,
67 class_distribution: None,
68 left: None,
69 right: None,
70 _phantom: std::marker::PhantomData,
71 }
72 }
73
74 fn new_leaf_regression(value: T) -> Self {
76 Node {
77 feature_idx: None,
78 threshold: None,
79 value: Some(value),
80 class_distribution: None,
81 left: None,
82 right: None,
83 _phantom: std::marker::PhantomData,
84 }
85 }
86
87 fn new_leaf_classification(value: T, class_distribution: HashMap<T, usize>) -> Self {
89 Node {
90 feature_idx: None,
91 threshold: None,
92 value: Some(value),
93 class_distribution: Some(class_distribution),
94 left: None,
95 right: None,
96 _phantom: std::marker::PhantomData,
97 }
98 }
99
100 fn is_leaf(&self) -> bool {
102 self.feature_idx.is_none()
103 }
104}
105
106#[derive(Debug, Clone)]
112pub struct DecisionTree<T, F>
113where
114 T: Clone + PartialOrd + Debug + ToPrimitive,
115 F: Float,
116{
117 tree_type: TreeType,
119 criterion: SplitCriterion,
121 max_depth: usize,
123 min_samples_split: usize,
125 min_samples_leaf: usize,
127 nodes: Vec<Node<T, F>>,
129}
130
131impl<T, F> DecisionTree<T, F>
132where
133 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
134 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
135 f64: AsPrimitive<F>,
136 usize: AsPrimitive<F>,
137 T: AsPrimitive<F>,
138 F: AsPrimitive<T>,
139{
140 pub fn new(
142 tree_type: TreeType,
143 criterion: SplitCriterion,
144 max_depth: usize,
145 min_samples_split: usize,
146 min_samples_leaf: usize,
147 ) -> Self {
148 Self {
149 tree_type,
150 criterion,
151 max_depth,
152 min_samples_split,
153 min_samples_leaf,
154 nodes: Vec::new(),
155 }
156 }
157
158 pub fn fit<D>(&mut self, features: &[Vec<D>], target: &[T]) -> StatsResult<()>
166 where
167 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
168 T: FromPrimitive,
169 {
170 if features.is_empty() {
171 return Err(StatsError::empty_data("Features cannot be empty"));
172 }
173 if target.is_empty() {
174 return Err(StatsError::empty_data("Target cannot be empty"));
175 }
176 if features.len() != target.len() {
177 return Err(StatsError::dimension_mismatch(format!(
178 "Features and target must have the same length (got {} and {})",
179 features.len(),
180 target.len()
181 )));
182 }
183
184 let n_features = features[0].len();
186 for (i, feature_vec) in features.iter().enumerate() {
187 if feature_vec.len() != n_features {
188 return Err(StatsError::invalid_input(format!(
189 "All feature vectors must have the same length (vector {} has {} features, expected {})",
190 i,
191 feature_vec.len(),
192 n_features
193 )));
194 }
195 }
196
197 self.nodes = Vec::new();
199
200 let indices: Vec<usize> = (0..features.len()).collect();
202
203 self.build_tree(features, target, &indices, 0)?;
205 Ok(())
206 }
207
208 fn build_tree<D>(
210 &mut self,
211 features: &[Vec<D>],
212 target: &[T],
213 indices: &[usize],
214 depth: usize,
215 ) -> StatsResult<usize>
216 where
217 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
218 {
219 if depth >= self.max_depth
221 || indices.len() < self.min_samples_split
222 || self.is_pure(target, indices)
223 {
224 let node_idx = self.nodes.len();
225 if self.tree_type == TreeType::Regression {
226 let value = self.calculate_mean(target, indices)?;
228 self.nodes.push(Node::new_leaf_regression(value));
229 } else {
230 let (value, class_counts) = self.calculate_class_distribution(target, indices);
232 self.nodes
233 .push(Node::new_leaf_classification(value, class_counts));
234 }
235 return Ok(node_idx);
236 }
237
238 let (feature_idx, threshold, left_indices, right_indices) =
240 self.find_best_split(features, target, indices);
241
242 if left_indices.is_empty() || right_indices.is_empty() {
244 let node_idx = self.nodes.len();
245 if self.tree_type == TreeType::Regression {
246 let value = self.calculate_mean(target, indices)?;
247 self.nodes.push(Node::new_leaf_regression(value));
248 } else {
249 let (value, class_counts) = self.calculate_class_distribution(target, indices);
250 self.nodes
251 .push(Node::new_leaf_classification(value, class_counts));
252 }
253 return Ok(node_idx);
254 }
255
256 let node_idx = self.nodes.len();
258
259 let t_threshold = NumCast::from(threshold).ok_or_else(|| {
261 StatsError::conversion_error(
262 "Failed to convert threshold to the feature type".to_string(),
263 )
264 })?;
265
266 self.nodes.push(Node::new_split(feature_idx, t_threshold));
267
268 let left_idx = self.build_tree(features, target, &left_indices, depth + 1)?;
270 let right_idx = self.build_tree(features, target, &right_indices, depth + 1)?;
271
272 self.nodes[node_idx].left = Some(left_idx);
274 self.nodes[node_idx].right = Some(right_idx);
275
276 Ok(node_idx)
277 }
278
279 fn find_best_split<D>(
281 &self,
282 features: &[Vec<D>],
283 target: &[T],
284 indices: &[usize],
285 ) -> (usize, D, Vec<usize>, Vec<usize>)
286 where
287 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
288 {
289 let n_features = features[0].len();
290
291 let mut best_impurity = F::infinity();
293 let mut best_feature = 0;
294 let mut best_threshold = features[indices[0]][0];
295 let mut best_left = Vec::new();
296 let mut best_right = Vec::new();
297
298 #[cfg(feature = "parallel")]
300 let iter = (0..n_features).into_par_iter();
301 #[cfg(not(feature = "parallel"))]
302 let iter = 0..n_features;
303
304 let results: Vec<_> = iter
305 .filter_map(|feature_idx| {
306 let mut feature_values: Vec<(usize, D)> = indices
308 .iter()
309 .map(|&idx| (idx, features[idx][feature_idx]))
310 .collect();
311
312 feature_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
314
315 let mut values: Vec<D> = Vec::new();
317 let mut prev_val: Option<&D> = None;
318
319 for (_, val) in &feature_values {
320 if prev_val.is_none()
321 || prev_val
322 .unwrap()
323 .partial_cmp(val)
324 .unwrap_or(Ordering::Equal)
325 != Ordering::Equal
326 {
327 values.push(*val);
328 prev_val = Some(val);
329 }
330 }
331
332 if values.len() <= 1 {
334 return None;
335 }
336
337 let mut feature_best_impurity = F::infinity();
339 let mut feature_best_threshold = values[0];
340 let mut feature_best_left = Vec::new();
341 let mut feature_best_right = Vec::new();
342
343 for i in 0..values.len() - 1 {
344 let val1: F = values[i].as_();
346 let val2: F = values[i + 1].as_();
347
348 let two = match F::from(2.0) {
350 Some(t) => t,
351 None => continue, };
353 let mid_value = (val1 + val2) / two;
354
355 let threshold = match NumCast::from(mid_value) {
357 Some(t) => t,
358 None => continue, };
360
361 let mut left_indices = Vec::new();
363 let mut right_indices = Vec::new();
364
365 for &idx in indices {
366 let feature_value = &features[idx][feature_idx];
367 if feature_value
368 .partial_cmp(&threshold)
369 .unwrap_or(Ordering::Equal)
370 != Ordering::Greater
371 {
372 left_indices.push(idx);
373 } else {
374 right_indices.push(idx);
375 }
376 }
377
378 if left_indices.len() < self.min_samples_leaf
380 || right_indices.len() < self.min_samples_leaf
381 {
382 continue;
383 }
384
385 let impurity =
387 self.calculate_split_impurity(target, &left_indices, &right_indices);
388
389 if impurity < feature_best_impurity {
391 feature_best_impurity = impurity;
392 feature_best_threshold = threshold;
393 feature_best_left = left_indices;
394 feature_best_right = right_indices;
395 }
396 }
397
398 if !feature_best_left.is_empty() && !feature_best_right.is_empty() {
400 Some((
401 feature_idx,
402 feature_best_impurity,
403 feature_best_threshold,
404 feature_best_left,
405 feature_best_right,
406 ))
407 } else {
408 None
409 }
410 })
411 .collect();
412
413 for (feature_idx, impurity, threshold, left, right) in results {
415 if impurity < best_impurity {
416 best_impurity = impurity;
417 best_feature = feature_idx;
418 best_threshold = threshold;
419 best_left = left;
420 best_right = right;
421 }
422 }
423
424 (best_feature, best_threshold, best_left, best_right)
425 }
426
427 fn calculate_split_impurity(
429 &self,
430 target: &[T],
431 left_indices: &[usize],
432 right_indices: &[usize],
433 ) -> F {
434 let n_left = left_indices.len();
435 let n_right = right_indices.len();
436 let n_total = n_left + n_right;
437
438 if n_left == 0 || n_right == 0 {
439 return F::infinity();
440 }
441
442 let left_weight: F = (n_left as f64).as_();
443 let right_weight: F = (n_right as f64).as_();
444 let total: F = (n_total as f64).as_();
445
446 let left_ratio = left_weight / total;
447 let right_ratio = right_weight / total;
448
449 match (self.tree_type, self.criterion) {
450 (TreeType::Regression, SplitCriterion::Mse) => {
451 let left_mse = self.calculate_mse(target, left_indices);
453 let right_mse = self.calculate_mse(target, right_indices);
454 left_ratio * left_mse + right_ratio * right_mse
455 }
456 (TreeType::Regression, SplitCriterion::Mae) => {
457 let left_mae = self.calculate_mae(target, left_indices);
459 let right_mae = self.calculate_mae(target, right_indices);
460 left_ratio * left_mae + right_ratio * right_mae
461 }
462 (TreeType::Classification, SplitCriterion::Gini) => {
463 let left_gini = self.calculate_gini(target, left_indices);
465 let right_gini = self.calculate_gini(target, right_indices);
466 left_ratio * left_gini + right_ratio * right_gini
467 }
468 (TreeType::Classification, SplitCriterion::Entropy) => {
469 let left_entropy = self.calculate_entropy(target, left_indices);
471 let right_entropy = self.calculate_entropy(target, right_indices);
472 left_ratio * left_entropy + right_ratio * right_entropy
473 }
474 _ => {
475 F::infinity()
478 }
479 }
480 }
481
482 fn calculate_mse(&self, target: &[T], indices: &[usize]) -> F {
484 if indices.is_empty() {
485 return F::zero();
486 }
487
488 let mean = match self.calculate_mean(target, indices) {
490 Ok(m) => m,
491 Err(_) => return F::infinity(),
492 };
493 let mean_f: F = mean.as_();
494
495 let sum_squared_error: F = indices
496 .iter()
497 .map(|&idx| {
498 let error: F = target[idx].as_() - mean_f;
499 error * error
500 })
501 .fold(F::zero(), |a, b| a + b);
502
503 let count = F::from(indices.len()).unwrap_or(F::one());
504 sum_squared_error / count
505 }
506
507 fn calculate_mae(&self, target: &[T], indices: &[usize]) -> F {
509 if indices.is_empty() {
510 return F::zero();
511 }
512
513 let mean = match self.calculate_mean(target, indices) {
515 Ok(m) => m,
516 Err(_) => return F::infinity(),
517 };
518 let mean_f: F = mean.as_();
519
520 let sum_absolute_error: F = indices
521 .iter()
522 .map(|&idx| {
523 let error: F = target[idx].as_() - mean_f;
524 error.abs()
525 })
526 .fold(F::zero(), |a, b| a + b);
527
528 let count = F::from(indices.len()).unwrap_or(F::one());
529 sum_absolute_error / count
530 }
531
532 fn calculate_gini(&self, target: &[T], indices: &[usize]) -> F {
534 if indices.is_empty() {
535 return F::zero();
536 }
537
538 let (_, class_counts) = self.calculate_class_distribution(target, indices);
539 let n_samples = indices.len();
540
541 F::one()
542 - class_counts
543 .values()
544 .map(|&count| {
545 let probability: F = (count as f64 / n_samples as f64).as_();
546 probability * probability
547 })
548 .fold(F::zero(), |a, b| a + b)
549 }
550
551 fn calculate_entropy(&self, target: &[T], indices: &[usize]) -> F {
553 if indices.is_empty() {
554 return F::zero();
555 }
556
557 let (_, class_counts) = self.calculate_class_distribution(target, indices);
558 let n_samples = indices.len();
559
560 -class_counts
561 .values()
562 .map(|&count| {
563 let probability: F = (count as f64 / n_samples as f64).as_();
564 if probability > F::zero() {
565 probability * probability.ln()
566 } else {
567 F::zero()
568 }
569 })
570 .fold(F::zero(), |a, b| a + b)
571 }
572
573 fn calculate_mean(&self, target: &[T], indices: &[usize]) -> StatsResult<T> {
575 if indices.is_empty() {
576 return Err(StatsError::empty_data(
577 "Cannot calculate mean for empty indices",
578 ));
579 }
580
581 let sum: F = indices
584 .iter()
585 .map(|&idx| target[idx].as_())
586 .fold(F::zero(), |a, b| a + b);
587
588 let count: F = F::from(indices.len()).ok_or_else(|| {
589 StatsError::conversion_error(format!("Failed to convert {} to type F", indices.len()))
590 })?;
591 let mean_f = sum / count;
592
593 NumCast::from(mean_f).ok_or_else(|| {
595 StatsError::conversion_error("Failed to convert mean to the target type".to_string())
596 })
597 }
598
599 fn calculate_class_distribution(
601 &self,
602 target: &[T],
603 indices: &[usize],
604 ) -> (T, HashMap<T, usize>) {
605 let mut class_counts: HashMap<T, usize> = HashMap::new();
606
607 for &idx in indices {
608 let class = target[idx];
609 *class_counts.entry(class).or_insert(0) += 1;
610 }
611
612 let (majority_class, _) = class_counts
614 .iter()
615 .max_by_key(|&(_, count)| *count)
616 .map(|(&class, count)| (class, *count))
617 .unwrap_or_else(|| {
618 (NumCast::from(0.0).unwrap(), 0)
620 });
621
622 (majority_class, class_counts)
623 }
624
625 fn is_pure(&self, target: &[T], indices: &[usize]) -> bool {
627 if indices.is_empty() {
628 return true;
629 }
630
631 let first_value = &target[indices[0]];
632 indices.iter().all(|&idx| {
633 target[idx]
634 .partial_cmp(first_value)
635 .unwrap_or(Ordering::Equal)
636 == Ordering::Equal
637 })
638 }
639
640 pub fn predict<D>(&self, features: &[Vec<D>]) -> StatsResult<Vec<T>>
646 where
647 D: Clone + PartialOrd + NumCast,
648 T: NumCast,
649 {
650 features
651 .iter()
652 .map(|feature_vec| self.predict_single(feature_vec))
653 .collect()
654 }
655
656 fn predict_single<D>(&self, features: &[D]) -> StatsResult<T>
658 where
659 D: Clone + PartialOrd + NumCast,
660 T: NumCast,
661 {
662 if self.nodes.is_empty() {
663 return Err(StatsError::not_fitted(
664 "Decision tree has not been trained yet",
665 ));
666 }
667
668 let mut node_idx = 0;
669 loop {
670 let node = &self.nodes[node_idx];
671
672 if node.is_leaf() {
673 return node
674 .value
675 .ok_or_else(|| StatsError::invalid_input("Leaf node missing value"));
676 }
677
678 let feature_idx = node
679 .feature_idx
680 .ok_or_else(|| StatsError::invalid_input("Internal node missing feature index"))?;
681 let threshold = node
682 .threshold
683 .as_ref()
684 .ok_or_else(|| StatsError::invalid_input("Internal node missing threshold"))?;
685
686 if feature_idx >= features.len() {
687 return Err(StatsError::index_out_of_bounds(format!(
688 "Feature index {} is out of bounds (features has {} elements)",
689 feature_idx,
690 features.len()
691 )));
692 }
693
694 let feature_val = &features[feature_idx];
695
696 let threshold_d = D::from(*threshold).ok_or_else(|| {
699 StatsError::conversion_error(format!(
700 "Failed to convert threshold {:?} to feature type",
701 threshold
702 ))
703 })?;
704
705 let comparison = feature_val
706 .partial_cmp(&threshold_d)
707 .unwrap_or(Ordering::Equal);
708
709 if comparison != Ordering::Greater {
710 node_idx = node
711 .left
712 .ok_or_else(|| StatsError::invalid_input("Internal node missing left child"))?;
713 } else {
714 node_idx = node.right.ok_or_else(|| {
715 StatsError::invalid_input("Internal node missing right child")
716 })?;
717 }
718 }
719 }
720
721 pub fn feature_importances(&self) -> Vec<F> {
723 if self.nodes.is_empty() {
724 return Vec::new();
725 }
726
727 let n_features = self
729 .nodes
730 .iter()
731 .find(|node| !node.is_leaf())
732 .and_then(|node| node.feature_idx)
733 .map(|idx| idx + 1)
734 .unwrap_or(0);
735
736 if n_features == 0 {
737 return Vec::new();
738 }
739
740 let mut feature_counts = vec![0; n_features];
742 for node in &self.nodes {
743 if let Some(feature_idx) = node.feature_idx {
744 feature_counts[feature_idx] += 1;
745 }
746 }
747
748 let total_count: f64 = feature_counts.iter().sum::<usize>() as f64;
750 if total_count > 0.0 {
751 feature_counts
752 .iter()
753 .map(|&count| (count as f64 / total_count).as_())
754 .collect()
755 } else {
756 vec![F::zero(); n_features]
757 }
758 }
759
760 pub fn tree_structure(&self) -> String {
762 if self.nodes.is_empty() {
763 return "Empty tree".to_string();
764 }
765
766 let mut result = String::new();
767 self.print_node(0, 0, &mut result);
768 result
769 }
770
771 fn print_node(&self, node_idx: usize, depth: usize, result: &mut String) {
773 let node = &self.nodes[node_idx];
774 let indent = " ".repeat(depth);
775
776 if node.is_leaf() {
777 if self.tree_type == TreeType::Classification {
778 let class_distribution = node.class_distribution.as_ref().unwrap();
779 let classes: Vec<String> = class_distribution
780 .iter()
781 .map(|(class, count)| format!("{:?}: {}", class, count))
782 .collect();
783
784 result.push_str(&format!(
785 "{}Leaf: prediction = {:?}, distribution = {{{}}}\n",
786 indent,
787 node.value.as_ref().unwrap(),
788 classes.join(", ")
789 ));
790 } else {
791 result.push_str(&format!(
792 "{}Leaf: prediction = {:?}\n",
793 indent,
794 node.value.as_ref().unwrap()
795 ));
796 }
797 } else {
798 result.push_str(&format!(
799 "{}Node: feature {} <= {:?}\n",
800 indent,
801 node.feature_idx.unwrap(),
802 node.threshold.as_ref().unwrap()
803 ));
804
805 if let Some(left_idx) = node.left {
806 self.print_node(left_idx, depth + 1, result);
807 }
808
809 if let Some(right_idx) = node.right {
810 self.print_node(right_idx, depth + 1, result);
811 }
812 }
813 }
814}
815
816impl<T, F> fmt::Display for DecisionTree<T, F>
817where
818 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
819 F: Float,
820{
821 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
822 write!(
823 f,
824 "DecisionTree({:?}, {:?}, max_depth={}, nodes={})",
825 self.tree_type,
826 self.criterion,
827 self.max_depth,
828 self.nodes.len()
829 )
830 }
831}
832
833impl<T, F> DecisionTree<T, F>
835where
836 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
837 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
838 f64: AsPrimitive<F>,
839 usize: AsPrimitive<F>,
840 T: AsPrimitive<F>,
841 F: AsPrimitive<T>,
842{
843 pub fn get_max_depth(&self) -> usize {
845 self.max_depth
846 }
847
848 pub fn get_node_count(&self) -> usize {
850 self.nodes.len()
851 }
852
853 pub fn is_trained(&self) -> bool {
855 !self.nodes.is_empty()
856 }
857
858 pub fn get_leaf_count(&self) -> usize {
860 self.nodes.iter().filter(|node| node.is_leaf()).count()
861 }
862
863 pub fn calculate_depth(&self) -> usize {
865 if self.nodes.is_empty() {
866 return 0;
867 }
868
869 fn depth_helper<T, F>(nodes: &[Node<T, F>], node_idx: usize, current_depth: usize) -> usize
871 where
872 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
873 F: Float,
874 {
875 let node = &nodes[node_idx];
876
877 if node.is_leaf() {
878 return current_depth;
879 }
880
881 let left_depth = depth_helper(nodes, node.left.unwrap(), current_depth + 1);
882 let right_depth = depth_helper(nodes, node.right.unwrap(), current_depth + 1);
883
884 std::cmp::max(left_depth, right_depth)
885 }
886
887 depth_helper(&self.nodes, 0, 0)
888 }
889
890 pub fn summary(&self) -> String {
892 if !self.is_trained() {
893 return "Decision tree is not trained yet".to_string();
894 }
895
896 let leaf_count = self.get_leaf_count();
897 let node_count = self.get_node_count();
898 let actual_depth = self.calculate_depth();
899
900 format!(
901 "Decision Tree Summary:\n\
902 - Type: {:?}\n\
903 - Criterion: {:?}\n\
904 - Max depth: {}\n\
905 - Actual depth: {}\n\
906 - Total nodes: {}\n\
907 - Leaf nodes: {}\n\
908 - Internal nodes: {}",
909 self.tree_type,
910 self.criterion,
911 self.max_depth,
912 actual_depth,
913 node_count,
914 leaf_count,
915 node_count - leaf_count
916 )
917 }
918}
919
920#[cfg(test)]
921mod tests {
922 use super::*;
923 use std::time::Duration;
924
925 #[derive(Clone, Debug, PartialOrd, Copy)]
927 struct TestFloat(f64);
928
929 impl PartialEq for TestFloat {
930 fn eq(&self, other: &Self) -> bool {
931 (self.0 - other.0).abs() < f64::EPSILON
932 }
933 }
934
935 impl Eq for TestFloat {}
936
937 impl std::hash::Hash for TestFloat {
938 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
939 let bits = self.0.to_bits();
940 bits.hash(state);
941 }
942 }
943
944 impl ToPrimitive for TestFloat {
945 fn to_i64(&self) -> Option<i64> {
946 self.0.to_i64()
947 }
948
949 fn to_u64(&self) -> Option<u64> {
950 self.0.to_u64()
951 }
952
953 fn to_f64(&self) -> Option<f64> {
954 Some(self.0)
955 }
956 }
957
958 impl NumCast for TestFloat {
959 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
960 n.to_f64().map(TestFloat)
961 }
962 }
963
964 impl FromPrimitive for TestFloat {
965 fn from_i64(n: i64) -> Option<Self> {
966 Some(TestFloat(n as f64))
967 }
968
969 fn from_u64(n: u64) -> Option<Self> {
970 Some(TestFloat(n as f64))
971 }
972
973 fn from_f64(n: f64) -> Option<Self> {
974 Some(TestFloat(n))
975 }
976 }
977
978 impl AsPrimitive<f64> for TestFloat {
979 fn as_(self) -> f64 {
980 self.0
981 }
982 }
983
984 impl AsPrimitive<TestFloat> for f64 {
985 fn as_(self) -> TestFloat {
986 TestFloat(self)
987 }
988 }
989
990 #[test]
992 fn test_diabetes_prediction() {
993 let mut tree = DecisionTree::<TestFloat, f64>::new(
995 TreeType::Regression,
996 SplitCriterion::Mse,
997 5, 2, 1, );
1001
1002 let features = vec![
1004 vec![45.0, 22.5, 95.0, 120.0, 0.0], vec![50.0, 26.0, 105.0, 140.0, 1.0], vec![35.0, 23.0, 90.0, 115.0, 0.0], vec![55.0, 30.0, 140.0, 150.0, 1.0], vec![60.0, 29.5, 130.0, 145.0, 1.0], vec![40.0, 24.0, 85.0, 125.0, 0.0], vec![48.0, 27.0, 110.0, 135.0, 1.0], vec![65.0, 31.0, 150.0, 155.0, 1.0], vec![42.0, 25.0, 100.0, 130.0, 0.0], vec![58.0, 32.0, 145.0, 160.0, 1.0], ];
1015
1016 let target = vec![
1018 TestFloat(2.0),
1019 TestFloat(5.5),
1020 TestFloat(1.5),
1021 TestFloat(8.0),
1022 TestFloat(6.5),
1023 TestFloat(2.0),
1024 TestFloat(5.0),
1025 TestFloat(8.5),
1026 TestFloat(3.0),
1027 TestFloat(9.0),
1028 ];
1029
1030 tree.fit(&features, &target).unwrap();
1032
1033 let test_features = vec![
1035 vec![45.0, 23.0, 90.0, 120.0, 0.0], vec![62.0, 31.0, 145.0, 155.0, 1.0], ];
1038
1039 let predictions = tree.predict(&test_features).unwrap();
1040
1041 assert!(
1043 predictions[0].0 < 5.0,
1044 "Young healthy patient should have low risk score"
1045 );
1046 assert!(
1047 predictions[1].0 > 5.0,
1048 "Older patient with high metrics should have high risk score"
1049 );
1050
1051 assert!(tree.is_trained());
1053 assert!(tree.calculate_depth() <= tree.get_max_depth());
1054 assert!(tree.get_leaf_count() > 0);
1055
1056 println!("Diabetes prediction tree:\n{}", tree.summary());
1058 }
1059
1060 #[test]
1062 fn test_disease_classification() {
1063 let mut tree = DecisionTree::<u8, f64>::new(
1065 TreeType::Classification,
1066 SplitCriterion::Gini,
1067 4, 2, 1, );
1071
1072 let features = vec![
1075 vec![3, 1, 2, 1, 0, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], vec![0, 3, 1, 0, 2, 2], vec![3, 2, 3, 2, 1, 0], vec![1, 3, 2, 0, 0, 3], vec![2, 0, 2, 3, 1, 0], vec![0, 2, 1, 0, 2, 2], vec![3, 1, 2, 1, 1, 0], vec![2, 3, 2, 0, 1, 2], vec![1, 0, 1, 3, 0, 0], vec![0, 3, 2, 0, 1, 3], ];
1088
1089 let target = vec![1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
1091
1092 tree.fit(&features, &target).unwrap();
1094
1095 let test_features = vec![
1097 vec![3, 2, 2, 1, 1, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], ];
1101
1102 let predictions = tree.predict(&test_features).unwrap();
1103
1104 assert_eq!(predictions[0], 1, "Should diagnose as Flu");
1106 assert_eq!(predictions[1], 2, "Should diagnose as COVID");
1107 assert_eq!(predictions[2], 3, "Should diagnose as Migraine");
1108
1109 println!("Disease classification tree:\n{}", tree.summary());
1111 }
1112
1113 #[test]
1114 fn test_system_failure_prediction() {
1115 let mut tree = DecisionTree::<i32, f64>::new(
1120 TreeType::Regression,
1121 SplitCriterion::Mse,
1122 2, 5, 2, );
1126
1127 let features = vec![
1130 vec![30, 40, 0],
1132 vec![35, 45, 1],
1133 vec![40, 50, 0],
1134 vec![25, 35, 1],
1135 vec![30, 40, 0],
1136 vec![90, 95, 10],
1138 vec![85, 90, 8],
1139 vec![95, 98, 15],
1140 vec![90, 95, 12],
1141 vec![80, 85, 7],
1142 ];
1143
1144 let target = vec![
1146 1000, 900, 950, 1100, 1050, 10, 15, 5, 8, 20, ];
1149
1150 tree.fit(&features, &target).unwrap();
1152
1153 println!("System failure tree summary:\n{}", tree.summary());
1155
1156 if tree.is_trained() {
1158 println!("Tree structure:\n{}", tree.tree_structure());
1159 }
1160
1161 if tree.is_trained() {
1163 let test_features = vec![
1165 vec![30, 40, 0], vec![90, 95, 10], ];
1168
1169 let predictions = match tree.predict(&test_features) {
1171 Ok(preds) => {
1172 println!("Successfully made predictions: {:?}", preds);
1173 preds
1174 }
1175 Err(e) => {
1176 println!("Error during prediction: {:?}", e);
1177 return; }
1179 };
1180
1181 if predictions.len() == 2 {
1183 assert!(
1184 predictions[0] > predictions[1],
1185 "Healthy system should have longer time to failure than failing system"
1186 );
1187 }
1188 } else {
1189 println!("Tree wasn't properly trained - skipping prediction tests");
1190 }
1191 }
1192
1193 #[test]
1195 fn test_security_incident_classification() {
1196 let mut tree = DecisionTree::<u8, f64>::new(
1198 TreeType::Classification,
1199 SplitCriterion::Entropy,
1200 5, 2, 1, );
1204
1205 let features = vec![
1207 vec![1, 0, 0, 0, 0], vec![5, 1, 1, 1, 0], vec![15, 3, 2, 1, 1], vec![2, 0, 1, 0, 0], vec![8, 2, 1, 1, 0], vec![20, 4, 3, 1, 1], vec![1, 0, 0, 1, 0], vec![6, 1, 2, 1, 0], vec![25, 5, 3, 1, 1], vec![3, 0, 0, 0, 0], vec![7, 2, 1, 0, 0], vec![18, 3, 2, 1, 1], vec![0, 0, 0, 0, 0], vec![9, 2, 2, 1, 0], vec![22, 4, 3, 1, 1], ];
1223
1224 let target = vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2];
1226
1227 tree.fit(&features, &target).unwrap();
1229
1230 let test_features = vec![
1232 vec![2, 0, 0, 0, 0], vec![7, 1, 1, 1, 0], vec![17, 3, 2, 1, 1], ];
1236
1237 let predictions = tree.predict(&test_features).unwrap();
1238
1239 assert_eq!(predictions[0], 0, "Should classify as normal activity");
1241 assert_eq!(predictions[1], 1, "Should classify as suspicious activity");
1242 assert_eq!(predictions[2], 2, "Should classify as potential breach");
1243
1244 println!(
1246 "Security incident classification tree:\n{}",
1247 tree.tree_structure()
1248 );
1249 }
1250
1251 #[test]
1253 fn test_custom_type_performance_analysis() {
1254 #[derive(Clone, PartialEq, Eq, Hash, Debug, Copy)]
1256 struct ResponseTime(Duration);
1257
1258 impl PartialOrd for ResponseTime {
1259 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1260 self.0.partial_cmp(&other.0)
1261 }
1262 }
1263
1264 impl ToPrimitive for ResponseTime {
1265 fn to_i64(&self) -> Option<i64> {
1266 Some(self.0.as_millis() as i64)
1267 }
1268
1269 fn to_u64(&self) -> Option<u64> {
1270 Some(self.0.as_millis() as u64)
1271 }
1272
1273 fn to_f64(&self) -> Option<f64> {
1274 Some(self.0.as_millis() as f64)
1275 }
1276 }
1277
1278 impl AsPrimitive<f64> for ResponseTime {
1279 fn as_(self) -> f64 {
1280 self.0.as_millis() as f64
1281 }
1282 }
1283
1284 impl NumCast for ResponseTime {
1285 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
1286 n.to_u64()
1287 .map(|ms| ResponseTime(Duration::from_millis(ms as u64)))
1288 }
1289 }
1290
1291 impl FromPrimitive for ResponseTime {
1292 fn from_i64(n: i64) -> Option<Self> {
1293 if n >= 0 {
1294 Some(ResponseTime(Duration::from_millis(n as u64)))
1295 } else {
1296 None
1297 }
1298 }
1299
1300 fn from_u64(n: u64) -> Option<Self> {
1301 Some(ResponseTime(Duration::from_millis(n)))
1302 }
1303
1304 fn from_f64(n: f64) -> Option<Self> {
1305 if n >= 0.0 {
1306 Some(ResponseTime(Duration::from_millis(n as u64)))
1307 } else {
1308 None
1309 }
1310 }
1311 }
1312
1313 impl AsPrimitive<ResponseTime> for f64 {
1315 fn as_(self) -> ResponseTime {
1316 ResponseTime(Duration::from_millis(self as u64))
1317 }
1318 }
1319
1320 let mut tree = DecisionTree::<ResponseTime, f64>::new(
1322 TreeType::Regression,
1323 SplitCriterion::Mse,
1324 3, 2, 1, );
1328
1329 let features = vec![
1331 vec![10, 20, 3, 5],
1332 vec![50, 40, 8, 2],
1333 vec![20, 30, 4, 4],
1334 vec![100, 60, 12, 0],
1335 vec![30, 35, 6, 3],
1336 vec![80, 50, 10, 1],
1337 ];
1338
1339 let target = vec![
1341 ResponseTime(Duration::from_millis(100)),
1342 ResponseTime(Duration::from_millis(350)),
1343 ResponseTime(Duration::from_millis(150)),
1344 ResponseTime(Duration::from_millis(600)),
1345 ResponseTime(Duration::from_millis(200)),
1346 ResponseTime(Duration::from_millis(450)),
1347 ];
1348
1349 tree.fit(&features, &target).unwrap();
1351
1352 let test_features = vec![
1354 vec![15, 25, 3, 4], vec![90, 55, 11, 0], ];
1357
1358 let predictions = tree.predict(&test_features).unwrap();
1359
1360 assert!(
1362 predictions[0].0.as_millis() < 200,
1363 "Small request should have fast response time"
1364 );
1365 assert!(
1366 predictions[1].0.as_millis() > 400,
1367 "Large request should have slow response time"
1368 );
1369
1370 println!("Response time prediction tree:\n{}", tree.summary());
1372 }
1373
1374 #[test]
1376 fn test_empty_features() {
1377 let mut tree =
1378 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1379
1380 let empty_features: Vec<Vec<f64>> = vec![];
1382 let empty_target: Vec<i32> = vec![];
1383
1384 let result = tree.fit(&empty_features, &empty_target);
1385 assert!(
1386 result.is_err(),
1387 "Fitting with empty features should return an error"
1388 );
1389 }
1390
1391 #[test]
1393 fn test_single_class_classification() {
1394 let mut tree =
1395 DecisionTree::<u8, f64>::new(TreeType::Classification, SplitCriterion::Gini, 3, 2, 1);
1396
1397 let features = vec![
1399 vec![1, 2, 3],
1400 vec![4, 5, 6],
1401 vec![7, 8, 9],
1402 vec![10, 11, 12],
1403 ];
1404
1405 let target = vec![1, 1, 1, 1];
1407
1408 tree.fit(&features, &target).unwrap();
1410
1411 let prediction = tree.predict(&vec![vec![2, 3, 4]]).unwrap();
1413
1414 assert_eq!(prediction[0], 1);
1416
1417 assert_eq!(tree.get_node_count(), 1);
1419 assert_eq!(tree.get_leaf_count(), 1);
1420 }
1421
1422 #[test]
1423 fn test_predict_not_fitted() {
1424 let tree =
1426 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1427 let features = vec![vec![1.0, 2.0]];
1428 let result = tree.predict(&features);
1429 assert!(result.is_err());
1430 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
1431 }
1432
1433 #[test]
1434 fn test_fit_target_empty() {
1435 let mut tree =
1436 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1437 let features = vec![vec![1.0, 2.0]];
1438 let target: Vec<i32> = vec![];
1439 let result = tree.fit(&features, &target);
1440 assert!(result.is_err());
1441 assert!(matches!(result.unwrap_err(), StatsError::EmptyData { .. }));
1442 }
1443
1444 #[test]
1445 fn test_fit_length_mismatch() {
1446 let mut tree =
1447 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1448 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1449 let target = vec![1]; let result = tree.fit(&features, &target);
1451 assert!(result.is_err());
1452 assert!(matches!(
1453 result.unwrap_err(),
1454 StatsError::DimensionMismatch { .. }
1455 ));
1456 }
1457
1458 #[test]
1459 fn test_fit_inconsistent_feature_lengths() {
1460 let mut tree =
1461 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1462 let features = vec![vec![1.0, 2.0], vec![3.0]]; let target = vec![1, 2];
1464 let result = tree.fit(&features, &target);
1465 assert!(result.is_err());
1466 assert!(matches!(
1467 result.unwrap_err(),
1468 StatsError::InvalidInput { .. }
1469 ));
1470 }
1471}