use std::path::Path;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use crate::data::batch::{normalize, Sample};
#[derive(Debug, Deserialize)]
struct TsqaRow {
#[serde(rename = "Question")]
question: String,
#[serde(rename = "Answer")]
answer: String,
#[serde(rename = "Task")]
task: String,
#[serde(rename = "Series")]
series: String, }
pub struct TsqaSplits {
pub train: Vec<Sample>,
pub val: Vec<Sample>,
pub test: Vec<Sample>,
}
pub fn load_tsqa_splits(data_dir: &Path) -> Result<TsqaSplits> {
let file = data_dir.join("tsqa").join("train.jsonl");
let text = std::fs::read_to_string(&file)
.with_context(|| format!("Cannot read TSQA dataset at {file:?}"))?;
let rows: Vec<TsqaRow> = text
.lines()
.filter(|l| !l.trim().is_empty())
.enumerate()
.map(|(i, line)| {
serde_json::from_str(line)
.with_context(|| format!("TSQA line {i}: parse error"))
})
.collect::<Result<_>>()?;
let samples: Vec<Sample> = rows.iter().map(row_to_sample).collect::<Result<Vec<_>>>()?.into_iter().collect();
split_80_10_10(samples)
}
fn row_to_sample(row: &TsqaRow) -> Result<Sample> {
let series: Vec<f32> = serde_json::from_str(&row.series)
.with_context(|| "Failed to parse TSQA series field")?;
let (series_norm, mean, std) = normalize(&series);
let pre_prompt = row.question.trim().to_string();
let post_prompt = format!("Predict the {} Answer:", row.task);
let ts_text = format!(
"This is the time series, it has mean {mean:.4} and std {std:.4}."
);
Ok(Sample {
pre_prompt,
time_series_text: vec![ts_text],
time_series: vec![series_norm],
post_prompt,
answer: row.answer.trim().to_string(),
label: Some(row.answer.trim().to_string()),
})
}
fn split_80_10_10(samples: Vec<Sample>) -> Result<TsqaSplits> {
let n = samples.len();
let shuffled = deterministic_shuffle(samples, 42);
let test_n = (n as f64 * 0.10).round() as usize;
let val_n = (n as f64 * 0.10).round() as usize;
let train_n = n - test_n - val_n;
let mut it = shuffled.into_iter();
let train: Vec<_> = it.by_ref().take(train_n).collect();
let val: Vec<_> = it.by_ref().take(val_n).collect();
let test: Vec<_> = it.collect();
Ok(TsqaSplits { train, val, test })
}
fn deterministic_shuffle(mut v: Vec<Sample>, seed: u64) -> Vec<Sample> {
use rand::{seq::SliceRandom, SeedableRng};
use rand::rngs::StdRng;
let mut rng = StdRng::seed_from_u64(seed);
v.shuffle(&mut rng);
v
}