use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;
use crate::data::batch::{normalize, Sample};
#[derive(Debug, Deserialize)]
struct M4Row {
series: Vec<f32>,
caption: String,
#[serde(default)]
info: String,
}
pub struct M4Splits {
pub train: Vec<Sample>,
pub val: Vec<Sample>,
pub test: Vec<Sample>,
}
pub fn load_m4_splits(data_dir: &Path) -> Result<M4Splits> {
let file = data_dir.join("m4").join("train_samples.jsonl");
let text = std::fs::read_to_string(&file)
.with_context(|| format!("Cannot read M4 dataset at {file:?}"))?;
let rows: Vec<M4Row> = text
.lines()
.filter(|l| !l.trim().is_empty())
.enumerate()
.map(|(i, line)| {
serde_json::from_str(line)
.with_context(|| format!("M4 line {i}: parse error"))
})
.collect::<Result<_>>()?;
let samples: Vec<Sample> = rows
.iter()
.map(row_to_sample)
.collect::<Result<Vec<_>>>()?;
split_80_10_10(samples)
}
fn row_to_sample(row: &M4Row) -> Result<Sample> {
let (series_norm, mean, std) = normalize(&row.series);
let pre_prompt = format!(
"You are given the following univariate time series{}. \
Describe its key characteristics, trends, and patterns in a short paragraph.",
if row.info.is_empty() { String::new() } else { format!(" ({})", row.info) }
);
let ts_text = format!(
"Time series data, mean {mean:.4}, std {std:.4}:"
);
Ok(Sample {
pre_prompt,
time_series_text: vec![ts_text],
time_series: vec![series_norm],
post_prompt: "Caption:".to_string(),
answer: row.caption.trim().to_string(),
label: None,
})
}
fn split_80_10_10(samples: Vec<Sample>) -> Result<M4Splits> {
let n = samples.len();
let shuffled = {
use rand::{seq::SliceRandom, SeedableRng};
use rand::rngs::StdRng;
let mut rng = StdRng::seed_from_u64(42);
let mut v = samples;
v.shuffle(&mut rng);
v
};
let test_n = (n as f64 * 0.10).round() as usize;
let val_n = (n as f64 * 0.10).round() as usize;
let train_n = n - test_n - val_n;
let mut it = shuffled.into_iter();
Ok(M4Splits {
train: it.by_ref().take(train_n).collect(),
val: it.by_ref().take(val_n).collect(),
test: it.collect(),
})
}