pub fn time_series_split(n_samples: usize, n_splits: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
if n_splits < 2 || n_samples < 2 {
return vec![];
}
let test_size = n_samples / (n_splits + 1);
if test_size == 0 {
return vec![];
}
let first_test_start = n_samples - n_splits * test_size;
let mut splits = Vec::with_capacity(n_splits);
for i in 0..n_splits {
let test_start = first_test_start + i * test_size;
let test_end = test_start + test_size;
let train: Vec<usize> = (0..test_start).collect();
let test: Vec<usize> = (test_start..test_end).collect();
splits.push((train, test));
}
splits
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_split() {
let splits = time_series_split(10, 3);
assert_eq!(splits.len(), 3);
assert_eq!(splits[0].0, vec![0, 1, 2, 3]);
assert_eq!(splits[0].1, vec![4, 5]);
assert_eq!(splits[1].0, vec![0, 1, 2, 3, 4, 5]);
assert_eq!(splits[1].1, vec![6, 7]);
assert_eq!(splits[2].0, vec![0, 1, 2, 3, 4, 5, 6, 7]);
assert_eq!(splits[2].1, vec![8, 9]);
}
#[test]
fn expanding_window() {
let splits = time_series_split(100, 5);
assert_eq!(splits.len(), 5);
for i in 1..splits.len() {
assert!(splits[i].0.len() > splits[i - 1].0.len());
}
let test_size = splits[0].1.len();
for s in &splits {
assert_eq!(s.1.len(), test_size);
}
}
#[test]
fn no_overlap() {
let splits = time_series_split(50, 5);
for (train, test) in &splits {
for &t in test {
assert!(!train.contains(&t), "test index {t} found in training set");
}
if let (Some(&last_train), Some(&first_test)) = (train.last(), test.first()) {
assert!(first_test > last_train, "test must come after training");
}
}
}
#[test]
fn too_few_samples() {
let splits = time_series_split(2, 5);
assert!(splits.is_empty());
}
#[test]
fn zero_splits() {
assert!(time_series_split(100, 0).is_empty());
}
#[test]
fn single_split() {
let splits = time_series_split(10, 1);
assert!(splits.is_empty());
}
#[test]
fn large_dataset() {
let splits = time_series_split(1000, 10);
assert_eq!(splits.len(), 10);
assert_eq!(splits[0].1.len(), 90);
}
}