use scirs2_core::ndarray::{array, Array2};
use sklears::metrics::classification::accuracy_score;
use sklears::neighbors::KNeighborsClassifier;
use sklears::prelude::*;
use sklears::utils::data_generation::make_classification;
#[test]
#[allow(non_snake_case)]
fn test_basic_knn_pipeline() {
let (X, y) = make_classification(50, 2, 2, None, None, 0.0, 3.0, Some(42))
.expect("operation should succeed");
assert_eq!(X.nrows(), y.len());
let classifier = KNeighborsClassifier::new(5);
let fitted_classifier = classifier
.fit(&X, &y)
.expect("model fitting should succeed");
let predictions = fitted_classifier
.predict(&X)
.expect("prediction should succeed");
let accuracy = accuracy_score(&y, &predictions).expect("operation should succeed");
assert!(
accuracy >= 0.7,
"Accuracy should be >= 0.7, got {}",
accuracy
);
assert_eq!(predictions.len(), y.len());
for &pred in predictions.iter() {
assert!(
(0..=1).contains(&pred),
"Predicted class {} should be in [0, 1]",
pred
);
}
}
#[test]
#[allow(non_snake_case)]
fn test_data_generation_consistency() {
let (X1, y1) = make_classification(30, 3, 2, None, None, 0.0, 1.0, Some(123))
.expect("operation should succeed");
let (X2, y2) = make_classification(30, 3, 2, None, None, 0.0, 1.0, Some(123))
.expect("operation should succeed");
assert_eq!(X1, X2);
assert_eq!(y1, y2);
assert_eq!(X1.shape(), &[30, 3]);
assert_eq!(y1.len(), 30);
let mut classes: Vec<i32> = y1.iter().copied().collect();
classes.sort_unstable();
classes.dedup();
assert_eq!(classes.len(), 2);
}
#[test]
#[allow(non_snake_case)]
fn test_metrics_basic_functionality() {
let y_true = array![0, 1, 1, 0, 1, 0, 1, 1, 0, 0];
let y_pred = array![0, 1, 0, 0, 1, 1, 1, 1, 0, 1];
let accuracy = accuracy_score(&y_true, &y_pred).expect("operation should succeed");
assert!((0.0..=1.0).contains(&accuracy));
let expected_accuracy = 0.7;
assert!(
(accuracy - expected_accuracy).abs() < 1e-10,
"Expected accuracy {}, got {}",
expected_accuracy,
accuracy
);
}
#[test]
#[allow(non_snake_case)]
fn test_utility_functions() {
let data = array![1, 2, 3, 2, 1, 3, 2];
assert_eq!(data.len(), 7);
let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("shape and data length should match");
assert_eq!(x.shape(), &[3, 2]);
assert_eq!(x[[0, 0]], 1.0);
assert_eq!(x[[2, 1]], 6.0);
}