opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
//! SleepEDF chain-of-thought dataset loader — mirrors `SleepEDFCoTQADataset`.
//!
//! Sleep stage classification from 30-second Fpz-Cz EEG segments at 100 Hz
//! (3 000 samples per window, truncated to [`MAX_SERIES_LEN`] = 512 samples
//! before encoding to reduce attention memory by ~34×).
//!
//! Source: [`wbxlala/sleep_edf_str`](https://huggingface.co/datasets/wbxlala/sleep_edf_str)
//! on HuggingFace.
//!
//! # JSONL row schema
//!
//! ```json
//! { "time_series": [...], "label": "Non-REM stage 2",
//!   "rationale": "The EEG shows ... Answer: Non-REM stage 2" }
//! ```
//!
//! # Curriculum stage
//!
//! Stage 4 — sleep-stage chain-of-thought reasoning.
//!
//! [`MAX_SERIES_LEN`]: crate::config::MAX_SERIES_LEN

use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;

use crate::{config::MAX_SERIES_LEN, data::batch::{normalize, Sample}};

/// Ordered list of sleep-stage class labels used in the CoT answers.
const SLEEP_STAGES: &[&str] = &[
    "Wake",
    "Non-REM stage 1",
    "Non-REM stage 2",
    "Non-REM stage 3",
    "REM sleep",
    "Movement",
];

/// Deserialised row from a SleepEDF JSONL file.
#[derive(Debug, Deserialize)]
struct SleepRow {
    time_series: Vec<f32>,
    label: String,
    rationale: String,
}

/// Train / val / test splits for the SleepEDF CoT dataset.
pub struct SleepSplits {
    pub train: Vec<Sample>,
    pub val:   Vec<Sample>,
    pub test:  Vec<Sample>,
}

/// Load SleepEDF CoT splits from `<data_dir>/sleep_cot/{train,val,test}.jsonl`.
///
/// Each EEG window is truncated to [`crate::config::MAX_SERIES_LEN`]
/// samples and zero-mean / unit-variance normalised before being wrapped in a
/// [`crate::data::batch::Sample`].
pub fn load_sleep_splits(data_dir: &Path) -> Result<SleepSplits> {
    Ok(SleepSplits {
        train: load_split(data_dir, "train.jsonl")?,
        val:   load_split(data_dir, "val.jsonl")?,
        test:  load_split(data_dir, "test.jsonl")?,
    })
}

fn load_split(base: &Path, filename: &str) -> Result<Vec<Sample>> {
    let file = base.join("sleep_cot").join(filename);
    let text = std::fs::read_to_string(&file)
        .with_context(|| format!("Cannot read SleepEDF file {file:?}"))?;

    text.lines()
        .filter(|l| !l.trim().is_empty())
        .enumerate()
        .map(|(i, line)| {
            let row: SleepRow = serde_json::from_str(line)
                .with_context(|| format!("{filename} line {i}: parse error"))?;
            row_to_sample(&row)
        })
        .collect()
}

fn row_to_sample(row: &SleepRow) -> Result<Sample> {
    // Truncate long EEG windows before normalisation.  At 100 Hz, 512 samples
    // gives 5.12 s of context — sufficient to capture sleep spindles (12–15 Hz)
    // and K-complexes — while keeping the patch count at 128 (vs. 750 for the
    // full 3000-sample window), reducing attention memory by ~34×.
    let ts_slice = &row.time_series[..row.time_series.len().min(MAX_SERIES_LEN)];
    let (normed, mean, std) = normalize(ts_slice);

    let pre_prompt = "\
You are given a 30-second EEG time series segment. Your task is to classify the sleep stage \
based on analysis of the data.

Instructions:
- Analyze the data objectively without presuming a particular label.
- Reason carefully and methodically about what the signal patterns suggest regarding sleep stage.
- Write your reasoning as a single, coherent paragraph. Do not use bullet points, lists, or section headers.
- Only reveal the correct class at the very end.
- Never state that you are uncertain or unable to classify the data. You must always provide a \
rationale and a final answer."
        .to_string();

    let stage_list = SLEEP_STAGES.join(", ");
    let post_prompt = format!(
        "Possible sleep stages are:\n{stage_list}\n\n\
         Please now write your rationale. Make sure that your last word is the answer. \
         You MUST end your response with \"Answer: \""
    );

    let ts_text = format!(
        "The following is the EEG time series, it has mean {mean:.4} and std {std:.4}:"
    );

    Ok(Sample {
        pre_prompt,
        time_series_text: vec![ts_text],
        time_series: vec![normed],
        post_prompt,
        answer: row.rationale.trim().to_string(),
        label: Some(row.label.trim().to_string()),
    })
}