#[test]
fn test_gini_impurity_empty() {
assert!((gini_impurity(&[]) - 0.0).abs() < 1e-7);
}
#[test]
fn test_gini_impurity_pure_single_class() {
let labels = vec![0, 0, 0, 0];
assert!((gini_impurity(&labels) - 0.0).abs() < 1e-7);
}
#[test]
fn test_gini_impurity_two_classes_balanced() {
let labels = vec![0, 1, 0, 1];
assert!((gini_impurity(&labels) - 0.5).abs() < 1e-7);
}
#[test]
fn test_gini_impurity_two_classes_unbalanced() {
let labels = vec![0, 0, 0, 1];
assert!((gini_impurity(&labels) - 0.375).abs() < 1e-7);
}
#[test]
fn test_gini_impurity_three_classes_uniform() {
let labels = vec![0, 1, 2, 0, 1, 2];
let expected = 1.0 - 3.0 * (1.0_f32 / 3.0).powi(2);
assert!((gini_impurity(&labels) - expected).abs() < 1e-6);
}
#[test]
fn test_gini_impurity_single_sample() {
let labels = vec![5];
assert!((gini_impurity(&labels) - 0.0).abs() < 1e-7);
}
#[test]
fn test_gini_split_empty_both() {
assert!((gini_split(&[], &[]) - 0.0).abs() < 1e-7);
}
#[test]
fn test_gini_split_perfect_split() {
let left = vec![0, 0, 0];
let right = vec![1, 1, 1];
assert!((gini_split(&left, &right) - 0.0).abs() < 1e-7);
}
#[test]
fn test_gini_split_worst_case() {
let left = vec![0, 1];
let right = vec![0, 1];
assert!((gini_split(&left, &right) - 0.5).abs() < 1e-7);
}
#[test]
fn test_get_sorted_unique_values_basic() {
let values = vec![3.0, 1.0, 2.0, 1.0, 3.0];
let unique = get_sorted_unique_values(&values);
assert_eq!(unique, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_get_sorted_unique_values_single() {
let values = vec![5.0];
let unique = get_sorted_unique_values(&values);
assert_eq!(unique, vec![5.0]);
}
#[test]
fn test_get_sorted_unique_values_all_same() {
let values = vec![2.0, 2.0, 2.0];
let unique = get_sorted_unique_values(&values);
assert_eq!(unique, vec![2.0]);
}
#[test]
fn test_get_sorted_unique_values_already_sorted() {
let values = vec![1.0, 2.0, 3.0, 4.0];
let unique = get_sorted_unique_values(&values);
assert_eq!(unique, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_split_labels_by_threshold_normal() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![0, 0, 1, 1];
let result = split_labels_by_threshold(&x, &y, 2.5);
let (left, right) = result.expect("split should produce non-empty partitions");
assert_eq!(left, vec![0, 0]);
assert_eq!(right, vec![1, 1]);
}
#[test]
fn test_split_labels_by_threshold_all_left() {
let x = vec![1.0, 2.0, 3.0];
let y = vec![0, 1, 0];
assert!(split_labels_by_threshold(&x, &y, 10.0).is_none());
}
#[test]
fn test_split_labels_by_threshold_all_right() {
let x = vec![5.0, 6.0, 7.0];
let y = vec![0, 1, 0];
assert!(split_labels_by_threshold(&x, &y, 0.0).is_none());
}
#[test]
fn test_split_labels_by_threshold_boundary() {
let x = vec![1.0, 2.0, 3.0];
let y = vec![0, 1, 2];
let (left, right) = split_labels_by_threshold(&x, &y, 2.0).expect("split should succeed");
assert_eq!(left, vec![0, 1]);
assert_eq!(right, vec![2]);
}
#[test]
fn test_calculate_information_gain_perfect_split() {
let current = gini_impurity(&[0, 0, 1, 1]);
let gain = calculate_information_gain(current, &[0, 0], &[1, 1]);
assert!((gain - current).abs() < 1e-7);
}
#[test]
fn test_calculate_information_gain_no_improvement() {
let current = gini_impurity(&[0, 1, 0, 1]);
let gain = calculate_information_gain(current, &[0, 1], &[0, 1]);
assert!(gain.abs() < 1e-7);
}
#[test]
fn test_find_best_split_for_feature_too_few_samples() {
let x = vec![1.0];
let y = vec![0];
assert!(find_best_split_for_feature(&x, &y).is_none());
}
#[test]
fn test_find_best_split_for_feature_all_same_value() {
let x = vec![1.0, 1.0, 1.0];
let y = vec![0, 1, 0];
assert!(find_best_split_for_feature(&x, &y).is_none());
}
#[test]
fn test_find_best_split_for_feature_clear_split() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![0, 0, 1, 1];
let result = find_best_split_for_feature(&x, &y);
let (threshold, gain) = result.expect("should find a split");
assert!((threshold - 2.5).abs() < 1e-7);
assert!(gain > 0.0);
}
#[test]
fn test_find_best_split_for_feature_pure_labels() {
let x = vec![1.0, 2.0, 3.0];
let y = vec![0, 0, 0];
assert!(find_best_split_for_feature(&x, &y).is_none());
}
#[test]
fn test_find_best_split_single_sample() {
let x = Matrix::from_vec(1, 2, vec![1.0, 2.0]).expect("matrix creation");
let y = vec![0];
assert!(find_best_split(&x, &y).is_none());
}
#[test]
fn test_find_best_split_separable() {
let x = Matrix::from_vec(4, 2, vec![1.0, 5.0, 2.0, 5.0, 3.0, 5.0, 4.0, 5.0])
.expect("matrix creation");
let y = vec![0, 0, 1, 1];
let (feat, threshold, gain) = find_best_split(&x, &y).expect("should find split");
assert_eq!(feat, 0);
assert!((threshold - 2.5).abs() < 1e-7);
assert!(gain > 0.0);
}
#[test]
fn test_find_best_split_pure_labels() {
let x =
Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("matrix creation");
let y = vec![0, 0, 0];
assert!(find_best_split(&x, &y).is_none());
}
#[test]
fn test_majority_class_single() {
assert_eq!(majority_class(&[7]), 7);
}
#[test]
fn test_majority_class_clear_winner() {
assert_eq!(majority_class(&[0, 1, 1, 1, 0]), 1);
}
#[test]
fn test_majority_class_all_same() {
assert_eq!(majority_class(&[3, 3, 3]), 3);
}
#[test]
fn test_split_data_by_indices_basic() {
let x = Matrix::from_vec(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("matrix creation");
let y = vec![0, 1, 2, 3];
let indices = vec![0, 2];
let (subset_x, subset_y) = split_data_by_indices(&x, &y, &indices);
assert_eq!(subset_x.shape(), (2, 2));
assert_eq!(subset_y, vec![0, 2]);
assert!((subset_x.get(0, 0) - 1.0).abs() < 1e-7);
assert!((subset_x.get(0, 1) - 2.0).abs() < 1e-7);
assert!((subset_x.get(1, 0) - 5.0).abs() < 1e-7);
assert!((subset_x.get(1, 1) - 6.0).abs() < 1e-7);
}
#[test]
fn test_check_stopping_criteria_pure_node() {
let y = vec![1, 1, 1];
let result = check_stopping_criteria(&y, 0, None);
let node = result.expect("pure node should trigger stop");
match node {
TreeNode::Leaf(leaf) => {
assert_eq!(leaf.class_label, 1);
assert_eq!(leaf.n_samples, 3);
}
TreeNode::Node(_) => panic!("expected leaf, got internal node"),
}
}
#[test]
fn test_check_stopping_criteria_max_depth() {
let y = vec![0, 1, 1];
let result = check_stopping_criteria(&y, 5, Some(5));
let node = result.expect("max depth should trigger stop");
match node {
TreeNode::Leaf(leaf) => {
assert_eq!(leaf.class_label, 1); assert_eq!(leaf.n_samples, 3);
}
TreeNode::Node(_) => panic!("expected leaf, got internal node"),
}
}
#[test]
fn test_check_stopping_criteria_no_stop() {
let y = vec![0, 1, 1];
let result = check_stopping_criteria(&y, 2, Some(5));
assert!(result.is_none());
}
#[test]
fn test_check_stopping_criteria_no_max_depth() {
let y = vec![0, 1, 1];
let result = check_stopping_criteria(&y, 100, None);
assert!(result.is_none());
}
#[test]
fn test_split_indices_by_threshold_normal() {
let x = Matrix::from_vec(4, 1, vec![1.0, 3.0, 2.0, 4.0]).expect("matrix creation");
let result = split_indices_by_threshold(&x, 0, 2.5, 4);
let (left, right) = result.expect("should split successfully");
assert_eq!(left, vec![0, 2]);
assert_eq!(right, vec![1, 3]);
}
#[test]
fn test_split_indices_by_threshold_all_left() {
let x = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("matrix creation");
assert!(split_indices_by_threshold(&x, 0, 10.0, 3).is_none());
}
#[test]
fn test_split_indices_by_threshold_all_right() {
let x = Matrix::from_vec(3, 1, vec![5.0, 6.0, 7.0]).expect("matrix creation");
assert!(split_indices_by_threshold(&x, 0, 0.0, 3).is_none());
}
#[test]
fn test_flatten_and_reconstruct_leaf() {
let leaf = TreeNode::Leaf(Leaf {
class_label: 2,
n_samples: 10,
});
let mut features = Vec::new();
let mut thresholds = Vec::new();
let mut classes = Vec::new();
let mut samples = Vec::new();
let mut left_children = Vec::new();
let mut right_children = Vec::new();
let root_idx = flatten_tree_node(
&leaf,
&mut features,
&mut thresholds,
&mut classes,
&mut samples,
&mut left_children,
&mut right_children,
);
assert_eq!(root_idx, 0);
assert_eq!(features.len(), 1);
assert!((features[0] - (-1.0)).abs() < 1e-7);
assert!((classes[0] - 2.0).abs() < 1e-7);
assert!((samples[0] - 10.0).abs() < 1e-7);
let reconstructed = reconstruct_tree_node(
root_idx,
&features,
&thresholds,
&classes,
&samples,
&left_children,
&right_children,
);
match reconstructed {
TreeNode::Leaf(l) => {
assert_eq!(l.class_label, 2);
assert_eq!(l.n_samples, 10);
}
TreeNode::Node(_) => panic!("expected leaf"),
}
}