use scirs2_core::ndarray::{Array1, ArrayBase, Data, Dimension};
use scirs2_core::numeric::NumCast;
use scirs2_core::random::{rngs::StdRng, seq::SliceRandom, SeedableRng};
use crate::error::{MetricsError, Result};
pub mod advanced_statistical;
pub mod cross_validation;
pub mod statistical;
pub mod workflow;
pub use cross_validation::{
grouped_k_fold, k_fold_cross_validation, leave_one_out_cv, nested_cross_validation,
stratified_k_fold, time_series_split,
};
pub use statistical::{
bootstrap_confidence_interval, cochrans_q_test, friedman_test, mcnemars_test,
wilcoxon_signed_rank_test,
};
pub type TrainTestSplitResult<T> = (Vec<Array1<T>>, Vec<Array1<T>>);
#[allow(dead_code)]
pub fn train_test_split<T>(
arrays: &[&ArrayBase<impl Data<Elem = T>, impl Dimension>],
test_size: f64,
random_seed: Option<u64>,
) -> Result<TrainTestSplitResult<T>>
where
T: Clone + NumCast,
{
if arrays.is_empty() {
return Err(MetricsError::InvalidInput(
"No arrays provided for splitting".to_string(),
));
}
let n_samples = arrays[0].shape()[0];
for (i, arr) in arrays.iter().enumerate().skip(1) {
if arr.shape()[0] != n_samples {
return Err(MetricsError::InvalidInput(format!(
"Arrays have different lengths: arrays[0]: {}, arrays[{}]: {}",
n_samples,
i,
arr.shape()[0]
)));
}
}
if test_size <= 0.0 || test_size >= 1.0 {
return Err(MetricsError::InvalidInput(format!(
"test_size must be between 0 and 1, got {}",
test_size
)));
}
let n_test = (n_samples as f64 * test_size).round() as usize;
if n_test == 0 || n_test >= n_samples {
return Err(MetricsError::InvalidInput(format!(
"test_size={} resulted in an invalid test set _size: {}",
test_size, n_test
)));
}
let mut indices: Vec<usize> = (0..n_samples).collect();
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = scirs2_core::random::rng();
StdRng::from_rng(&mut r)
}
};
indices.shuffle(&mut rng);
let test_indices = &indices[0..n_test];
let train_indices = &indices[n_test..];
let mut train_arrays = Vec::with_capacity(arrays.len());
let mut test_arrays = Vec::with_capacity(arrays.len());
for &arr in arrays {
let mut train_arr = Vec::with_capacity(train_indices.len());
for &idx in train_indices {
let value = match arr.iter().nth(idx) {
Some(v) => v.clone(),
None => {
return Err(MetricsError::InvalidInput(format!(
"Index out of bounds: {} for array of shape {:?}",
idx,
arr.shape()
)))
}
};
train_arr.push(value);
}
train_arrays.push(Array1::from(train_arr));
let mut test_arr = Vec::with_capacity(test_indices.len());
for &idx in test_indices {
let value = match arr.iter().nth(idx) {
Some(v) => v.clone(),
None => {
return Err(MetricsError::InvalidInput(format!(
"Index out of bounds: {} for array of shape {:?}",
idx,
arr.shape()
)))
}
};
test_arr.push(value);
}
test_arrays.push(Array1::from(test_arr));
}
Ok((train_arrays, test_arrays))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_train_test_split() {
let x = scirs2_core::ndarray::Array::linspace(0.0, 9.0, 10);
let y = &x * 2.0;
let (train_arrays, test_arrays) =
train_test_split(&[&x, &y], 0.3, Some(42)).expect("Operation failed");
assert_eq!(train_arrays.len(), 2);
assert_eq!(test_arrays.len(), 2);
assert_eq!(train_arrays[0].len(), 7); assert_eq!(test_arrays[0].len(), 3);
for (x_val, y_val) in train_arrays[0].iter().zip(train_arrays[1].iter()) {
assert_eq!(*y_val, *x_val * 2.0);
}
for (x_val, y_val) in test_arrays[0].iter().zip(test_arrays[1].iter()) {
assert_eq!(*y_val, *x_val * 2.0);
}
}
}