use crate::primitives::{Matrix, Vector};
use crate::traits::Estimator;
use crate::tree::helpers::{
bootstrap_sample, build_tree, find_best_split, find_best_split_for_feature, gini_impurity,
gini_split, majority_class,
};
use crate::tree::*;
#[test]
fn test_leaf_creation() {
let leaf = Leaf {
class_label: 1,
n_samples: 10,
};
assert_eq!(leaf.class_label, 1);
assert_eq!(leaf.n_samples, 10);
}
#[test]
fn test_node_creation() {
let left = TreeNode::Leaf(Leaf {
class_label: 0,
n_samples: 5,
});
let right = TreeNode::Leaf(Leaf {
class_label: 1,
n_samples: 5,
});
let node = Node {
feature_idx: 0,
threshold: 0.5,
left: Box::new(left),
right: Box::new(right),
};
assert_eq!(node.feature_idx, 0);
assert!((node.threshold - 0.5).abs() < 1e-6);
}
#[test]
fn test_tree_depth() {
let leaf = TreeNode::Leaf(Leaf {
class_label: 0,
n_samples: 1,
});
assert_eq!(leaf.depth(), 0);
let tree = TreeNode::Node(Node {
feature_idx: 0,
threshold: 0.5,
left: Box::new(TreeNode::Leaf(Leaf {
class_label: 0,
n_samples: 1,
})),
right: Box::new(TreeNode::Leaf(Leaf {
class_label: 1,
n_samples: 1,
})),
});
assert_eq!(tree.depth(), 1);
let deep_tree = TreeNode::Node(Node {
feature_idx: 0,
threshold: 0.5,
left: Box::new(tree),
right: Box::new(TreeNode::Leaf(Leaf {
class_label: 1,
n_samples: 1,
})),
});
assert_eq!(deep_tree.depth(), 2);
}
#[test]
fn test_decision_tree_creation() {
let tree = DecisionTreeClassifier::new();
assert!(tree.tree.is_none());
assert!(tree.max_depth.is_none());
}
#[test]
fn test_decision_tree_with_max_depth() {
let tree = DecisionTreeClassifier::new().with_max_depth(5);
assert_eq!(tree.max_depth, Some(5));
}
#[test]
fn test_decision_tree_default() {
let tree = DecisionTreeClassifier::default();
assert!(tree.tree.is_none());
assert!(tree.max_depth.is_none());
}
#[test]
fn test_gini_impurity_pure() {
let pure = vec![0, 0, 0, 0, 0];
assert!((gini_impurity(&pure) - 0.0).abs() < 1e-6);
let pure_ones = vec![1, 1, 1];
assert!((gini_impurity(&pure_ones) - 0.0).abs() < 1e-6);
}
#[test]
fn test_gini_impurity_empty() {
let empty: Vec<usize> = vec![];
assert!((gini_impurity(&empty) - 0.0).abs() < 1e-6);
}
#[test]
fn test_gini_impurity_binary_50_50() {
let mixed = vec![0, 1, 0, 1];
assert!((gini_impurity(&mixed) - 0.5).abs() < 1e-6);
}
#[test]
fn test_gini_impurity_three_class_even() {
let three_class = vec![0, 1, 2, 0, 1, 2];
assert!((gini_impurity(&three_class) - 0.6667).abs() < 1e-4);
}
#[test]
fn test_gini_impurity_bounds() {
let labels_sets = vec![
vec![0, 0, 0],
vec![0, 1],
vec![0, 1, 2],
vec![0, 0, 1, 1, 2, 2],
vec![0, 0, 0, 1],
];
for labels in labels_sets {
let gini = gini_impurity(&labels);
assert!(gini >= 0.0, "Gini should be >= 0, got {gini}");
assert!(gini <= 1.0, "Gini should be <= 1, got {gini}");
}
}
#[test]
fn test_gini_split_calculation() {
let left = vec![0, 0, 0]; let right = vec![1, 1, 1]; assert!((gini_split(&left, &right) - 0.0).abs() < 1e-6);
}
#[test]
fn test_gini_split_mixed() {
let left = vec![0, 0, 1];
let right = vec![1, 1];
let expected = (3.0 / 5.0) * (4.0 / 9.0);
assert!((gini_split(&left, &right) - expected).abs() < 1e-4);
}
#[test]
fn test_find_best_split_simple() {
let x = vec![1.0, 2.0, 5.0, 6.0];
let y = vec![0, 0, 1, 1];
let result = find_best_split_for_feature(&x, &y);
assert!(result.is_some());
let (threshold, gain) = result.expect("should have valid result");
assert!(threshold > 2.0 && threshold < 5.0);
assert!(gain > 0.0);
}
#[test]
fn test_find_best_split_no_gain() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![0, 0, 0, 0];
let result = find_best_split_for_feature(&x, &y);
assert!(result.is_none());
}
#[test]
fn test_find_best_split_too_small() {
let x = vec![1.0];
let y = vec![0];
let result = find_best_split_for_feature(&x, &y);
assert!(result.is_none());
}
#[test]
fn test_find_best_split_gain_is_positive() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![0, 1, 0, 1];
if let Some((_threshold, gain)) = find_best_split_for_feature(&x, &y) {
assert!(gain >= 0.0);
}
}
#[test]
fn test_find_best_split_across_features() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 2, vec![1.0, 1.0, 1.0, 2.0, 5.0, 5.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let result = find_best_split(&x, &y);
assert!(result.is_some());
let (feature_idx, _threshold, gain) = result.expect("should have valid result");
assert!(feature_idx < 2);
assert!(gain > 0.0);
}
#[test]
fn test_find_best_split_perfect_separation() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let result = find_best_split(&x, &y);
assert!(result.is_some());
let (feature_idx, threshold, gain) = result.expect("should have valid result");
assert_eq!(feature_idx, 0); assert!(threshold > 2.0 && threshold < 5.0);
assert!(gain > 0.4); }
#[test]
fn test_majority_class_simple() {
let labels = vec![0, 0, 1, 0, 1];
assert_eq!(majority_class(&labels), 0);
}
#[test]
fn test_majority_class_tie() {
let labels = vec![0, 1, 0, 1];
let result = majority_class(&labels);
assert!(result == 0 || result == 1);
}
#[test]
fn test_majority_class_single() {
let labels = vec![5];
assert_eq!(majority_class(&labels), 5);
}
#[test]
fn test_build_tree_pure_leaf() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 0];
let tree = build_tree(&x, &y, 0, None);
match tree {
TreeNode::Leaf(leaf) => {
assert_eq!(leaf.class_label, 0);
assert_eq!(leaf.n_samples, 3);
}
TreeNode::Node(_) => panic!("Expected Leaf node for pure data"),
}
}
#[test]
fn test_build_tree_max_depth_zero() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let tree = build_tree(&x, &y, 0, Some(0));
match tree {
TreeNode::Leaf(leaf) => {
assert!(leaf.class_label == 0 || leaf.class_label == 1);
assert_eq!(leaf.n_samples, 4);
}
TreeNode::Node(_) => panic!("Expected Leaf node at max depth"),
}
}
#[test]
fn test_build_tree_simple_split() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let tree = build_tree(&x, &y, 0, Some(5));
match tree {
TreeNode::Node(node) => {
assert_eq!(node.feature_idx, 0); assert!(node.threshold > 2.0 && node.threshold < 5.0);
}
TreeNode::Leaf(_) => panic!("Expected Node for splittable data"),
}
}
#[test]
fn test_build_tree_depth_tracking() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let tree = build_tree(&x, &y, 0, Some(1));
assert!(tree.depth() <= 1);
}
#[test]
fn test_fit_simple() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let mut tree = DecisionTreeClassifier::new().with_max_depth(5);
let result = tree.fit(&x, &y);
assert!(result.is_ok());
assert!(tree.tree.is_some()); }
#[test]
fn test_predict_perfect_classification() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let mut tree = DecisionTreeClassifier::new().with_max_depth(5);
tree.fit(&x, &y).expect("fit should succeed");
let predictions = tree.predict(&x);
assert_eq!(predictions, vec![0, 0, 1, 1]);
}
#[test]
fn test_predict_single_sample() {
use crate::primitives::Matrix;
let x_train = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y_train = vec![0, 0, 1, 1];
let mut tree = DecisionTreeClassifier::new().with_max_depth(5);
tree.fit(&x_train, &y_train).expect("fit should succeed");
let x_test =
Matrix::from_vec(1, 1, vec![1.5]).expect("Matrix creation should succeed in tests");
let predictions = tree.predict(&x_test);
assert_eq!(predictions.len(), 1);
assert_eq!(predictions[0], 0); }
#[test]
fn test_score_perfect() {
use crate::primitives::Matrix;
let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 5.0, 6.0])
.expect("Matrix creation should succeed in tests");
let y = vec![0, 0, 1, 1];
let mut tree = DecisionTreeClassifier::new().with_max_depth(5);
tree.fit(&x, &y).expect("fit should succeed");
let accuracy = tree.score(&x, &y);
assert!((accuracy - 1.0).abs() < 1e-6); }
include!("core_ensemble_and_boosting.rs");
include!("core_regression_tree.rs");
include!("core_random_forest_regressor.rs");