use crate::config::PATCH_SIZE;
#[derive(Debug, Clone)]
pub struct Sample {
pub pre_prompt: String,
pub time_series_text: Vec<String>,
pub time_series: Vec<Vec<f32>>,
pub post_prompt: String,
pub answer: String,
pub label: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TokenizedBatch {
pub samples: Vec<Sample>,
}
impl TokenizedBatch {
pub fn from_samples(mut samples: Vec<Sample>, patch_size: usize) -> Self {
for s in &mut samples {
extend_to_patch_multiple(&mut s.time_series, patch_size);
}
Self { samples }
}
}
pub fn extend_to_patch_multiple(series_list: &mut Vec<Vec<f32>>, patch_size: usize) {
if series_list.is_empty() {
return;
}
let max_len = series_list.iter().map(|v| v.len()).max().unwrap_or(0);
let padded_len = ((max_len + patch_size - 1) / patch_size) * patch_size;
for series in series_list.iter_mut() {
if series.len() < padded_len {
series.resize(padded_len, 0.0);
} else if series.len() > padded_len {
series.truncate(padded_len);
}
}
}
pub fn collate(samples: Vec<Sample>) -> TokenizedBatch {
TokenizedBatch::from_samples(samples, PATCH_SIZE)
}
pub fn normalize(v: &[f32]) -> (Vec<f32>, f32, f32) {
let n = v.len() as f32;
let mean = v.iter().sum::<f32>() / n;
let var = v.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
let std = var.sqrt().max(1e-6);
let normed = v.iter().map(|x| (x - mean) / std).collect();
(normed, mean, std)
}