#![cfg(feature = "utility")]
use ndarray::prelude::*;
use rustyml::error::ModelError;
use rustyml::utility::train_test_split::*;
#[test]
fn test_train_test_split_valid_input() {
let x = Array2::from_shape_vec(
(5, 2),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
)
.unwrap();
let y = Array1::from(vec![0.0, 1.0, 0.0, 1.0, 0.0]);
let result = train_test_split(x.clone(), y.clone(), Some(0.4), Some(42));
assert!(result.is_ok());
let (x_train, x_test, y_train, y_test) = result.unwrap();
assert_eq!(x_train.nrows(), 3); assert_eq!(x_test.nrows(), 2); assert_eq!(y_train.len(), 3);
assert_eq!(y_test.len(), 2);
let all_x_rows = x_train
.rows()
.into_iter()
.chain(x_test.rows().into_iter())
.collect::<Vec<_>>();
let all_y_values = y_train.iter().chain(y_test.iter()).collect::<Vec<_>>();
assert_eq!(all_x_rows.len(), 5);
assert_eq!(all_y_values.len(), 5);
}
#[test]
fn test_default_parameters() {
let x = Array2::from_shape_vec((10, 2), (1..=20).map(|x| x as f64).collect()).unwrap();
let y = Array1::from_vec((0..10).map(|i| (i % 2) as f64).collect());
let result = train_test_split(x.clone(), y.clone(), None, Some(42));
assert!(result.is_ok());
let (x_train, x_test, _y_train, _y_test) = result.unwrap();
assert_eq!(x_train.nrows(), 7);
assert_eq!(x_test.nrows(), 3);
}
#[test]
fn test_same_random_state_gives_same_split() {
let x = Array2::from_shape_vec((10, 2), (1..=20).map(|x| x as f64).collect()).unwrap();
let y = Array1::from_vec((0..10).map(|i| (i % 2) as f64).collect());
let result1 = train_test_split(x.clone(), y.clone(), Some(0.3), Some(42)).unwrap();
let result2 = train_test_split(x.clone(), y.clone(), Some(0.3), Some(42)).unwrap();
assert_eq!(result1.0, result2.0); assert_eq!(result1.1, result2.1); assert_eq!(result1.2, result2.2); assert_eq!(result1.3, result2.3); }
#[test]
fn test_different_random_states_give_different_splits() {
let x = Array2::from_shape_vec((100, 2), (1..=200).map(|x| x as f64).collect()).unwrap();
let y = Array1::from_vec((0..100).map(|i| (i % 2) as f64).collect());
let result1 = train_test_split(x.clone(), y.clone(), Some(0.3), Some(42)).unwrap();
let result2 = train_test_split(x.clone(), y.clone(), Some(0.3), Some(43)).unwrap();
assert!(result1.0 != result2.0 || result1.1 != result2.1);
}
#[test]
fn test_error_different_sample_sizes() {
let x = Array2::from_shape_vec(
(5, 2),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
)
.unwrap();
let y = Array1::from(vec![0.0, 1.0, 0.0]);
let result = train_test_split(x, y, Some(0.4), Some(42));
assert!(result.is_err());
if let Err(ModelError::InputValidationError(msg)) = result {
assert!(msg.contains("x and y must have the same number of samples"));
} else {
panic!("Expected InputValidationError");
}
}
#[test]
fn test_error_invalid_test_size() {
let x = Array2::from_shape_vec(
(5, 2),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
)
.unwrap();
let y = Array1::from(vec![0.0, 1.0, 0.0, 1.0, 0.0]);
let result = train_test_split(x.clone(), y.clone(), Some(-0.1), Some(42));
assert!(result.is_err());
let result = train_test_split(x.clone(), y.clone(), Some(1.5), Some(42));
assert!(result.is_err());
}
#[test]
fn test_consistent_x_y_split() {
let x = Array2::from_shape_vec(
(5, 2),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
)
.unwrap();
let y = Array1::from(vec![100.0, 200.0, 300.0, 400.0, 500.0]);
let (x_train, x_test, y_train, y_test) =
train_test_split(x.clone(), y.clone(), Some(0.4), Some(42)).unwrap();
for i in 0..x.nrows() {
let x_row = x.row(i);
let y_val = y[i];
let in_train_x = x_train.rows().into_iter().any(|r| r == x_row);
let in_train_y = y_train.iter().any(|&val| val == y_val);
let in_test_x = x_test.rows().into_iter().any(|r| r == x_row);
let in_test_y = y_test.iter().any(|&val| val == y_val);
assert_eq!(in_train_x, in_train_y);
assert_eq!(in_test_x, in_test_y);
}
}