1use crate::error::{StatsError, StatsResult};
2use num_traits::cast::AsPrimitive;
3use num_traits::{Float, FromPrimitive, NumCast, ToPrimitive};
4use rayon::prelude::*;
5use std::cmp::Ordering;
6use std::collections::HashMap;
7use std::fmt::{self, Debug};
8use std::hash::Hash;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum TreeType {
13 Regression,
15 Classification,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum SplitCriterion {
22 Mse,
24 Mae,
26 Gini,
28 Entropy,
30}
31
32#[derive(Debug, Clone)]
34struct Node<T, F>
35where
36 T: Clone + PartialOrd + Debug + ToPrimitive,
37 F: Float,
38{
39 feature_idx: Option<usize>,
41 threshold: Option<T>,
43 value: Option<T>,
45 class_distribution: Option<HashMap<T, usize>>,
47 left: Option<usize>,
49 right: Option<usize>,
51 _phantom: std::marker::PhantomData<F>,
53}
54
55impl<T, F> Node<T, F>
56where
57 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
58 F: Float,
59{
60 fn new_split(feature_idx: usize, threshold: T) -> Self {
62 Node {
63 feature_idx: Some(feature_idx),
64 threshold: Some(threshold),
65 value: None,
66 class_distribution: None,
67 left: None,
68 right: None,
69 _phantom: std::marker::PhantomData,
70 }
71 }
72
73 fn new_leaf_regression(value: T) -> Self {
75 Node {
76 feature_idx: None,
77 threshold: None,
78 value: Some(value),
79 class_distribution: None,
80 left: None,
81 right: None,
82 _phantom: std::marker::PhantomData,
83 }
84 }
85
86 fn new_leaf_classification(value: T, class_distribution: HashMap<T, usize>) -> Self {
88 Node {
89 feature_idx: None,
90 threshold: None,
91 value: Some(value),
92 class_distribution: Some(class_distribution),
93 left: None,
94 right: None,
95 _phantom: std::marker::PhantomData,
96 }
97 }
98
99 fn is_leaf(&self) -> bool {
101 self.feature_idx.is_none()
102 }
103}
104
105#[derive(Debug, Clone)]
111pub struct DecisionTree<T, F>
112where
113 T: Clone + PartialOrd + Debug + ToPrimitive,
114 F: Float,
115{
116 tree_type: TreeType,
118 criterion: SplitCriterion,
120 max_depth: usize,
122 min_samples_split: usize,
124 min_samples_leaf: usize,
126 nodes: Vec<Node<T, F>>,
128}
129
130impl<T, F> DecisionTree<T, F>
131where
132 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
133 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
134 f64: AsPrimitive<F>,
135 usize: AsPrimitive<F>,
136 T: AsPrimitive<F>,
137 F: AsPrimitive<T>,
138{
139 pub fn new(
141 tree_type: TreeType,
142 criterion: SplitCriterion,
143 max_depth: usize,
144 min_samples_split: usize,
145 min_samples_leaf: usize,
146 ) -> Self {
147 Self {
148 tree_type,
149 criterion,
150 max_depth,
151 min_samples_split,
152 min_samples_leaf,
153 nodes: Vec::new(),
154 }
155 }
156
157 pub fn fit<D>(&mut self, features: &[Vec<D>], target: &[T]) -> StatsResult<()>
165 where
166 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
167 T: FromPrimitive,
168 {
169 if features.is_empty() {
170 return Err(StatsError::empty_data("Features cannot be empty"));
171 }
172 if target.is_empty() {
173 return Err(StatsError::empty_data("Target cannot be empty"));
174 }
175 if features.len() != target.len() {
176 return Err(StatsError::dimension_mismatch(format!(
177 "Features and target must have the same length (got {} and {})",
178 features.len(),
179 target.len()
180 )));
181 }
182
183 let n_features = features[0].len();
185 for (i, feature_vec) in features.iter().enumerate() {
186 if feature_vec.len() != n_features {
187 return Err(StatsError::invalid_input(format!(
188 "All feature vectors must have the same length (vector {} has {} features, expected {})",
189 i,
190 feature_vec.len(),
191 n_features
192 )));
193 }
194 }
195
196 self.nodes = Vec::new();
198
199 let indices: Vec<usize> = (0..features.len()).collect();
201
202 self.build_tree(features, target, &indices, 0)?;
204 Ok(())
205 }
206
207 fn build_tree<D>(
209 &mut self,
210 features: &[Vec<D>],
211 target: &[T],
212 indices: &[usize],
213 depth: usize,
214 ) -> StatsResult<usize>
215 where
216 D: Clone + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
217 {
218 if depth >= self.max_depth
220 || indices.len() < self.min_samples_split
221 || self.is_pure(target, indices)
222 {
223 let node_idx = self.nodes.len();
224 if self.tree_type == TreeType::Regression {
225 let value = self.calculate_mean(target, indices)?;
227 self.nodes.push(Node::new_leaf_regression(value));
228 } else {
229 let (value, class_counts) = self.calculate_class_distribution(target, indices);
231 self.nodes
232 .push(Node::new_leaf_classification(value, class_counts));
233 }
234 return Ok(node_idx);
235 }
236
237 let (feature_idx, threshold, left_indices, right_indices) =
239 self.find_best_split(features, target, indices);
240
241 if left_indices.is_empty() || right_indices.is_empty() {
243 let node_idx = self.nodes.len();
244 if self.tree_type == TreeType::Regression {
245 let value = self.calculate_mean(target, indices)?;
246 self.nodes.push(Node::new_leaf_regression(value));
247 } else {
248 let (value, class_counts) = self.calculate_class_distribution(target, indices);
249 self.nodes
250 .push(Node::new_leaf_classification(value, class_counts));
251 }
252 return Ok(node_idx);
253 }
254
255 let node_idx = self.nodes.len();
257
258 let t_threshold = NumCast::from(threshold).ok_or_else(|| {
260 StatsError::conversion_error(
261 "Failed to convert threshold to the feature type".to_string(),
262 )
263 })?;
264
265 self.nodes.push(Node::new_split(feature_idx, t_threshold));
266
267 let left_idx = self.build_tree(features, target, &left_indices, depth + 1)?;
269 let right_idx = self.build_tree(features, target, &right_indices, depth + 1)?;
270
271 self.nodes[node_idx].left = Some(left_idx);
273 self.nodes[node_idx].right = Some(right_idx);
274
275 Ok(node_idx)
276 }
277
278 fn find_best_split<D>(
288 &self,
289 features: &[Vec<D>],
290 target: &[T],
291 indices: &[usize],
292 ) -> (usize, D, Vec<usize>, Vec<usize>)
293 where
294 D: Clone + Copy + PartialOrd + NumCast + ToPrimitive + AsPrimitive<F> + Send + Sync,
295 {
296 let n_features = features[0].len();
297
298 let mut best_impurity = F::infinity();
299 let mut best_feature = 0;
300 let mut best_threshold = features[indices[0]][0];
301 let mut best_left: Vec<usize> = Vec::new();
302 let mut best_right: Vec<usize> = Vec::new();
303
304 let results: Vec<_> = (0..n_features)
308 .into_par_iter()
309 .filter_map(|feature_idx| {
310 let mut sorted_indices: Vec<usize> = indices.to_vec();
313 sorted_indices.sort_by(|&a, &b| {
314 features[a][feature_idx]
315 .partial_cmp(&features[b][feature_idx])
316 .unwrap_or(Ordering::Equal)
317 });
318 let n = sorted_indices.len();
319 if n < 2 {
320 return None;
321 }
322
323 let two = F::from(2.0)?;
324
325 let mut feature_best_impurity = F::infinity();
330 let mut feature_best_split_pos: Option<usize> = None;
331 let mut feature_best_threshold = features[sorted_indices[0]][feature_idx];
332
333 for split_pos in 1..n {
334 let i_prev = sorted_indices[split_pos - 1];
335 let i_curr = sorted_indices[split_pos];
336 let v_prev = features[i_prev][feature_idx];
337 let v_curr = features[i_curr][feature_idx];
338 if v_prev.partial_cmp(&v_curr).unwrap_or(Ordering::Equal) == Ordering::Equal {
339 continue; }
341 let left = &sorted_indices[..split_pos];
342 let right = &sorted_indices[split_pos..];
343 if left.len() < self.min_samples_leaf || right.len() < self.min_samples_leaf {
344 continue;
345 }
346 let impurity = self.calculate_split_impurity(target, left, right);
347 if impurity < feature_best_impurity {
348 let v1: F = v_prev.as_();
349 let v2: F = v_curr.as_();
350 let mid = (v1 + v2) / two;
351 let threshold: D = match NumCast::from(mid) {
352 Some(t) => t,
353 None => continue,
354 };
355 feature_best_impurity = impurity;
356 feature_best_split_pos = Some(split_pos);
357 feature_best_threshold = threshold;
358 }
359 }
360
361 feature_best_split_pos.map(|split_pos| {
363 let (left, right) = sorted_indices.split_at(split_pos);
364 (
365 feature_idx,
366 feature_best_impurity,
367 feature_best_threshold,
368 left.to_vec(),
369 right.to_vec(),
370 )
371 })
372 })
373 .collect();
374
375 for (feature_idx, impurity, threshold, left, right) in results {
376 if impurity < best_impurity {
377 best_impurity = impurity;
378 best_feature = feature_idx;
379 best_threshold = threshold;
380 best_left = left;
381 best_right = right;
382 }
383 }
384
385 (best_feature, best_threshold, best_left, best_right)
386 }
387
388 fn calculate_split_impurity(
390 &self,
391 target: &[T],
392 left_indices: &[usize],
393 right_indices: &[usize],
394 ) -> F {
395 let n_left = left_indices.len();
396 let n_right = right_indices.len();
397 let n_total = n_left + n_right;
398
399 if n_left == 0 || n_right == 0 {
400 return F::infinity();
401 }
402
403 let left_weight: F = (n_left as f64).as_();
404 let right_weight: F = (n_right as f64).as_();
405 let total: F = (n_total as f64).as_();
406
407 let left_ratio = left_weight / total;
408 let right_ratio = right_weight / total;
409
410 match (self.tree_type, self.criterion) {
411 (TreeType::Regression, SplitCriterion::Mse) => {
412 let left_mse = self.calculate_mse(target, left_indices);
414 let right_mse = self.calculate_mse(target, right_indices);
415 left_ratio * left_mse + right_ratio * right_mse
416 }
417 (TreeType::Regression, SplitCriterion::Mae) => {
418 let left_mae = self.calculate_mae(target, left_indices);
420 let right_mae = self.calculate_mae(target, right_indices);
421 left_ratio * left_mae + right_ratio * right_mae
422 }
423 (TreeType::Classification, SplitCriterion::Gini) => {
424 let left_gini = self.calculate_gini(target, left_indices);
426 let right_gini = self.calculate_gini(target, right_indices);
427 left_ratio * left_gini + right_ratio * right_gini
428 }
429 (TreeType::Classification, SplitCriterion::Entropy) => {
430 let left_entropy = self.calculate_entropy(target, left_indices);
432 let right_entropy = self.calculate_entropy(target, right_indices);
433 left_ratio * left_entropy + right_ratio * right_entropy
434 }
435 _ => {
436 F::infinity()
439 }
440 }
441 }
442
443 fn calculate_mse(&self, target: &[T], indices: &[usize]) -> F {
445 if indices.is_empty() {
446 return F::zero();
447 }
448
449 let mean = match self.calculate_mean(target, indices) {
451 Ok(m) => m,
452 Err(_) => return F::infinity(),
453 };
454 let mean_f: F = mean.as_();
455
456 let sum_squared_error: F = indices
457 .iter()
458 .map(|&idx| {
459 let error: F = target[idx].as_() - mean_f;
460 error * error
461 })
462 .fold(F::zero(), |a, b| a + b);
463
464 let count = F::from(indices.len()).unwrap_or(F::one());
465 sum_squared_error / count
466 }
467
468 fn calculate_mae(&self, target: &[T], indices: &[usize]) -> F {
470 if indices.is_empty() {
471 return F::zero();
472 }
473
474 let mean = match self.calculate_mean(target, indices) {
476 Ok(m) => m,
477 Err(_) => return F::infinity(),
478 };
479 let mean_f: F = mean.as_();
480
481 let sum_absolute_error: F = indices
482 .iter()
483 .map(|&idx| {
484 let error: F = target[idx].as_() - mean_f;
485 error.abs()
486 })
487 .fold(F::zero(), |a, b| a + b);
488
489 let count = F::from(indices.len()).unwrap_or(F::one());
490 sum_absolute_error / count
491 }
492
493 fn calculate_gini(&self, target: &[T], indices: &[usize]) -> F {
495 if indices.is_empty() {
496 return F::zero();
497 }
498
499 let (_, class_counts) = self.calculate_class_distribution(target, indices);
500 let n_samples = indices.len();
501
502 F::one()
503 - class_counts
504 .values()
505 .map(|&count| {
506 let probability: F = (count as f64 / n_samples as f64).as_();
507 probability * probability
508 })
509 .fold(F::zero(), |a, b| a + b)
510 }
511
512 fn calculate_entropy(&self, target: &[T], indices: &[usize]) -> F {
514 if indices.is_empty() {
515 return F::zero();
516 }
517
518 let (_, class_counts) = self.calculate_class_distribution(target, indices);
519 let n_samples = indices.len();
520
521 -class_counts
522 .values()
523 .map(|&count| {
524 let probability: F = (count as f64 / n_samples as f64).as_();
525 if probability > F::zero() {
526 probability * probability.ln()
527 } else {
528 F::zero()
529 }
530 })
531 .fold(F::zero(), |a, b| a + b)
532 }
533
534 fn calculate_mean(&self, target: &[T], indices: &[usize]) -> StatsResult<T> {
536 if indices.is_empty() {
537 return Err(StatsError::empty_data(
538 "Cannot calculate mean for empty indices",
539 ));
540 }
541
542 let sum: F = indices
545 .iter()
546 .map(|&idx| target[idx].as_())
547 .fold(F::zero(), |a, b| a + b);
548
549 let count: F = F::from(indices.len()).ok_or_else(|| {
550 StatsError::conversion_error(format!("Failed to convert {} to type F", indices.len()))
551 })?;
552 let mean_f = sum / count;
553
554 NumCast::from(mean_f).ok_or_else(|| {
556 StatsError::conversion_error("Failed to convert mean to the target type".to_string())
557 })
558 }
559
560 fn calculate_class_distribution(
562 &self,
563 target: &[T],
564 indices: &[usize],
565 ) -> (T, HashMap<T, usize>) {
566 let mut class_counts: HashMap<T, usize> = HashMap::new();
567
568 for &idx in indices {
569 let class = target[idx];
570 *class_counts.entry(class).or_insert(0) += 1;
571 }
572
573 let (majority_class, _) = class_counts
575 .iter()
576 .max_by_key(|&(_, count)| *count)
577 .map(|(&class, count)| (class, *count))
578 .unwrap_or_else(|| {
579 (NumCast::from(0.0).unwrap(), 0)
581 });
582
583 (majority_class, class_counts)
584 }
585
586 fn is_pure(&self, target: &[T], indices: &[usize]) -> bool {
588 if indices.is_empty() {
589 return true;
590 }
591
592 let first_value = &target[indices[0]];
593 indices.iter().all(|&idx| {
594 target[idx]
595 .partial_cmp(first_value)
596 .unwrap_or(Ordering::Equal)
597 == Ordering::Equal
598 })
599 }
600
601 pub fn predict<D>(&self, features: &[Vec<D>]) -> StatsResult<Vec<T>>
607 where
608 D: Clone + PartialOrd + NumCast,
609 T: NumCast,
610 {
611 features
612 .iter()
613 .map(|feature_vec| self.predict_single(feature_vec))
614 .collect()
615 }
616
617 fn predict_single<D>(&self, features: &[D]) -> StatsResult<T>
619 where
620 D: Clone + PartialOrd + NumCast,
621 T: NumCast,
622 {
623 if self.nodes.is_empty() {
624 return Err(StatsError::not_fitted(
625 "Decision tree has not been trained yet",
626 ));
627 }
628
629 let mut node_idx = 0;
630 loop {
631 let node = &self.nodes[node_idx];
632
633 if node.is_leaf() {
634 return node
635 .value
636 .ok_or_else(|| StatsError::invalid_input("Leaf node missing value"));
637 }
638
639 let feature_idx = node
640 .feature_idx
641 .ok_or_else(|| StatsError::invalid_input("Internal node missing feature index"))?;
642 let threshold = node
643 .threshold
644 .as_ref()
645 .ok_or_else(|| StatsError::invalid_input("Internal node missing threshold"))?;
646
647 if feature_idx >= features.len() {
648 return Err(StatsError::index_out_of_bounds(format!(
649 "Feature index {} is out of bounds (features has {} elements)",
650 feature_idx,
651 features.len()
652 )));
653 }
654
655 let feature_val = &features[feature_idx];
656
657 let threshold_d = D::from(*threshold).ok_or_else(|| {
660 StatsError::conversion_error(format!(
661 "Failed to convert threshold {:?} to feature type",
662 threshold
663 ))
664 })?;
665
666 let comparison = feature_val
667 .partial_cmp(&threshold_d)
668 .unwrap_or(Ordering::Equal);
669
670 if comparison != Ordering::Greater {
671 node_idx = node
672 .left
673 .ok_or_else(|| StatsError::invalid_input("Internal node missing left child"))?;
674 } else {
675 node_idx = node.right.ok_or_else(|| {
676 StatsError::invalid_input("Internal node missing right child")
677 })?;
678 }
679 }
680 }
681
682 pub fn feature_importances(&self) -> Vec<F> {
684 if self.nodes.is_empty() {
685 return Vec::new();
686 }
687
688 let n_features = self
690 .nodes
691 .iter()
692 .find(|node| !node.is_leaf())
693 .and_then(|node| node.feature_idx)
694 .map(|idx| idx + 1)
695 .unwrap_or(0);
696
697 if n_features == 0 {
698 return Vec::new();
699 }
700
701 let mut feature_counts = vec![0; n_features];
703 for node in &self.nodes {
704 if let Some(feature_idx) = node.feature_idx {
705 feature_counts[feature_idx] += 1;
706 }
707 }
708
709 let total_count: f64 = feature_counts.iter().sum::<usize>() as f64;
711 if total_count > 0.0 {
712 feature_counts
713 .iter()
714 .map(|&count| (count as f64 / total_count).as_())
715 .collect()
716 } else {
717 vec![F::zero(); n_features]
718 }
719 }
720
721 pub fn tree_structure(&self) -> String {
723 if self.nodes.is_empty() {
724 return "Empty tree".to_string();
725 }
726
727 let mut result = String::new();
728 self.print_node(0, 0, &mut result);
729 result
730 }
731
732 fn print_node(&self, node_idx: usize, depth: usize, result: &mut String) {
734 let node = &self.nodes[node_idx];
735 let indent = " ".repeat(depth);
736
737 if node.is_leaf() {
738 if self.tree_type == TreeType::Classification {
739 let class_distribution = node.class_distribution.as_ref().unwrap();
740 let classes: Vec<String> = class_distribution
741 .iter()
742 .map(|(class, count)| format!("{:?}: {}", class, count))
743 .collect();
744
745 result.push_str(&format!(
746 "{}Leaf: prediction = {:?}, distribution = {{{}}}\n",
747 indent,
748 node.value.as_ref().unwrap(),
749 classes.join(", ")
750 ));
751 } else {
752 result.push_str(&format!(
753 "{}Leaf: prediction = {:?}\n",
754 indent,
755 node.value.as_ref().unwrap()
756 ));
757 }
758 } else {
759 result.push_str(&format!(
760 "{}Node: feature {} <= {:?}\n",
761 indent,
762 node.feature_idx.unwrap(),
763 node.threshold.as_ref().unwrap()
764 ));
765
766 if let Some(left_idx) = node.left {
767 self.print_node(left_idx, depth + 1, result);
768 }
769
770 if let Some(right_idx) = node.right {
771 self.print_node(right_idx, depth + 1, result);
772 }
773 }
774 }
775}
776
777impl<T, F> fmt::Display for DecisionTree<T, F>
778where
779 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
780 F: Float,
781{
782 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783 write!(
784 f,
785 "DecisionTree({:?}, {:?}, max_depth={}, nodes={})",
786 self.tree_type,
787 self.criterion,
788 self.max_depth,
789 self.nodes.len()
790 )
791 }
792}
793
794impl<T, F> DecisionTree<T, F>
796where
797 T: Clone + PartialOrd + Eq + Hash + Send + Sync + NumCast + ToPrimitive + Debug,
798 F: Float + Send + Sync + NumCast + FromPrimitive + 'static,
799 f64: AsPrimitive<F>,
800 usize: AsPrimitive<F>,
801 T: AsPrimitive<F>,
802 F: AsPrimitive<T>,
803{
804 pub fn get_max_depth(&self) -> usize {
806 self.max_depth
807 }
808
809 pub fn get_node_count(&self) -> usize {
811 self.nodes.len()
812 }
813
814 pub fn is_trained(&self) -> bool {
816 !self.nodes.is_empty()
817 }
818
819 pub fn get_leaf_count(&self) -> usize {
821 self.nodes.iter().filter(|node| node.is_leaf()).count()
822 }
823
824 pub fn calculate_depth(&self) -> usize {
826 if self.nodes.is_empty() {
827 return 0;
828 }
829
830 fn depth_helper<T, F>(nodes: &[Node<T, F>], node_idx: usize, current_depth: usize) -> usize
832 where
833 T: Clone + PartialOrd + Eq + Hash + Debug + ToPrimitive,
834 F: Float,
835 {
836 let node = &nodes[node_idx];
837
838 if node.is_leaf() {
839 return current_depth;
840 }
841
842 let left_depth = depth_helper(nodes, node.left.unwrap(), current_depth + 1);
843 let right_depth = depth_helper(nodes, node.right.unwrap(), current_depth + 1);
844
845 std::cmp::max(left_depth, right_depth)
846 }
847
848 depth_helper(&self.nodes, 0, 0)
849 }
850
851 pub fn summary(&self) -> String {
853 if !self.is_trained() {
854 return "Decision tree is not trained yet".to_string();
855 }
856
857 let leaf_count = self.get_leaf_count();
858 let node_count = self.get_node_count();
859 let actual_depth = self.calculate_depth();
860
861 format!(
862 "Decision Tree Summary:\n\
863 - Type: {:?}\n\
864 - Criterion: {:?}\n\
865 - Max depth: {}\n\
866 - Actual depth: {}\n\
867 - Total nodes: {}\n\
868 - Leaf nodes: {}\n\
869 - Internal nodes: {}",
870 self.tree_type,
871 self.criterion,
872 self.max_depth,
873 actual_depth,
874 node_count,
875 leaf_count,
876 node_count - leaf_count
877 )
878 }
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884 use std::time::Duration;
885
886 #[derive(Clone, Debug, PartialOrd, Copy)]
888 struct TestFloat(f64);
889
890 impl PartialEq for TestFloat {
891 fn eq(&self, other: &Self) -> bool {
892 (self.0 - other.0).abs() < f64::EPSILON
893 }
894 }
895
896 impl Eq for TestFloat {}
897
898 impl std::hash::Hash for TestFloat {
899 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
900 let bits = self.0.to_bits();
901 bits.hash(state);
902 }
903 }
904
905 impl ToPrimitive for TestFloat {
906 fn to_i64(&self) -> Option<i64> {
907 self.0.to_i64()
908 }
909
910 fn to_u64(&self) -> Option<u64> {
911 self.0.to_u64()
912 }
913
914 fn to_f64(&self) -> Option<f64> {
915 Some(self.0)
916 }
917 }
918
919 impl NumCast for TestFloat {
920 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
921 n.to_f64().map(TestFloat)
922 }
923 }
924
925 impl FromPrimitive for TestFloat {
926 fn from_i64(n: i64) -> Option<Self> {
927 Some(TestFloat(n as f64))
928 }
929
930 fn from_u64(n: u64) -> Option<Self> {
931 Some(TestFloat(n as f64))
932 }
933
934 fn from_f64(n: f64) -> Option<Self> {
935 Some(TestFloat(n))
936 }
937 }
938
939 impl AsPrimitive<f64> for TestFloat {
940 fn as_(self) -> f64 {
941 self.0
942 }
943 }
944
945 impl AsPrimitive<TestFloat> for f64 {
946 fn as_(self) -> TestFloat {
947 TestFloat(self)
948 }
949 }
950
951 #[test]
953 fn test_diabetes_prediction() {
954 let mut tree = DecisionTree::<TestFloat, f64>::new(
956 TreeType::Regression,
957 SplitCriterion::Mse,
958 5, 2, 1, );
962
963 let features = vec![
965 vec![45.0, 22.5, 95.0, 120.0, 0.0], vec![50.0, 26.0, 105.0, 140.0, 1.0], vec![35.0, 23.0, 90.0, 115.0, 0.0], vec![55.0, 30.0, 140.0, 150.0, 1.0], vec![60.0, 29.5, 130.0, 145.0, 1.0], vec![40.0, 24.0, 85.0, 125.0, 0.0], vec![48.0, 27.0, 110.0, 135.0, 1.0], vec![65.0, 31.0, 150.0, 155.0, 1.0], vec![42.0, 25.0, 100.0, 130.0, 0.0], vec![58.0, 32.0, 145.0, 160.0, 1.0], ];
976
977 let target = vec![
979 TestFloat(2.0),
980 TestFloat(5.5),
981 TestFloat(1.5),
982 TestFloat(8.0),
983 TestFloat(6.5),
984 TestFloat(2.0),
985 TestFloat(5.0),
986 TestFloat(8.5),
987 TestFloat(3.0),
988 TestFloat(9.0),
989 ];
990
991 tree.fit(&features, &target).unwrap();
993
994 let test_features = vec![
996 vec![45.0, 23.0, 90.0, 120.0, 0.0], vec![62.0, 31.0, 145.0, 155.0, 1.0], ];
999
1000 let predictions = tree.predict(&test_features).unwrap();
1001
1002 assert!(
1004 predictions[0].0 < 5.0,
1005 "Young healthy patient should have low risk score"
1006 );
1007 assert!(
1008 predictions[1].0 > 5.0,
1009 "Older patient with high metrics should have high risk score"
1010 );
1011
1012 assert!(tree.is_trained());
1014 assert!(tree.calculate_depth() <= tree.get_max_depth());
1015 assert!(tree.get_leaf_count() > 0);
1016
1017 println!("Diabetes prediction tree:\n{}", tree.summary());
1019 }
1020
1021 #[test]
1023 fn test_disease_classification() {
1024 let mut tree = DecisionTree::<u8, f64>::new(
1026 TreeType::Classification,
1027 SplitCriterion::Gini,
1028 4, 2, 1, );
1032
1033 let features = vec![
1036 vec![3, 1, 2, 1, 0, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], vec![0, 3, 1, 0, 2, 2], vec![3, 2, 3, 2, 1, 0], vec![1, 3, 2, 0, 0, 3], vec![2, 0, 2, 3, 1, 0], vec![0, 2, 1, 0, 2, 2], vec![3, 1, 2, 1, 1, 0], vec![2, 3, 2, 0, 1, 2], vec![1, 0, 1, 3, 0, 0], vec![0, 3, 2, 0, 1, 3], ];
1049
1050 let target = vec![1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
1052
1053 tree.fit(&features, &target).unwrap();
1055
1056 let test_features = vec![
1058 vec![3, 2, 2, 1, 1, 0], vec![1, 3, 2, 0, 1, 3], vec![2, 0, 1, 3, 0, 0], ];
1062
1063 let predictions = tree.predict(&test_features).unwrap();
1064
1065 assert_eq!(predictions[0], 1, "Should diagnose as Flu");
1067 assert_eq!(predictions[1], 2, "Should diagnose as COVID");
1068 assert_eq!(predictions[2], 3, "Should diagnose as Migraine");
1069
1070 println!("Disease classification tree:\n{}", tree.summary());
1072 }
1073
1074 #[test]
1075 fn test_system_failure_prediction() {
1076 let mut tree = DecisionTree::<i32, f64>::new(
1081 TreeType::Regression,
1082 SplitCriterion::Mse,
1083 2, 5, 2, );
1087
1088 let features = vec![
1091 vec![30, 40, 0],
1093 vec![35, 45, 1],
1094 vec![40, 50, 0],
1095 vec![25, 35, 1],
1096 vec![30, 40, 0],
1097 vec![90, 95, 10],
1099 vec![85, 90, 8],
1100 vec![95, 98, 15],
1101 vec![90, 95, 12],
1102 vec![80, 85, 7],
1103 ];
1104
1105 let target = vec![
1107 1000, 900, 950, 1100, 1050, 10, 15, 5, 8, 20, ];
1110
1111 tree.fit(&features, &target).unwrap();
1113
1114 println!("System failure tree summary:\n{}", tree.summary());
1116
1117 if tree.is_trained() {
1119 println!("Tree structure:\n{}", tree.tree_structure());
1120 }
1121
1122 if tree.is_trained() {
1124 let test_features = vec![
1126 vec![30, 40, 0], vec![90, 95, 10], ];
1129
1130 let predictions = match tree.predict(&test_features) {
1132 Ok(preds) => {
1133 println!("Successfully made predictions: {:?}", preds);
1134 preds
1135 }
1136 Err(e) => {
1137 println!("Error during prediction: {:?}", e);
1138 return; }
1140 };
1141
1142 if predictions.len() == 2 {
1144 assert!(
1145 predictions[0] > predictions[1],
1146 "Healthy system should have longer time to failure than failing system"
1147 );
1148 }
1149 } else {
1150 println!("Tree wasn't properly trained - skipping prediction tests");
1151 }
1152 }
1153
1154 #[test]
1156 fn test_security_incident_classification() {
1157 let mut tree = DecisionTree::<u8, f64>::new(
1159 TreeType::Classification,
1160 SplitCriterion::Entropy,
1161 5, 2, 1, );
1165
1166 let features = vec![
1168 vec![1, 0, 0, 0, 0], vec![5, 1, 1, 1, 0], vec![15, 3, 2, 1, 1], vec![2, 0, 1, 0, 0], vec![8, 2, 1, 1, 0], vec![20, 4, 3, 1, 1], vec![1, 0, 0, 1, 0], vec![6, 1, 2, 1, 0], vec![25, 5, 3, 1, 1], vec![3, 0, 0, 0, 0], vec![7, 2, 1, 0, 0], vec![18, 3, 2, 1, 1], vec![0, 0, 0, 0, 0], vec![9, 2, 2, 1, 0], vec![22, 4, 3, 1, 1], ];
1184
1185 let target = vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2];
1187
1188 tree.fit(&features, &target).unwrap();
1190
1191 let test_features = vec![
1193 vec![2, 0, 0, 0, 0], vec![7, 1, 1, 1, 0], vec![17, 3, 2, 1, 1], ];
1197
1198 let predictions = tree.predict(&test_features).unwrap();
1199
1200 assert_eq!(predictions[0], 0, "Should classify as normal activity");
1202 assert_eq!(predictions[1], 1, "Should classify as suspicious activity");
1203 assert_eq!(predictions[2], 2, "Should classify as potential breach");
1204
1205 println!(
1207 "Security incident classification tree:\n{}",
1208 tree.tree_structure()
1209 );
1210 }
1211
1212 #[test]
1214 fn test_custom_type_performance_analysis() {
1215 #[derive(Clone, PartialEq, Eq, Hash, Debug, Copy)]
1217 struct ResponseTime(Duration);
1218
1219 impl PartialOrd for ResponseTime {
1220 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1221 self.0.partial_cmp(&other.0)
1222 }
1223 }
1224
1225 impl ToPrimitive for ResponseTime {
1226 fn to_i64(&self) -> Option<i64> {
1227 Some(self.0.as_millis() as i64)
1228 }
1229
1230 fn to_u64(&self) -> Option<u64> {
1231 Some(self.0.as_millis() as u64)
1232 }
1233
1234 fn to_f64(&self) -> Option<f64> {
1235 Some(self.0.as_millis() as f64)
1236 }
1237 }
1238
1239 impl AsPrimitive<f64> for ResponseTime {
1240 fn as_(self) -> f64 {
1241 self.0.as_millis() as f64
1242 }
1243 }
1244
1245 impl NumCast for ResponseTime {
1246 fn from<T: ToPrimitive>(n: T) -> Option<Self> {
1247 n.to_u64()
1248 .map(|ms| ResponseTime(Duration::from_millis(ms as u64)))
1249 }
1250 }
1251
1252 impl FromPrimitive for ResponseTime {
1253 fn from_i64(n: i64) -> Option<Self> {
1254 if n >= 0 {
1255 Some(ResponseTime(Duration::from_millis(n as u64)))
1256 } else {
1257 None
1258 }
1259 }
1260
1261 fn from_u64(n: u64) -> Option<Self> {
1262 Some(ResponseTime(Duration::from_millis(n)))
1263 }
1264
1265 fn from_f64(n: f64) -> Option<Self> {
1266 if n >= 0.0 {
1267 Some(ResponseTime(Duration::from_millis(n as u64)))
1268 } else {
1269 None
1270 }
1271 }
1272 }
1273
1274 impl AsPrimitive<ResponseTime> for f64 {
1276 fn as_(self) -> ResponseTime {
1277 ResponseTime(Duration::from_millis(self as u64))
1278 }
1279 }
1280
1281 let mut tree = DecisionTree::<ResponseTime, f64>::new(
1283 TreeType::Regression,
1284 SplitCriterion::Mse,
1285 3, 2, 1, );
1289
1290 let features = vec![
1292 vec![10, 20, 3, 5],
1293 vec![50, 40, 8, 2],
1294 vec![20, 30, 4, 4],
1295 vec![100, 60, 12, 0],
1296 vec![30, 35, 6, 3],
1297 vec![80, 50, 10, 1],
1298 ];
1299
1300 let target = vec![
1302 ResponseTime(Duration::from_millis(100)),
1303 ResponseTime(Duration::from_millis(350)),
1304 ResponseTime(Duration::from_millis(150)),
1305 ResponseTime(Duration::from_millis(600)),
1306 ResponseTime(Duration::from_millis(200)),
1307 ResponseTime(Duration::from_millis(450)),
1308 ];
1309
1310 tree.fit(&features, &target).unwrap();
1312
1313 let test_features = vec![
1315 vec![15, 25, 3, 4], vec![90, 55, 11, 0], ];
1318
1319 let predictions = tree.predict(&test_features).unwrap();
1320
1321 assert!(
1323 predictions[0].0.as_millis() < 200,
1324 "Small request should have fast response time"
1325 );
1326 assert!(
1327 predictions[1].0.as_millis() > 400,
1328 "Large request should have slow response time"
1329 );
1330
1331 println!("Response time prediction tree:\n{}", tree.summary());
1333 }
1334
1335 #[test]
1337 fn test_empty_features() {
1338 let mut tree =
1339 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1340
1341 let empty_features: Vec<Vec<f64>> = vec![];
1343 let empty_target: Vec<i32> = vec![];
1344
1345 let result = tree.fit(&empty_features, &empty_target);
1346 assert!(
1347 result.is_err(),
1348 "Fitting with empty features should return an error"
1349 );
1350 }
1351
1352 #[test]
1354 fn test_single_class_classification() {
1355 let mut tree =
1356 DecisionTree::<u8, f64>::new(TreeType::Classification, SplitCriterion::Gini, 3, 2, 1);
1357
1358 let features = vec![
1360 vec![1, 2, 3],
1361 vec![4, 5, 6],
1362 vec![7, 8, 9],
1363 vec![10, 11, 12],
1364 ];
1365
1366 let target = vec![1, 1, 1, 1];
1368
1369 tree.fit(&features, &target).unwrap();
1371
1372 let prediction = tree.predict(&vec![vec![2, 3, 4]]).unwrap();
1374
1375 assert_eq!(prediction[0], 1);
1377
1378 assert_eq!(tree.get_node_count(), 1);
1380 assert_eq!(tree.get_leaf_count(), 1);
1381 }
1382
1383 #[test]
1384 fn test_predict_not_fitted() {
1385 let tree =
1387 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1388 let features = vec![vec![1.0, 2.0]];
1389 let result = tree.predict(&features);
1390 assert!(result.is_err());
1391 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
1392 }
1393
1394 #[test]
1395 fn test_fit_target_empty() {
1396 let mut tree =
1397 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1398 let features = vec![vec![1.0, 2.0]];
1399 let target: Vec<i32> = vec![];
1400 let result = tree.fit(&features, &target);
1401 assert!(result.is_err());
1402 assert!(matches!(result.unwrap_err(), StatsError::EmptyData { .. }));
1403 }
1404
1405 #[test]
1406 fn test_fit_length_mismatch() {
1407 let mut tree =
1408 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1409 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1410 let target = vec![1]; let result = tree.fit(&features, &target);
1412 assert!(result.is_err());
1413 assert!(matches!(
1414 result.unwrap_err(),
1415 StatsError::DimensionMismatch { .. }
1416 ));
1417 }
1418
1419 #[test]
1420 fn test_fit_inconsistent_feature_lengths() {
1421 let mut tree =
1422 DecisionTree::<i32, f64>::new(TreeType::Regression, SplitCriterion::Mse, 3, 2, 1);
1423 let features = vec![vec![1.0, 2.0], vec![3.0]]; let target = vec![1, 2];
1425 let result = tree.fit(&features, &target);
1426 assert!(result.is_err());
1427 assert!(matches!(
1428 result.unwrap_err(),
1429 StatsError::InvalidInput { .. }
1430 ));
1431 }
1432}