use crate::error::ModelError;
use ndarray::{Array1, Array2, Axis};
use ndarray_rand::rand::{SeedableRng, rng, rngs::StdRng, seq::SliceRandom};
pub fn train_test_split(
x: Array2<f64>,
y: Array1<f64>,
test_size: Option<f64>,
random_state: Option<u64>,
) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>, Array1<f64>), ModelError> {
let n_samples = x.nrows();
if n_samples == 0 {
return Err(ModelError::InputValidationError(
"Cannot split empty dataset".to_string(),
));
}
if n_samples != y.len() {
return Err(ModelError::InputValidationError(format!(
"x and y must have the same number of samples, x rows: {}, y length: {}",
n_samples,
y.len()
)));
}
let test_size = test_size.unwrap_or(0.3);
if test_size <= 0.0 || test_size >= 1.0 {
return Err(ModelError::InputValidationError(format!(
"test_size must be between 0 and 1 (exclusive), got {}",
test_size
)));
}
let n_test = if n_samples == 1 {
return Err(ModelError::InputValidationError(
"Cannot split a dataset with only 1 sample into train and test sets".to_string(),
));
} else if n_samples == 2 {
1 } else {
let calculated = (n_samples as f64 * test_size).round() as usize;
calculated.max(1).min(n_samples - 1) };
let mut indices: Vec<usize> = (0..n_samples).collect();
match random_state {
Some(seed) => {
let mut rng = StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
}
None => {
let mut rng = rng();
indices.shuffle(&mut rng);
}
}
let (test_indices, train_indices) = indices.split_at(n_test);
let x_train = x.select(Axis(0), train_indices);
let x_test = x.select(Axis(0), test_indices);
let y_train = y.select(Axis(0), train_indices);
let y_test = y.select(Axis(0), test_indices);
Ok((x_train, x_test, y_train, y_test))
}