matten_mlprep/split.rs
1//! Ordered, deterministic train/test split (RFC-028 §4.4).
2
3use crate::error::MattenMlprepError;
4use crate::util::matrix_dims;
5use matten::Tensor;
6
7/// Splits the rows of a 2D tensor into `(train, test)` by an ordered,
8/// deterministic partition — **no shuffling**.
9///
10/// ```text
11/// n_train = floor(n_rows * train_ratio)
12/// train = rows[0 .. n_train]
13/// test = rows[n_train .. n_rows]
14/// ```
15///
16/// The split is fully deterministic and reproducible. If you need a randomized
17/// split, shuffle the rows yourself first (a seeded variant is planned but not
18/// in this release; see RFC-024 §6).
19///
20/// # Errors
21///
22/// - [`MattenMlprepError::ExpectedMatrix`] if `x` is not rank-2.
23/// - [`MattenMlprepError::InvalidRatio`] if `train_ratio` is not finite or not in `(0.0, 1.0)`.
24/// - [`MattenMlprepError::EmptySplit`] if `floor(rows * train_ratio) == 0`.
25/// - [`MattenMlprepError::DynamicTensor`] (with the `dynamic` feature) if `x` is dynamic.
26///
27/// ```
28/// use matten::Tensor;
29/// use matten_mlprep::train_test_split;
30///
31/// // 4 rows, 1 feature; 0.75 -> 3 train rows, 1 test row.
32/// let x = Tensor::new(vec![10.0, 20.0, 30.0, 40.0], &[4, 1]);
33/// let (train, test) = train_test_split(&x, 0.75).unwrap();
34/// assert_eq!(train.shape(), &[3, 1]);
35/// assert_eq!(test.shape(), &[1, 1]);
36/// assert_eq!(train.as_slice(), &[10.0, 20.0, 30.0]);
37/// assert_eq!(test.as_slice(), &[40.0]);
38/// ```
39pub fn train_test_split(
40 x: &Tensor,
41 train_ratio: f64,
42) -> Result<(Tensor, Tensor), MattenMlprepError> {
43 let (rows, cols) = matrix_dims(x)?;
44
45 if !train_ratio.is_finite() || train_ratio <= 0.0 || train_ratio >= 1.0 {
46 return Err(MattenMlprepError::InvalidRatio(train_ratio));
47 }
48
49 let n_train = (rows as f64 * train_ratio).floor() as usize;
50 // For any ratio < 1.0, n_train <= rows - 1, so the test set is never empty.
51 // The only failure is an empty train set.
52 if n_train == 0 {
53 return Err(MattenMlprepError::EmptySplit { rows, train_ratio });
54 }
55 let n_test = rows - n_train;
56
57 let data = x.as_slice();
58 let split = n_train * cols;
59
60 let train = Tensor::try_new(data[..split].to_vec(), &[n_train, cols])
61 .map_err(MattenMlprepError::Matten)?;
62 let test = Tensor::try_new(data[split..].to_vec(), &[n_test, cols])
63 .map_err(MattenMlprepError::Matten)?;
64
65 Ok((train, test))
66}