Skip to main content

train_test_split

Function train_test_split 

Source
pub fn train_test_split(
    x: &Tensor,
    train_ratio: f64,
) -> Result<(Tensor, Tensor), MattenMlprepError>
Expand description

Splits the rows of a 2D tensor into (train, test) by an ordered, deterministic partition — no shuffling.

n_train = floor(n_rows * train_ratio)
train   = rows[0 .. n_train]
test    = rows[n_train .. n_rows]

The split is fully deterministic and reproducible. If you need a randomized split, shuffle the rows yourself first (a seeded variant is planned but not in this release; see RFC-024 §6).

§Errors

use matten::Tensor;
use matten_mlprep::train_test_split;

// 4 rows, 1 feature; 0.75 -> 3 train rows, 1 test row.
let x = Tensor::new(vec![10.0, 20.0, 30.0, 40.0], &[4, 1]);
let (train, test) = train_test_split(&x, 0.75).unwrap();
assert_eq!(train.shape(), &[3, 1]);
assert_eq!(test.shape(), &[1, 1]);
assert_eq!(train.as_slice(), &[10.0, 20.0, 30.0]);
assert_eq!(test.as_slice(), &[40.0]);
Examples found in repository?
examples/train_test_split.rs (line 20)
17fn main() {
18    // 5 samples, 2 features: rows 0..=4.
19    let x = Tensor::new((0..10).map(|v| v as f64).collect(), &[5, 2]);
20    let (train, test) = train_test_split(&x, 0.6).expect("valid split"); // 3 / 2
21    println!("train {:?}: {:?}", train.shape(), train.as_slice());
22    println!("test  {:?}: {:?}", test.shape(), test.as_slice());
23
24    // Deterministic ordered split: first 3 rows -> train, last 2 rows -> test.
25    assert_eq!(train.shape(), &[3, 2]);
26    assert_eq!(test.shape(), &[2, 2]);
27    assert_eq!(train.as_slice(), &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
28    assert_eq!(test.as_slice(), &[6.0, 7.0, 8.0, 9.0]);
29    println!("train_test_split: OK");
30}