use crate::error::MattenMlprepError;
use crate::util::matrix_dims;
use matten::Tensor;
pub fn train_test_split(
x: &Tensor,
train_ratio: f64,
) -> Result<(Tensor, Tensor), MattenMlprepError> {
let (rows, cols) = matrix_dims(x)?;
if !train_ratio.is_finite() || train_ratio <= 0.0 || train_ratio >= 1.0 {
return Err(MattenMlprepError::InvalidRatio(train_ratio));
}
let n_train = (rows as f64 * train_ratio).floor() as usize;
if n_train == 0 {
return Err(MattenMlprepError::EmptySplit { rows, train_ratio });
}
let n_test = rows - n_train;
let data = x.as_slice();
let split = n_train * cols;
let train = Tensor::try_new(data[..split].to_vec(), &[n_train, cols])
.map_err(MattenMlprepError::Matten)?;
let test = Tensor::try_new(data[split..].to_vec(), &[n_test, cols])
.map_err(MattenMlprepError::Matten)?;
Ok((train, test))
}