use crate::core::{TimeSeriesError};
use crate::WindowedTimeSeries;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub enum SplitStrategy {
TimeBased {
train_frac: f64,
val_frac: f64,
},
SeriesBased {
train_frac: f64,
val_frac: f64,
},
RollingWindow {
train_windows: usize,
val_windows: usize,
},
}
#[derive(Debug, Clone)]
pub struct DataSplit<T> {
pub train: WindowedTimeSeries<T>,
pub val: Option<WindowedTimeSeries<T>>,
pub test: WindowedTimeSeries<T>,
}
pub fn split_windowed_data<T>(
data: WindowedTimeSeries<T>,
strategy: SplitStrategy,
) -> Result<DataSplit<T>, TimeSeriesError>
where
T: Clone + Default,
{
match strategy {
SplitStrategy::TimeBased { train_frac, val_frac } => {
split_time_based(data, train_frac, val_frac)
}
SplitStrategy::SeriesBased { train_frac, val_frac } => {
split_series_based(data, train_frac, val_frac)
}
SplitStrategy::RollingWindow { train_windows, val_windows } => {
split_rolling_window(data, train_windows, val_windows)
}
}
}
fn split_time_based<T>(
data: WindowedTimeSeries<T>,
train_frac: f64,
val_frac: f64,
) -> Result<DataSplit<T>, TimeSeriesError>
where
T: Clone + Default,
{
let total_windows = data.len();
let train_end = (total_windows as f64 * train_frac) as usize;
let val_end = train_end + (total_windows as f64 * val_frac) as usize;
let (train_windows, rest) = data.windows.split_at(train_end);
let (val_windows, test_windows) = rest.split_at(val_end - train_end);
let train = WindowedTimeSeries {
windows: train_windows.to_vec(),
labels: data.labels.as_ref().map(|l| l[..train_end].to_vec()),
series_indices: data.series_indices[..train_end].to_vec(),
window_starts: data.window_starts[..train_end].to_vec(),
window_size: data.window_size,
num_features: data.num_features,
};
let val = if !val_windows.is_empty() {
Some(WindowedTimeSeries {
windows: val_windows.to_vec(),
labels: data.labels.as_ref().map(|l| l[train_end..val_end].to_vec()),
series_indices: data.series_indices[train_end..val_end].to_vec(),
window_starts: data.window_starts[train_end..val_end].to_vec(),
window_size: data.window_size,
num_features: data.num_features,
})
} else {
None
};
let test = WindowedTimeSeries {
windows: test_windows.to_vec(),
labels: data.labels.as_ref().map(|l| l[val_end..].to_vec()),
series_indices: data.series_indices[val_end..].to_vec(),
window_starts: data.window_starts[val_end..].to_vec(),
window_size: data.window_size,
num_features: data.num_features,
};
Ok(DataSplit { train, val, test })
}
fn split_series_based<T>(
data: WindowedTimeSeries<T>,
train_frac: f64,
val_frac: f64,
) -> Result<DataSplit<T>, TimeSeriesError>
where
T: Clone + Default,
{
let mut series_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for (idx, &series_idx) in data.series_indices.iter().enumerate() {
series_groups.entry(series_idx).or_default().push(idx);
}
let mut series_indices: Vec<usize> = series_groups.keys().cloned().collect();
series_indices.sort_by_key(|&x| x.wrapping_mul(2654435761) % (2usize.pow(32)));
let total_series = series_indices.len();
let train_end = (total_series as f64 * train_frac) as usize;
let val_end = train_end + (total_series as f64 * val_frac) as usize;
let train_series = &series_indices[..train_end];
let val_series = &series_indices[train_end..val_end];
let test_series = &series_indices[val_end..];
let train_indices: Vec<usize> = train_series.iter()
.flat_map(|&s| series_groups[&s].iter().cloned())
.collect();
let val_indices: Vec<usize> = val_series.iter()
.flat_map(|&s| series_groups[&s].iter().cloned())
.collect();
let test_indices: Vec<usize> = test_series.iter()
.flat_map(|&s| series_groups[&s].iter().cloned())
.collect();
let train = extract_windows(&data, &train_indices);
let val = if !val_indices.is_empty() {
Some(extract_windows(&data, &val_indices))
} else {
None
};
let test = extract_windows(&data, &test_indices);
Ok(DataSplit { train, val, test })
}
fn split_rolling_window<T>(
data: WindowedTimeSeries<T>,
train_windows: usize,
val_windows: usize,
) -> Result<DataSplit<T>, TimeSeriesError>
where
T: Clone + Default,
{
let total_windows = data.len();
if train_windows + val_windows >= total_windows {
return Err(TimeSeriesError::LengthMismatch {
timestamps: train_windows + val_windows,
values: total_windows,
});
}
let train_indices: Vec<usize> = (0..train_windows).collect();
let val_indices: Vec<usize> = (train_windows..train_windows + val_windows).collect();
let test_indices: Vec<usize> = (train_windows + val_windows..total_windows).collect();
let train = extract_windows(&data, &train_indices);
let val = Some(extract_windows(&data, &val_indices));
let test = extract_windows(&data, &test_indices);
Ok(DataSplit { train, val, test })
}
fn extract_windows<T>(
data: &WindowedTimeSeries<T>,
indices: &[usize],
) -> WindowedTimeSeries<T>
where
T: Clone + Default,
{
let windows: Vec<Vec<Vec<T>>> = indices.iter()
.map(|&i| data.windows[i].clone())
.collect();
let labels = data.labels.as_ref().map(|l| {
indices.iter().map(|&i| l[i].clone()).collect()
});
let series_indices: Vec<usize> = indices.iter()
.map(|&i| data.series_indices[i])
.collect();
let window_starts: Vec<usize> = indices.iter()
.map(|&i| data.window_starts[i])
.collect();
WindowedTimeSeries {
windows,
labels,
series_indices,
window_starts,
window_size: data.window_size,
num_features: data.num_features,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TimeSeries;
#[test]
fn test_time_based_split() {
let series = TimeSeries::from_raw(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let windows = WindowedTimeSeries::from_series(&series, 2, 1).unwrap();
let split = split_windowed_data(windows, SplitStrategy::TimeBased {
train_frac: 0.5,
val_frac: 0.25,
}).unwrap();
assert_eq!(split.train.len(), 2);
assert_eq!(split.val.as_ref().unwrap().len(), 1);
assert_eq!(split.test.len(), 2); }
#[test]
fn test_rolling_window_split() {
let series = TimeSeries::from_raw(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let windows = WindowedTimeSeries::from_series(&series, 2, 1).unwrap();
let split = split_windowed_data(windows, SplitStrategy::RollingWindow {
train_windows: 2,
val_windows: 1,
}).unwrap();
assert_eq!(split.train.len(), 2);
assert_eq!(split.val.as_ref().unwrap().len(), 1);
assert_eq!(split.test.len(), 2); }
}