use super::*;
use crate::primitives::Matrix;
#[test]
fn falsify_dt_001_predictions_in_label_range() {
let x = Matrix::from_vec(
6,
2,
vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0],
)
.expect("valid matrix");
let y = vec![0_usize, 0, 1, 1, 2, 2];
let mut dt = DecisionTreeClassifier::new();
dt.fit(&x, &y).expect("fit succeeds");
let preds = dt.predict(&x);
for (i, &p) in preds.iter().enumerate() {
assert!(
p <= 2,
"FALSIFIED DT-001: prediction[{i}] = {p}, not in [0, 2]"
);
}
}
#[test]
fn falsify_dt_002_deterministic() {
let x =
Matrix::from_vec(4, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).expect("valid matrix");
let y = vec![0_usize, 0, 1, 1];
let mut dt = DecisionTreeClassifier::new();
dt.fit(&x, &y).expect("fit");
let p1 = dt.predict(&x);
let p2 = dt.predict(&x);
assert_eq!(p1, p2, "FALSIFIED DT-002: predictions differ on same input");
}
#[test]
fn falsify_dt_003_perfect_separable() {
let x = Matrix::from_vec(4, 1, vec![0.0, 1.0, 10.0, 11.0]).expect("valid matrix");
let y = vec![0_usize, 0, 1, 1];
let mut dt = DecisionTreeClassifier::new();
dt.fit(&x, &y).expect("fit");
let preds = dt.predict(&x);
assert_eq!(
preds, y,
"FALSIFIED DT-003: tree cannot perfectly fit separable data"
);
}
#[test]
fn falsify_dt_004_prediction_count() {
let x_train =
Matrix::from_vec(4, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).expect("valid");
let y_train = vec![0_usize, 0, 1, 1];
let mut dt = DecisionTreeClassifier::new();
dt.fit(&x_train, &y_train).expect("fit");
let x_test = Matrix::from_vec(3, 2, vec![0.5, 0.5, 1.5, 1.5, 2.5, 2.5]).expect("valid");
let preds = dt.predict(&x_test);
assert_eq!(
preds.len(),
3,
"FALSIFIED DT-004: {} predictions for 3 inputs",
preds.len()
);
}