opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
//! M4 captioning dataset loader — mirrors `M4QADataset` in Python.
//!
//! In the Rust pipeline the WISDM-W wrist accelerometer dataset is converted
//! into M4-style captioning samples: the model must generate a short
//! natural-language description of the provided time series.
//!
//! # JSONL row schema
//!
//! ```json
//! { "series": [0.1, -0.2, ...], "caption": "This 4-second segment ...",
//!   "info": "wrist accelerometer x-axis, WISDM-W, 50 Hz, 4 s" }
//! ```
//!
//! # Splits
//!
//! 80 / 10 / 10 (train / val / test) with seed 42.
//!
//! # Curriculum stage
//!
//! Stage 2 — time-series captioning.

use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;

use crate::data::batch::{normalize, Sample};

/// Deserialised row from the M4 JSONL file.
#[derive(Debug, Deserialize)]
struct M4Row {
    series: Vec<f32>,
    caption: String,
    #[serde(default)]
    info: String,
}

/// Train / val / test splits for the M4 captioning dataset.
pub struct M4Splits {
    pub train: Vec<Sample>,
    pub val:   Vec<Sample>,
    pub test:  Vec<Sample>,
}

/// Load M4 captioning dataset from `<data_dir>/m4/train_samples.jsonl`.
pub fn load_m4_splits(data_dir: &Path) -> Result<M4Splits> {
    let file = data_dir.join("m4").join("train_samples.jsonl");
    let text = std::fs::read_to_string(&file)
        .with_context(|| format!("Cannot read M4 dataset at {file:?}"))?;

    let rows: Vec<M4Row> = text
        .lines()
        .filter(|l| !l.trim().is_empty())
        .enumerate()
        .map(|(i, line)| {
            serde_json::from_str(line)
                .with_context(|| format!("M4 line {i}: parse error"))
        })
        .collect::<Result<_>>()?;

    let samples: Vec<Sample> = rows
        .iter()
        .map(row_to_sample)
        .collect::<Result<Vec<_>>>()?;

    split_80_10_10(samples)
}

fn row_to_sample(row: &M4Row) -> Result<Sample> {
    let (series_norm, mean, std) = normalize(&row.series);

    let pre_prompt = format!(
        "You are given the following univariate time series{}. \
         Describe its key characteristics, trends, and patterns in a short paragraph.",
        if row.info.is_empty() { String::new() } else { format!(" ({})", row.info) }
    );

    let ts_text = format!(
        "Time series data, mean {mean:.4}, std {std:.4}:"
    );

    Ok(Sample {
        pre_prompt,
        time_series_text: vec![ts_text],
        time_series: vec![series_norm],
        post_prompt: "Caption:".to_string(),
        answer: row.caption.trim().to_string(),
        label: None,
    })
}

fn split_80_10_10(samples: Vec<Sample>) -> Result<M4Splits> {
    let n = samples.len();
    let shuffled = {
        use rand::{seq::SliceRandom, SeedableRng};
        use rand::rngs::StdRng;
        let mut rng = StdRng::seed_from_u64(42);
        let mut v = samples;
        v.shuffle(&mut rng);
        v
    };

    let test_n = (n as f64 * 0.10).round() as usize;
    let val_n  = (n as f64 * 0.10).round() as usize;
    let train_n = n - test_n - val_n;

    let mut it = shuffled.into_iter();
    Ok(M4Splits {
        train: it.by_ref().take(train_n).collect(),
        val:   it.by_ref().take(val_n).collect(),
        test:  it.collect(),
    })
}