opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
//! TSQA dataset loader — mirrors `TSQADataset` / `QADataset` in Python.
//!
//! **Source**: HuggingFace [`ChengsenWang/TSQA`](https://huggingface.co/datasets/ChengsenWang/TSQA),
//! served as JSONL files downloaded by the [`downloader`](crate::data::downloader).
//!
//! In the Rust pipeline the WISDM-W accelerometer dataset (which the
//! downloader uses) is converted into TSQA-format MCQ rows: the model must
//! identify which of four activity labels matches the x-axis accelerometer
//! signal shown in a 4-second window.
//!
//! # JSONL row schema
//!
//! ```json
//! { "Question": "...", "Answer": "...", "Task": "activity classification",
//!   "Series": "[0.1, -0.2, ...]" }
//! ```
//!
//! # Splits
//!
//! 80 / 10 / 10 (train / val / test) with seed 42, matching the Python splits.
//!
//! # Curriculum stage
//!
//! Stage 1 — MCQ on time series.

use std::path::Path;

use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};

use crate::data::batch::{normalize, Sample};

// ── Row schema ────────────────────────────────────────────────────────────────

/// Deserialised row from the TSQA JSONL file.
#[derive(Debug, Deserialize)]
struct TsqaRow {
    #[serde(rename = "Question")]
    question: String,
    #[serde(rename = "Answer")]
    answer: String,
    #[serde(rename = "Task")]
    task: String,
    #[serde(rename = "Series")]
    series: String, // JSON-encoded array
}

// ── Public API ────────────────────────────────────────────────────────────────

/// Train / val / test splits for the TSQA dataset.
pub struct TsqaSplits {
    pub train: Vec<Sample>,
    pub val:   Vec<Sample>,
    pub test:  Vec<Sample>,
}

/// Load the TSQA dataset from a JSONL file downloaded from HuggingFace.
///
/// Expected file: `<data_dir>/tsqa/train.jsonl`
/// (Download with `huggingface-cli download ChengsenWang/TSQA --local-dir <data_dir>/tsqa`)
pub fn load_tsqa_splits(data_dir: &Path) -> Result<TsqaSplits> {
    let file = data_dir.join("tsqa").join("train.jsonl");
    let text = std::fs::read_to_string(&file)
        .with_context(|| format!("Cannot read TSQA dataset at {file:?}"))?;

    let rows: Vec<TsqaRow> = text
        .lines()
        .filter(|l| !l.trim().is_empty())
        .enumerate()
        .map(|(i, line)| {
            serde_json::from_str(line)
                .with_context(|| format!("TSQA line {i}: parse error"))
        })
        .collect::<Result<_>>()?;

    let samples: Vec<Sample> = rows.iter().map(row_to_sample).collect::<Result<Vec<_>>>()?.into_iter().collect();

    // 80 / 10 / 10 split (seeded deterministically)
    split_80_10_10(samples)
}

fn row_to_sample(row: &TsqaRow) -> Result<Sample> {
    // Parse series JSON
    let series: Vec<f32> = serde_json::from_str(&row.series)
        .with_context(|| "Failed to parse TSQA series field")?;

    // Normalise
    let (series_norm, mean, std) = normalize(&series);

    let pre_prompt = row.question.trim().to_string();
    let post_prompt = format!("Predict the {} Answer:", row.task);
    let ts_text = format!(
        "This is the time series, it has mean {mean:.4} and std {std:.4}."
    );

    Ok(Sample {
        pre_prompt,
        time_series_text: vec![ts_text],
        time_series: vec![series_norm],
        post_prompt,
        answer: row.answer.trim().to_string(),
        label: Some(row.answer.trim().to_string()),
    })
}

fn split_80_10_10(samples: Vec<Sample>) -> Result<TsqaSplits> {
    let n = samples.len();
    // Deterministic shuffle with seed 42 (matches Python random_state=42)
    let shuffled = deterministic_shuffle(samples, 42);

    let test_n = (n as f64 * 0.10).round() as usize;
    let val_n = (n as f64 * 0.10).round() as usize;
    let train_n = n - test_n - val_n;

    let mut it = shuffled.into_iter();
    let train: Vec<_> = it.by_ref().take(train_n).collect();
    let val: Vec<_> = it.by_ref().take(val_n).collect();
    let test: Vec<_> = it.collect();

    Ok(TsqaSplits { train, val, test })
}

fn deterministic_shuffle(mut v: Vec<Sample>, seed: u64) -> Vec<Sample> {
    use rand::{seq::SliceRandom, SeedableRng};
    use rand::rngs::StdRng;
    let mut rng = StdRng::seed_from_u64(seed);
    v.shuffle(&mut rng);
    v
}