#![cfg(feature = "machine_learning")]
use ndarray::{arr1, arr2};
use rustyml::machine_learning::decision_tree::{
Algorithm, DecisionTree, DecisionTreeParams, Node, NodeType,
};
#[test]
fn test_decision_tree_new() {
let dt = DecisionTree::new(Algorithm::CART, true, None).unwrap();
assert!(matches!(dt.get_algorithm(), Algorithm::CART));
assert!(dt.get_is_classifier());
let params = DecisionTreeParams {
max_depth: Some(5),
min_samples_split: 10,
min_samples_leaf: 5,
min_impurity_decrease: 0.1,
random_state: Some(42),
};
let dt = DecisionTree::new(Algorithm::CART, false, Some(params)).unwrap();
assert!(matches!(dt.get_algorithm(), Algorithm::CART));
assert!(!dt.get_is_classifier());
assert_eq!(dt.get_parameters().max_depth, Some(5));
assert_eq!(dt.get_parameters().min_samples_split, 10);
assert_eq!(dt.get_parameters().min_samples_leaf, 5);
assert_eq!(dt.get_parameters().min_impurity_decrease, 0.1);
assert_eq!(dt.get_parameters().random_state, Some(42));
}
#[test]
fn test_decision_tree_params_default() {
let params = DecisionTreeParams::default();
assert_eq!(params.max_depth, None);
assert_eq!(params.min_samples_split, 2);
assert_eq!(params.min_samples_leaf, 1);
assert!(params.min_impurity_decrease < 1e-6);
assert_eq!(params.random_state, None);
}
#[test]
fn test_fit_predict_classifier() {
let x = arr2(&[
[2.0, 2.0],
[2.0, 3.0],
[3.0, 2.0],
[3.0, 3.0],
[1.0, 1.0],
[1.0, 2.0],
]);
let y = arr1(&[0.0, 0.0, 0.0, 0.0, 1.0, 1.0]);
let mut dt = DecisionTree::new(Algorithm::CART, true, None).unwrap();
dt.fit(&x.view(), &y.view()).unwrap();
assert_eq!(dt.get_n_features(), 2);
assert!(dt.get_root().is_some());
let predictions = dt.predict(&x.view()).unwrap();
assert_eq!(predictions.len(), 6);
for i in 0..6 {
assert_eq!(predictions[i], y[i]);
}
let sample = &[2.0, 2.0];
let pred = dt.predict_one(sample).unwrap();
assert_eq!(pred, 0.0);
let sample = &[1.0, 1.0];
let pred = dt.predict_one(sample).unwrap();
assert_eq!(pred, 1.0);
println!("{}", dt.generate_tree_structure().unwrap());
}
#[test]
fn test_fit_predict_regressor() {
let x = arr2(&[[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]);
let y = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut dt = DecisionTree::new(Algorithm::CART, false, None).unwrap();
dt.fit(&x.view(), &y.view()).unwrap();
assert_eq!(dt.get_n_features(), 1);
assert!(dt.get_root().is_some());
let predictions = dt.predict(&x.view()).unwrap();
assert_eq!(predictions.len(), 6);
for i in 0..6 {
assert!((predictions[i] - y[i]).abs() < 0.5);
}
println!("{}", dt.generate_tree_structure().unwrap());
}
#[test]
fn test_predict_proba() {
let x = arr2(&[
[1.0, 1.0],
[1.0, 2.0],
[2.0, 1.0],
[2.0, 2.0],
[3.0, 3.0],
[4.0, 4.0],
]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
let mut dt = DecisionTree::new(Algorithm::CART, true, None).unwrap();
dt.fit(&x.view(), &y.view()).unwrap();
let prob = dt.predict_proba(&x.view()).unwrap();
assert_eq!(prob.shape()[0], 6); assert_eq!(prob.shape()[1], 3);
for i in 0..6 {
let row_sum: f64 = prob.row(i).sum();
assert!((row_sum - 1.0).abs() < 1e-5);
}
let sample = &[1.0, 1.0];
let probs = dt.predict_proba_one(sample).unwrap();
assert_eq!(probs.len(), 3);
assert!((probs.iter().sum::<f64>() - 1.0).abs() < 1e-5);
}
#[test]
fn test_different_algorithms() {
let x = arr2(&[[1.0, 1.0], [1.0, 2.0], [2.0, 1.0], [2.0, 2.0]]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0]);
let mut dt_id3 = DecisionTree::new(Algorithm::ID3, true, None).unwrap();
dt_id3.fit(&x.view(), &y.view()).unwrap();
let pred_id3 = dt_id3.predict(&x.view()).unwrap();
let mut dt_c45 = DecisionTree::new(Algorithm::C45, true, None).unwrap();
dt_c45.fit(&x.view(), &y.view()).unwrap();
let pred_c45 = dt_c45.predict(&x.view()).unwrap();
let mut dt_cart = DecisionTree::new(Algorithm::CART, true, None).unwrap();
dt_cart.fit(&x.view(), &y.view()).unwrap();
let pred_cart = dt_cart.predict(&x.view()).unwrap();
for i in 0..4 {
assert_eq!(pred_id3[i], y[i]);
assert_eq!(pred_c45[i], y[i]);
assert_eq!(pred_cart[i], y[i]);
}
}
#[test]
fn test_max_depth_parameter() {
let x = arr2(&[
[1.0, 1.0],
[1.0, 2.0],
[2.0, 1.0],
[2.0, 2.0],
[3.0, 3.0],
[3.0, 4.0],
[4.0, 3.0],
[4.0, 4.0],
]);
let y = arr1(&[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let params_depth1 = DecisionTreeParams {
max_depth: Some(1),
..DecisionTreeParams::default()
};
let mut dt_limited = DecisionTree::new(Algorithm::CART, true, Some(params_depth1)).unwrap();
dt_limited.fit(&x.view(), &y.view()).unwrap();
let _predictions = dt_limited.predict(&x.view()).unwrap();
let mut dt_unlimited = DecisionTree::new(Algorithm::CART, true, None).unwrap();
dt_unlimited.fit(&x.view(), &y.view()).unwrap();
let predictions_unlimited = dt_unlimited.predict(&x.view()).unwrap();
for i in 0..8 {
assert_eq!(predictions_unlimited[i], y[i]);
}
}
#[test]
fn test_error_handling() {
let dt = DecisionTree::new(Algorithm::CART, true, None).unwrap();
assert!(dt.get_root().is_none());
assert!(dt.get_n_classes().is_none());
let x = arr2(&[[1.0, 2.0]]);
assert!(dt.predict(&x.view()).is_err());
let sample = &[1.0, 2.0];
assert!(dt.predict_one(sample).is_err());
assert!(dt.predict_proba(&x.view()).is_err());
assert!(dt.predict_proba_one(sample).is_err());
}
#[test]
fn test_node_creation() {
let leaf = Node::new_leaf(1.5, Some(0), Some(vec![0.8, 0.2]));
match leaf.node_type {
NodeType::Leaf {
value,
class,
probabilities,
} => {
assert_eq!(value, 1.5);
assert_eq!(class, Some(0));
assert_eq!(probabilities, Some(vec![0.8, 0.2]));
}
_ => panic!("Expected a leaf node"),
}
let internal = Node::new_internal(1, 0.5);
match internal.node_type {
NodeType::Internal {
feature_index,
threshold,
categories,
} => {
assert_eq!(feature_index, 1);
assert_eq!(threshold, 0.5);
assert_eq!(categories, None);
}
_ => panic!("Expected an internal node"),
}
let categorical = Node::new_categorical(2, vec!["A".to_string(), "B".to_string()]);
match categorical.node_type {
NodeType::Internal {
feature_index,
threshold: _,
categories,
} => {
assert_eq!(feature_index, 2);
assert_eq!(categories, Some(vec!["A".to_string(), "B".to_string()]));
}
_ => panic!("Expected an internal categorical node"),
}
}