1use crate::error::{StatsError, StatsResult};
2use num_traits::cast::AsPrimitive;
3use num_traits::{Float, FromPrimitive, NumCast, ToPrimitive};
4use rayon::prelude::*;
5use std::cmp::Ordering;
6use std::collections::HashMap;
7use std::fmt::{self, Debug};
8use std::hash::Hash;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum TreeType {
13 Regression,
15 Classification,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum SplitCriterion {
22 Mse,
24 Mae,
26 Gini,
28 Entropy,
30}
31
32#[derive(Debug, Clone)]
34struct Node<T, F>
35where
36 T: Clone + PartialOrd + Debug + ToPrimitive,
37 F: Float,
38{
39 feature_idx: Option<usize>,
41 threshold: Option<T>,
43 value: Option<T>,
45 class_distribution: Option<HashMap<T, usize>>,
47 left: Option<usize>,
49 right: Option<usize>,
51 _phantom: std::marker::PhantomData<F>,
53}
54
55impl<T, F> Node<T, F>
56where
57 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
58 F: Float,
59{
60 fn new_split(feature_idx: usize, threshold: T) -> Self {
62 Node {
63 feature_idx: Some(feature_idx),
64 threshold: Some(threshold),
65 value: None,
66 class_distribution: None,
67 left: None,
68 right: None,
69 _phantom: std::marker::PhantomData,
70 }
71 }
72
73 fn new_leaf_regression(value: T) -> Self {
75 Node {
76 feature_idx: None,
77 threshold: None,
78 value: Some(value),
79 class_distribution: None,
80 left: None,
81 right: None,
82 _phantom: std::marker::PhantomData,
83 }
84 }
85
86 fn new_leaf_classification(value: T, class_distribution: HashMap<T, usize>) -> Self {
88 Node {
89 feature_idx: None,
90 threshold: None,
91 value: Some(value),
92 class_distribution: Some(class_distribution),
93 left: None,
94 right: None,
95 _phantom: std::marker::PhantomData,
96 }
97 }
98
99 fn is_leaf(&self) -> bool {
101 self.feature_idx.is_none()
102 }
103}
104
105#[derive(Debug, Clone)]
111pub struct DecisionTree<T, F>
112where
113 T: Clone + PartialOrd + Debug + ToPrimitive,
114 F: Float,
115{
116 tree_type: TreeType,
118 criterion: SplitCriterion,
120 max_depth: usize,
122 min_samples_split: usize,
124 min_samples_leaf: usize,
126 nodes: Vec<Node<T, F>>,
128}
129
130impl<T, F> DecisionTree<T, F>
131where
132 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
133 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
134 f64: AsPrimitive<F>,
135 usize: AsPrimitive<F>,
136 T: AsPrimitive<F>,
137 F: AsPrimitive<T>,
138{
139 pub fn new(
141 tree_type: TreeType,
142 criterion: SplitCriterion,
143 max_depth: usize,
144 min_samples_split: usize,
145 min_samples_leaf: usize,
146 ) -> Self {
147 Self {
148 tree_type,
149 criterion,
150 max_depth,
151 min_samples_split,
152 min_samples_leaf,
153 nodes: Vec::new(),
154 }
155 }
156
157 pub fn fit<D>(&mut self, features: &[Vec<D>], target: &[T]) -> StatsResult<()>
165 where
166 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
167 T: FromPrimitive,
168 {
169 if features.is_empty() {
170 return Err(StatsError::empty_data("Features cannot be empty"));
171 }
172 if target.is_empty() {
173 return Err(StatsError::empty_data("Target cannot be empty"));
174 }
175 if features.len() != target.len() {
176 return Err(StatsError::dimension_mismatch(format!(
177 "Features and target must have the same length (got {} and {})",
178 features.len(),
179 target.len()
180 )));
181 }
182
183 let n_features = features[0].len();
185 for (i, feature_vec) in features.iter().enumerate() {
186 if feature_vec.len() != n_features {
187 return Err(StatsError::invalid_input(format!(
188 "All feature vectors must have the same length (vector {} has {} features, expected {})",
189 i,
190 feature_vec.len(),
191 n_features
192 )));
193 }
194 }
195
196 self.nodes = Vec::new();
198
199 let indices: Vec<usize> = (0..features.len()).collect();
201
202 self.build_tree(features, target, &indices, 0)?;
204 Ok(())
205 }
206
207 fn build_tree<D>(
209 &mut self,
210 features: &[Vec<D>],
211 target: &[T],
212 indices: &[usize],
213 depth: usize,
214 ) -> StatsResult<usize>
215 where
216 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
217 {
218 if depth >= self.max_depth
220 || indices.len() < self.min_samples_split
221 || self.is_pure(target, indices)
222 {
223 let node_idx = self.nodes.len();
224 if self.tree_type == TreeType::Regression {
225 let value = self.calculate_mean(target, indices)?;
227 self.nodes.push(Node::new_leaf_regression(value));
228 } else {
229 let (value, class_counts) = self.calculate_class_distribution(target, indices);
231 self.nodes
232 .push(Node::new_leaf_classification(value, class_counts));
233 }
234 return Ok(node_idx);
235 }
236
237 let (feature_idx, threshold, left_indices, right_indices) =
239 self.find_best_split(features, target, indices);
240
241 if left_indices.is_empty() || right_indices.is_empty() {
243 let node_idx = self.nodes.len();
244 if self.tree_type == TreeType::Regression {
245 let value = self.calculate_mean(target, indices)?;
246 self.nodes.push(Node::new_leaf_regression(value));
247 } else {
248 let (value, class_counts) = self.calculate_class_distribution(target, indices);
249 self.nodes
250 .push(Node::new_leaf_classification(value, class_counts));
251 }
252 return Ok(node_idx);
253 }
254
255 let node_idx = self.nodes.len();
257
258 let t_threshold = NumCast::from(threshold).ok_or_else(|| {
260 StatsError::conversion_error(
261 "Failed to convert threshold to the feature type".to_string(),
262 )
263 })?;
264
265 self.nodes.push(Node::new_split(feature_idx, t_threshold));
266
267 let left_idx = self.build_tree(features, target, &left_indices, depth + 1)?;
269 let right_idx = self.build_tree(features, target, &right_indices, depth + 1)?;
270
271 self.nodes[node_idx].left = Some(left_idx);
273 self.nodes[node_idx].right = Some(right_idx);
274
275 Ok(node_idx)
276 }
277
278 fn find_best_split<D>(
280 &self,
281 features: &[Vec<D>],
282 target: &[T],
283 indices: &[usize],
284 ) -> (usize, D, Vec<usize>, Vec<usize>)
285 where
286 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
287 {
288 let n_features = features[0].len();
289
290 let mut best_impurity = F::infinity();
292 let mut best_feature = 0;
293 let mut best_threshold = features[indices[0]][0];
294 let mut best_left = Vec::new();
295 let mut best_right = Vec::new();
296
297 let results: Vec<_> = (0..n_features)
299 .into_par_iter()
300 .filter_map(|feature_idx| {
301 let mut feature_values: Vec<(usize, D)> = indices
303 .iter()
304 .map(|&idx| (idx, features[idx][feature_idx]))
305 .collect();
306
307 feature_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
309
310 let mut values: Vec<D> = Vec::new();
312 let mut prev_val: Option<&D> = None;
313
314 for (_, val) in &feature_values {
315 if prev_val.is_none()
316 || prev_val
317 .unwrap()
318 .partial_cmp(val)
319 .unwrap_or(Ordering::Equal)
320 != Ordering::Equal
321 {
322 values.push(*val);
323 prev_val = Some(val);
324 }
325 }
326
327 if values.len() <= 1 {
329 return None;
330 }
331
332 let mut feature_best_impurity = F::infinity();
334 let mut feature_best_threshold = values[0];
335 let mut feature_best_left = Vec::new();
336 let mut feature_best_right = Vec::new();
337
338 for i in 0..values.len() - 1 {
339 let val1: F = values[i].as_();
341 let val2: F = values[i + 1].as_();
342
343 let two = match F::from(2.0) {
345 Some(t) => t,
346 None => continue, };
348 let mid_value = (val1 + val2) / two;
349
350 let threshold = match NumCast::from(mid_value) {
352 Some(t) => t,
353 None => continue, };
355
356 let mut left_indices = Vec::new();
358 let mut right_indices = Vec::new();
359
360 for &idx in indices {
361 let feature_value = &features[idx][feature_idx];
362 if feature_value
363 .partial_cmp(&threshold)
364 .unwrap_or(Ordering::Equal)
365 != Ordering::Greater
366 {
367 left_indices.push(idx);
368 } else {
369 right_indices.push(idx);
370 }
371 }
372
373 if left_indices.len() < self.min_samples_leaf
375 || right_indices.len() < self.min_samples_leaf
376 {
377 continue;
378 }
379
380 let impurity =
382 self.calculate_split_impurity(target, &left_indices, &right_indices);
383
384 if impurity < feature_best_impurity {
386 feature_best_impurity = impurity;
387 feature_best_threshold = threshold;
388 feature_best_left = left_indices;
389 feature_best_right = right_indices;
390 }
391 }
392
393 if !feature_best_left.is_empty() && !feature_best_right.is_empty() {
395 Some((
396 feature_idx,
397 feature_best_impurity,
398 feature_best_threshold,
399 feature_best_left,
400 feature_best_right,
401 ))
402 } else {
403 None
404 }
405 })
406 .collect();
407
408 for (feature_idx, impurity, threshold, left, right) in results {
410 if impurity < best_impurity {
411 best_impurity = impurity;
412 best_feature = feature_idx;
413 best_threshold = threshold;
414 best_left = left;
415 best_right = right;
416 }
417 }
418
419 (best_feature, best_threshold, best_left, best_right)
420 }
421
422 fn calculate_split_impurity(
424 &self,
425 target: &[T],
426 left_indices: &[usize],
427 right_indices: &[usize],
428 ) -> F {
429 let n_left = left_indices.len();
430 let n_right = right_indices.len();
431 let n_total = n_left + n_right;
432
433 if n_left == 0 || n_right == 0 {
434 return F::infinity();
435 }
436
437 let left_weight: F = (n_left as f64).as_();
438 let right_weight: F = (n_right as f64).as_();
439 let total: F = (n_total as f64).as_();
440
441 let left_ratio = left_weight / total;
442 let right_ratio = right_weight / total;
443
444 match (self.tree_type, self.criterion) {
445 (TreeType::Regression, SplitCriterion::Mse) => {
446 let left_mse = self.calculate_mse(target, left_indices);
448 let right_mse = self.calculate_mse(target, right_indices);
449 left_ratio * left_mse + right_ratio * right_mse
450 }
451 (TreeType::Regression, SplitCriterion::Mae) => {
452 let left_mae = self.calculate_mae(target, left_indices);
454 let right_mae = self.calculate_mae(target, right_indices);
455 left_ratio * left_mae + right_ratio * right_mae
456 }
457 (TreeType::Classification, SplitCriterion::Gini) => {
458 let left_gini = self.calculate_gini(target, left_indices);
460 let right_gini = self.calculate_gini(target, right_indices);
461 left_ratio * left_gini + right_ratio * right_gini
462 }
463 (TreeType::Classification, SplitCriterion::Entropy) => {
464 let left_entropy = self.calculate_entropy(target, left_indices);
466 let right_entropy = self.calculate_entropy(target, right_indices);
467 left_ratio * left_entropy + right_ratio * right_entropy
468 }
469 _ => {
470 F::infinity()
473 }
474 }
475 }
476
477 fn calculate_mse(&self, target: &[T], indices: &[usize]) -> F {
479 if indices.is_empty() {
480 return F::zero();
481 }
482
483 let mean = match self.calculate_mean(target, indices) {
485 Ok(m) => m,
486 Err(_) => return F::infinity(),
487 };
488 let mean_f: F = mean.as_();
489
490 let sum_squared_error: F = indices
491 .iter()
492 .map(|&idx| {
493 let error: F = target[idx].as_() - mean_f;
494 error * error
495 })
496 .fold(F::zero(), |a, b| a + b);
497
498 let count = F::from(indices.len()).unwrap_or(F::one());
499 sum_squared_error / count
500 }
501
502 fn calculate_mae(&self, target: &[T], indices: &[usize]) -> F {
504 if indices.is_empty() {
505 return F::zero();
506 }
507
508 let mean = match self.calculate_mean(target, indices) {
510 Ok(m) => m,
511 Err(_) => return F::infinity(),
512 };
513 let mean_f: F = mean.as_();
514
515 let sum_absolute_error: F = indices
516 .iter()
517 .map(|&idx| {
518 let error: F = target[idx].as_() - mean_f;
519 error.abs()
520 })
521 .fold(F::zero(), |a, b| a + b);
522
523 let count = F::from(indices.len()).unwrap_or(F::one());
524 sum_absolute_error / count
525 }
526
527 fn calculate_gini(&self, target: &[T], indices: &[usize]) -> F {
529 if indices.is_empty() {
530 return F::zero();
531 }
532
533 let (_, class_counts) = self.calculate_class_distribution(target, indices);
534 let n_samples = indices.len();
535
536 F::one()
537 - class_counts
538 .values()
539 .map(|&count| {
540 let probability: F = (count as f64 / n_samples as f64).as_();
541 probability * probability
542 })
543 .fold(F::zero(), |a, b| a + b)
544 }
545
546 fn calculate_entropy(&self, target: &[T], indices: &[usize]) -> F {
548 if indices.is_empty() {
549 return F::zero();
550 }
551
552 let (_, class_counts) = self.calculate_class_distribution(target, indices);
553 let n_samples = indices.len();
554
555 -class_counts
556 .values()
557 .map(|&count| {
558 let probability: F = (count as f64 / n_samples as f64).as_();
559 if probability > F::zero() {
560 probability * probability.ln()
561 } else {
562 F::zero()
563 }
564 })
565 .fold(F::zero(), |a, b| a + b)
566 }
567
568 fn calculate_mean(&self, target: &[T], indices: &[usize]) -> StatsResult<T> {
570 if indices.is_empty() {
571 return Err(StatsError::empty_data(
572 "Cannot calculate mean for empty indices",
573 ));
574 }
575
576 let sum: F = indices
579 .iter()
580 .map(|&idx| target[idx].as_())
581 .fold(F::zero(), |a, b| a + b);
582
583 let count: F = F::from(indices.len()).ok_or_else(|| {
584 StatsError::conversion_error(format!("Failed to convert {} to type F", indices.len()))
585 })?;
586 let mean_f = sum / count;
587
588 NumCast::from(mean_f).ok_or_else(|| {
590 StatsError::conversion_error("Failed to convert mean to the target type".to_string())
591 })
592 }
593
594 fn calculate_class_distribution(
596 &self,
597 target: &[T],
598 indices: &[usize],
599 ) -> (T, HashMap<T, usize>) {
600 let mut class_counts: HashMap<T, usize> = HashMap::new();
601
602 for &idx in indices {
603 let class = target[idx];
604 *class_counts.entry(class).or_insert(0) += 1;
605 }
606
607 let (majority_class, _) = class_counts
609 .iter()
610 .max_by_key(|&(_, count)| *count)
611 .map(|(&class, count)| (class, *count))
612 .unwrap_or_else(|| {
613 (NumCast::from(0.0).unwrap(), 0)
615 });
616
617 (majority_class, class_counts)
618 }
619
620 fn is_pure(&self, target: &[T], indices: &[usize]) -> bool {
622 if indices.is_empty() {
623 return true;
624 }
625
626 let first_value = &target[indices[0]];
627 indices.iter().all(|&idx| {
628 target[idx]
629 .partial_cmp(first_value)
630 .unwrap_or(Ordering::Equal)
631 == Ordering::Equal
632 })
633 }
634
635 pub fn predict<D>(&self, features: &[Vec<D>]) -> StatsResult<Vec<T>>
641 where
642 D: Clone + PartialOrd + NumCast,
643 T: NumCast,
644 {
645 features
646 .iter()
647 .map(|feature_vec| self.predict_single(feature_vec))
648 .collect()
649 }
650
651 fn predict_single<D>(&self, features: &[D]) -> StatsResult<T>
653 where
654 D: Clone + PartialOrd + NumCast,
655 T: NumCast,
656 {
657 if self.nodes.is_empty() {
658 return Err(StatsError::not_fitted(
659 "Decision tree has not been trained yet",
660 ));
661 }
662
663 let mut node_idx = 0;
664 loop {
665 let node = &self.nodes[node_idx];
666
667 if node.is_leaf() {
668 return node
669 .value
670 .ok_or_else(|| StatsError::invalid_input("Leaf node missing value"));
671 }
672
673 let feature_idx = node
674 .feature_idx
675 .ok_or_else(|| StatsError::invalid_input("Internal node missing feature index"))?;
676 let threshold = node
677 .threshold
678 .as_ref()
679 .ok_or_else(|| StatsError::invalid_input("Internal node missing threshold"))?;
680
681 if feature_idx >= features.len() {
682 return Err(StatsError::index_out_of_bounds(format!(
683 "Feature index {} is out of bounds (features has {} elements)",
684 feature_idx,
685 features.len()
686 )));
687 }
688
689 let feature_val = &features[feature_idx];
690
691 let threshold_d = D::from(*threshold).ok_or_else(|| {
694 StatsError::conversion_error(format!(
695 "Failed to convert threshold {:?} to feature type",
696 threshold
697 ))
698 })?;
699
700 let comparison = feature_val
701 .partial_cmp(&threshold_d)
702 .unwrap_or(Ordering::Equal);
703
704 if comparison != Ordering::Greater {
705 node_idx = node
706 .left
707 .ok_or_else(|| StatsError::invalid_input("Internal node missing left child"))?;
708 } else {
709 node_idx = node.right.ok_or_else(|| {
710 StatsError::invalid_input("Internal node missing right child")
711 })?;
712 }
713 }
714 }
715
716 pub fn feature_importances(&self) -> Vec<F> {
718 if self.nodes.is_empty() {
719 return Vec::new();
720 }
721
722 let n_features = self
724 .nodes
725 .iter()
726 .find(|node| !node.is_leaf())
727 .and_then(|node| node.feature_idx)
728 .map(|idx| idx + 1)
729 .unwrap_or(0);
730
731 if n_features == 0 {
732 return Vec::new();
733 }
734
735 let mut feature_counts = vec![0; n_features];
737 for node in &self.nodes {
738 if let Some(feature_idx) = node.feature_idx {
739 feature_counts[feature_idx] += 1;
740 }
741 }
742
743 let total_count: f64 = feature_counts.iter().sum::<usize>() as f64;
745 if total_count > 0.0 {
746 feature_counts
747 .iter()
748 .map(|&count| (count as f64 / total_count).as_())
749 .collect()
750 } else {
751 vec![F::zero(); n_features]
752 }
753 }
754
755 pub fn tree_structure(&self) -> String {
757 if self.nodes.is_empty() {
758 return "Empty tree".to_string();
759 }
760
761 let mut result = String::new();
762 self.print_node(0, 0, &mut result);
763 result
764 }
765
766 fn print_node(&self, node_idx: usize, depth: usize, result: &mut String) {
768 let node = &self.nodes[node_idx];
769 let indent = " ".repeat(depth);
770
771 if node.is_leaf() {
772 if self.tree_type == TreeType::Classification {
773 let class_distribution = node.class_distribution.as_ref().unwrap();
774 let classes: Vec<String> = class_distribution
775 .iter()
776 .map(|(class, count)| format!("{:?}: {}", class, count))
777 .collect();
778
779 result.push_str(&format!(
780 "{}Leaf: prediction = {:?}, distribution = {{{}}}\n",
781 indent,
782 node.value.as_ref().unwrap(),
783 classes.join(", ")
784 ));
785 } else {
786 result.push_str(&format!(
787 "{}Leaf: prediction = {:?}\n",
788 indent,
789 node.value.as_ref().unwrap()
790 ));
791 }
792 } else {
793 result.push_str(&format!(
794 "{}Node: feature {} <= {:?}\n",
795 indent,
796 node.feature_idx.unwrap(),
797 node.threshold.as_ref().unwrap()
798 ));
799
800 if let Some(left_idx) = node.left {
801 self.print_node(left_idx, depth + 1, result);
802 }
803
804 if let Some(right_idx) = node.right {
805 self.print_node(right_idx, depth + 1, result);
806 }
807 }
808 }
809}
810
811impl<T, F> fmt::Display for DecisionTree<T, F>
812where
813 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
814 F: Float,
815{
816 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
817 write!(
818 f,
819 "DecisionTree({:?}, {:?}, max_depth={}, nodes={})",
820 self.tree_type,
821 self.criterion,
822 self.max_depth,
823 self.nodes.len()
824 )
825 }
826}
827
828impl<T, F> DecisionTree<T, F>
830where
831 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
832 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
833 f64: AsPrimitive<F>,
834 usize: AsPrimitive<F>,
835 T: AsPrimitive<F>,
836 F: AsPrimitive<T>,
837{
838 pub fn get_max_depth(&self) -> usize {
840 self.max_depth
841 }
842
843 pub fn get_node_count(&self) -> usize {
845 self.nodes.len()
846 }
847
848 pub fn is_trained(&self) -> bool {
850 !self.nodes.is_empty()
851 }
852
853 pub fn get_leaf_count(&self) -> usize {
855 self.nodes.iter().filter(|node| node.is_leaf()).count()
856 }
857
858 pub fn calculate_depth(&self) -> usize {
860 if self.nodes.is_empty() {
861 return 0;
862 }
863
864 fn depth_helper<T, F>(nodes: &[Node<T, F>], node_idx: usize, current_depth: usize) -> usize
866 where
867 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
868 F: Float,
869 {
870 let node = &nodes[node_idx];
871
872 if node.is_leaf() {
873 return current_depth;
874 }
875
876 let left_depth = depth_helper(nodes, node.left.unwrap(), current_depth + 1);
877 let right_depth = depth_helper(nodes, node.right.unwrap(), current_depth + 1);
878
879 std::cmp::max(left_depth, right_depth)
880 }
881
882 depth_helper(&self.nodes, 0, 0)
883 }
884
885 pub fn summary(&self) -> String {
887 if !self.is_trained() {
888 return "Decision tree is not trained yet".to_string();
889 }
890
891 let leaf_count = self.get_leaf_count();
892 let node_count = self.get_node_count();
893 let actual_depth = self.calculate_depth();
894
895 format!(
896 "Decision Tree Summary:\n\
897 - Type: {:?}\n\
898 - Criterion: {:?}\n\
899 - Max depth: {}\n\
900 - Actual depth: {}\n\
901 - Total nodes: {}\n\
902 - Leaf nodes: {}\n\
903 - Internal nodes: {}",
904 self.tree_type,
905 self.criterion,
906 self.max_depth,
907 actual_depth,
908 node_count,
909 leaf_count,
910 node_count - leaf_count
911 )
912 }
913}
914
915#[cfg(test)]
916mod tests {
917 use super::*;
918 use std::time::Duration;
919
920 #[derive(Clone, Debug, PartialOrd, Copy)]
922 struct TestFloat(f64);
923
924 impl PartialEq for TestFloat {
925 fn eq(&self, other: &Self) -> bool {
926 (self.0 - other.0).abs() < f64::EPSILON
927 }
928 }
929
930 impl Eq for TestFloat {}
931
932 impl std::hash::Hash for TestFloat {
933 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
934 let bits = self.0.to_bits();
935 bits.hash(state);
936 }
937 }
938
939 impl ToPrimitive for TestFloat {
940 fn to_i64(&self) -> Option<i64> {
941 self.0.to_i64()
942 }
943
944 fn to_u64(&self) -> Option<u64> {
945 self.0.to_u64()
946 }
947
948 fn to_f64(&self) -> Option<f64> {
949 Some(self.0)
950 }
951 }
952
953 impl NumCast for TestFloat {
954 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
955 n.to_f64().map(TestFloat)
956 }
957 }
958
959 impl FromPrimitive for TestFloat {
960 fn from_i64(n: i64) -> Option<Self> {
961 Some(TestFloat(n as f64))
962 }
963
964 fn from_u64(n: u64) -> Option<Self> {
965 Some(TestFloat(n as f64))
966 }
967
968 fn from_f64(n: f64) -> Option<Self> {
969 Some(TestFloat(n))
970 }
971 }
972
973 impl AsPrimitive<f64> for TestFloat {
974 fn as_(self) -> f64 {
975 self.0
976 }
977 }
978
979 impl AsPrimitive<TestFloat> for f64 {
980 fn as_(self) -> TestFloat {
981 TestFloat(self)
982 }
983 }
984
985 #[test]
987 fn test_diabetes_prediction() {
988 let mut tree = DecisionTree::<TestFloat, f64>::new(
990 TreeType::Regression,
991 SplitCriterion::Mse,
992 5, 2, 1, );
996
997 let features = vec![
999 vec![45.0, 22.5, 95.0, 120.0, 0.0], vec![50.0, 26.0, 105.0, 140.0, 1.0], vec![35.0, 23.0, 90.0, 115.0, 0.0], vec![55.0, 30.0, 140.0, 150.0, 1.0], vec![60.0, 29.5, 130.0, 145.0, 1.0], vec![40.0, 24.0, 85.0, 125.0, 0.0], vec![48.0, 27.0, 110.0, 135.0, 1.0], vec![65.0, 31.0, 150.0, 155.0, 1.0], vec![42.0, 25.0, 100.0, 130.0, 0.0], vec![58.0, 32.0, 145.0, 160.0, 1.0], ];
1010
1011 let target = vec![
1013 TestFloat(2.0),
1014 TestFloat(5.5),
1015 TestFloat(1.5),
1016 TestFloat(8.0),
1017 TestFloat(6.5),
1018 TestFloat(2.0),
1019 TestFloat(5.0),
1020 TestFloat(8.5),
1021 TestFloat(3.0),
1022 TestFloat(9.0),
1023 ];
1024
1025 tree.fit(&features, &target).unwrap();
1027
1028 let test_features = vec![
1030 vec![45.0, 23.0, 90.0, 120.0, 0.0], vec![62.0, 31.0, 145.0, 155.0, 1.0], ];
1033
1034 let predictions = tree.predict(&test_features).unwrap();
1035
1036 assert!(
1038 predictions[0].0 < 5.0,
1039 "Young healthy patient should have low risk score"
1040 );
1041 assert!(
1042 predictions[1].0 > 5.0,
1043 "Older patient with high metrics should have high risk score"
1044 );
1045
1046 assert!(tree.is_trained());
1048 assert!(tree.calculate_depth() <= tree.get_max_depth());
1049 assert!(tree.get_leaf_count() > 0);
1050
1051 println!("Diabetes prediction tree:\n{}", tree.summary());
1053 }
1054
1055 #[test]
1057 fn test_disease_classification() {
1058 let mut tree = DecisionTree::<u8, f64>::new(
1060 TreeType::Classification,
1061 SplitCriterion::Gini,
1062 4, 2, 1, );
1066
1067 let features = vec![
1070 vec![3, 1, 2, 1, 0, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], vec![0, 3, 1, 0, 2, 2], vec![3, 2, 3, 2, 1, 0], vec![1, 3, 2, 0, 0, 3], vec![2, 0, 2, 3, 1, 0], vec![0, 2, 1, 0, 2, 2], vec![3, 1, 2, 1, 1, 0], vec![2, 3, 2, 0, 1, 2], vec![1, 0, 1, 3, 0, 0], vec![0, 3, 2, 0, 1, 3], ];
1083
1084 let target = vec![1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
1086
1087 tree.fit(&features, &target).unwrap();
1089
1090 let test_features = vec![
1092 vec![3, 2, 2, 1, 1, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], ];
1096
1097 let predictions = tree.predict(&test_features).unwrap();
1098
1099 assert_eq!(predictions[0], 1, "Should diagnose as Flu");
1101 assert_eq!(predictions[1], 2, "Should diagnose as COVID");
1102 assert_eq!(predictions[2], 3, "Should diagnose as Migraine");
1103
1104 println!("Disease classification tree:\n{}", tree.summary());
1106 }
1107
1108 #[test]
1109 fn test_system_failure_prediction() {
1110 let mut tree = DecisionTree::<i32, f64>::new(
1115 TreeType::Regression,
1116 SplitCriterion::Mse,
1117 2, 5, 2, );
1121
1122 let features = vec![
1125 vec![30, 40, 0],
1127 vec![35, 45, 1],
1128 vec![40, 50, 0],
1129 vec![25, 35, 1],
1130 vec![30, 40, 0],
1131 vec![90, 95, 10],
1133 vec![85, 90, 8],
1134 vec![95, 98, 15],
1135 vec![90, 95, 12],
1136 vec![80, 85, 7],
1137 ];
1138
1139 let target = vec![
1141 1000, 900, 950, 1100, 1050, 10, 15, 5, 8, 20, ];
1144
1145 tree.fit(&features, &target).unwrap();
1147
1148 println!("System failure tree summary:\n{}", tree.summary());
1150
1151 if tree.is_trained() {
1153 println!("Tree structure:\n{}", tree.tree_structure());
1154 }
1155
1156 if tree.is_trained() {
1158 let test_features = vec![
1160 vec![30, 40, 0], vec![90, 95, 10], ];
1163
1164 let predictions = match tree.predict(&test_features) {
1166 Ok(preds) => {
1167 println!("Successfully made predictions: {:?}", preds);
1168 preds
1169 }
1170 Err(e) => {
1171 println!("Error during prediction: {:?}", e);
1172 return; }
1174 };
1175
1176 if predictions.len() == 2 {
1178 assert!(
1179 predictions[0] > predictions[1],
1180 "Healthy system should have longer time to failure than failing system"
1181 );
1182 }
1183 } else {
1184 println!("Tree wasn't properly trained - skipping prediction tests");
1185 }
1186 }
1187
1188 #[test]
1190 fn test_security_incident_classification() {
1191 let mut tree = DecisionTree::<u8, f64>::new(
1193 TreeType::Classification,
1194 SplitCriterion::Entropy,
1195 5, 2, 1, );
1199
1200 let features = vec![
1202 vec![1, 0, 0, 0, 0], vec![5, 1, 1, 1, 0], vec![15, 3, 2, 1, 1], vec![2, 0, 1, 0, 0], vec![8, 2, 1, 1, 0], vec![20, 4, 3, 1, 1], vec![1, 0, 0, 1, 0], vec![6, 1, 2, 1, 0], vec![25, 5, 3, 1, 1], vec![3, 0, 0, 0, 0], vec![7, 2, 1, 0, 0], vec![18, 3, 2, 1, 1], vec![0, 0, 0, 0, 0], vec![9, 2, 2, 1, 0], vec![22, 4, 3, 1, 1], ];
1218
1219 let target = vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2];
1221
1222 tree.fit(&features, &target).unwrap();
1224
1225 let test_features = vec![
1227 vec![2, 0, 0, 0, 0], vec![7, 1, 1, 1, 0], vec![17, 3, 2, 1, 1], ];
1231
1232 let predictions = tree.predict(&test_features).unwrap();
1233
1234 assert_eq!(predictions[0], 0, "Should classify as normal activity");
1236 assert_eq!(predictions[1], 1, "Should classify as suspicious activity");
1237 assert_eq!(predictions[2], 2, "Should classify as potential breach");
1238
1239 println!(
1241 "Security incident classification tree:\n{}",
1242 tree.tree_structure()
1243 );
1244 }
1245
1246 #[test]
1248 fn test_custom_type_performance_analysis() {
1249 #[derive(Clone, PartialEq, Eq, Hash, Debug, Copy)]
1251 struct ResponseTime(Duration);
1252
1253 impl PartialOrd for ResponseTime {
1254 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1255 self.0.partial_cmp(&other.0)
1256 }
1257 }
1258
1259 impl ToPrimitive for ResponseTime {
1260 fn to_i64(&self) -> Option<i64> {
1261 Some(self.0.as_millis() as i64)
1262 }
1263
1264 fn to_u64(&self) -> Option<u64> {
1265 Some(self.0.as_millis() as u64)
1266 }
1267
1268 fn to_f64(&self) -> Option<f64> {
1269 Some(self.0.as_millis() as f64)
1270 }
1271 }
1272
1273 impl AsPrimitive<f64> for ResponseTime {
1274 fn as_(self) -> f64 {
1275 self.0.as_millis() as f64
1276 }
1277 }
1278
1279 impl NumCast for ResponseTime {
1280 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
1281 n.to_u64()
1282 .map(|ms| ResponseTime(Duration::from_millis(ms as u64)))
1283 }
1284 }
1285
1286 impl FromPrimitive for ResponseTime {
1287 fn from_i64(n: i64) -> Option<Self> {
1288 if n >= 0 {
1289 Some(ResponseTime(Duration::from_millis(n as u64)))
1290 } else {
1291 None
1292 }
1293 }
1294
1295 fn from_u64(n: u64) -> Option<Self> {
1296 Some(ResponseTime(Duration::from_millis(n)))
1297 }
1298
1299 fn from_f64(n: f64) -> Option<Self> {
1300 if n >= 0.0 {
1301 Some(ResponseTime(Duration::from_millis(n as u64)))
1302 } else {
1303 None
1304 }
1305 }
1306 }
1307
1308 impl AsPrimitive<ResponseTime> for f64 {
1310 fn as_(self) -> ResponseTime {
1311 ResponseTime(Duration::from_millis(self as u64))
1312 }
1313 }
1314
1315 let mut tree = DecisionTree::<ResponseTime, f64>::new(
1317 TreeType::Regression,
1318 SplitCriterion::Mse,
1319 3, 2, 1, );
1323
1324 let features = vec![
1326 vec![10, 20, 3, 5],
1327 vec![50, 40, 8, 2],
1328 vec![20, 30, 4, 4],
1329 vec![100, 60, 12, 0],
1330 vec![30, 35, 6, 3],
1331 vec![80, 50, 10, 1],
1332 ];
1333
1334 let target = vec![
1336 ResponseTime(Duration::from_millis(100)),
1337 ResponseTime(Duration::from_millis(350)),
1338 ResponseTime(Duration::from_millis(150)),
1339 ResponseTime(Duration::from_millis(600)),
1340 ResponseTime(Duration::from_millis(200)),
1341 ResponseTime(Duration::from_millis(450)),
1342 ];
1343
1344 tree.fit(&features, &target).unwrap();
1346
1347 let test_features = vec![
1349 vec![15, 25, 3, 4], vec![90, 55, 11, 0], ];
1352
1353 let predictions = tree.predict(&test_features).unwrap();
1354
1355 assert!(
1357 predictions[0].0.as_millis() < 200,
1358 "Small request should have fast response time"
1359 );
1360 assert!(
1361 predictions[1].0.as_millis() > 400,
1362 "Large request should have slow response time"
1363 );
1364
1365 println!("Response time prediction tree:\n{}", tree.summary());
1367 }
1368
1369 #[test]
1371 fn test_empty_features() {
1372 let mut tree =
1373 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1374
1375 let empty_features: Vec<Vec<f64>> = vec![];
1377 let empty_target: Vec<i32> = vec![];
1378
1379 let result = tree.fit(&empty_features, &empty_target);
1380 assert!(
1381 result.is_err(),
1382 "Fitting with empty features should return an error"
1383 );
1384 }
1385
1386 #[test]
1388 fn test_single_class_classification() {
1389 let mut tree =
1390 DecisionTree::<u8, f64>::new(TreeType::Classification, SplitCriterion::Gini, 3, 2, 1);
1391
1392 let features = vec![
1394 vec![1, 2, 3],
1395 vec![4, 5, 6],
1396 vec![7, 8, 9],
1397 vec![10, 11, 12],
1398 ];
1399
1400 let target = vec![1, 1, 1, 1];
1402
1403 tree.fit(&features, &target).unwrap();
1405
1406 let prediction = tree.predict(&vec![vec![2, 3, 4]]).unwrap();
1408
1409 assert_eq!(prediction[0], 1);
1411
1412 assert_eq!(tree.get_node_count(), 1);
1414 assert_eq!(tree.get_leaf_count(), 1);
1415 }
1416
1417 #[test]
1418 fn test_predict_not_fitted() {
1419 let tree =
1421 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1422 let features = vec![vec![1.0, 2.0]];
1423 let result = tree.predict(&features);
1424 assert!(result.is_err());
1425 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
1426 }
1427
1428 #[test]
1429 fn test_fit_target_empty() {
1430 let mut tree =
1431 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1432 let features = vec![vec![1.0, 2.0]];
1433 let target: Vec<i32> = vec![];
1434 let result = tree.fit(&features, &target);
1435 assert!(result.is_err());
1436 assert!(matches!(result.unwrap_err(), StatsError::EmptyData { .. }));
1437 }
1438
1439 #[test]
1440 fn test_fit_length_mismatch() {
1441 let mut tree =
1442 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1443 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1444 let target = vec![1]; let result = tree.fit(&features, &target);
1446 assert!(result.is_err());
1447 assert!(matches!(
1448 result.unwrap_err(),
1449 StatsError::DimensionMismatch { .. }
1450 ));
1451 }
1452
1453 #[test]
1454 fn test_fit_inconsistent_feature_lengths() {
1455 let mut tree =
1456 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1457 let features = vec![vec![1.0, 2.0], vec![3.0]]; let target = vec![1, 2];
1459 let result = tree.fit(&features, &target);
1460 assert!(result.is_err());
1461 assert!(matches!(
1462 result.unwrap_err(),
1463 StatsError::InvalidInput { .. }
1464 ));
1465 }
1466}