opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
//! ECG-QA chain-of-thought dataset loader — mirrors `ECGQACoTQADataset`.
//!
//! 12-lead ECG question answering with chain-of-thought reasoning.
//! In the Rust pipeline, synthetic ECG waveforms are used (generated by the
//! downloader) because PhysioNet/PTB-XL is not available through the
//! HuggingFace Datasets Parquet API.
//!
//! Each record has up to 12 leads; the Rust loader emits them in canonical
//! order (`I II III aVR aVL aVF V1–V6`) and truncates each lead to
//! [`MAX_SERIES_LEN`] samples.
//!
//! # JSONL row schema
//!
//! ```json
//! {
//!   "question": "Does this ECG show normal sinus rhythm?",
//!   "rationale": "The ECG shows regular P waves ... Answer: normal sinus rhythm",
//!   "clinical_context": "Routine cardiac screening.",
//!   "template_id": 0,
//!   "question_type": "rhythm classification",
//!   "leads": { "I": [...], "II": [...] }
//! }
//! ```
//!
//! # Curriculum stage
//!
//! Stage 5 — ECG QA chain-of-thought reasoning.

use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;
use std::collections::HashMap;

use crate::{config::MAX_SERIES_LEN, data::batch::{normalize, Sample}};

/// Canonical 12-lead order used when building per-lead encoder inputs.
/// Leads present in the JSON object are emitted in this order; missing leads
/// are silently skipped.
const LEAD_ORDER: &[&str] = &[
    "I", "II", "III", "aVR", "aVL", "aVF",
    "V1", "V2", "V3", "V4", "V5", "V6",
];

/// Deserialised row from an ECG-QA JSONL file.
#[derive(Debug, Deserialize)]
struct EcgRow {
    question: String,
    rationale: String,
    #[serde(default)]
    clinical_context: String,
    #[serde(default)]
    template_id: Option<u32>,
    #[serde(default)]
    question_type: Option<String>,
    /// Map from lead name (`"I"`, `"II"`, etc.) to raw float samples.
    leads: HashMap<String, Vec<f32>>,
}

/// Train / val / test splits for the ECG-QA CoT dataset.
pub struct EcgSplits {
    pub train: Vec<Sample>,
    pub val:   Vec<Sample>,
    pub test:  Vec<Sample>,
}

/// Load ECG-QA CoT splits from `<data_dir>/ecg_qa_cot/{train,val,test}.jsonl`.
///
/// Each lead is independently truncated to
/// [`crate::config::MAX_SERIES_LEN`] samples and normalised.
/// Rows with no lead data at all are rejected with an error.
pub fn load_ecg_splits(data_dir: &Path) -> Result<EcgSplits> {
    Ok(EcgSplits {
        train: load_split(data_dir, "train.jsonl")?,
        val:   load_split(data_dir, "val.jsonl")?,
        test:  load_split(data_dir, "test.jsonl")?,
    })
}

fn load_split(base: &Path, filename: &str) -> Result<Vec<Sample>> {
    let file = base.join("ecg_qa_cot").join(filename);
    let text = std::fs::read_to_string(&file)
        .with_context(|| format!("Cannot read ECG-QA file {file:?}"))?;

    text.lines()
        .filter(|l| !l.trim().is_empty())
        .enumerate()
        .map(|(i, line)| {
            let row: EcgRow = serde_json::from_str(line)
                .with_context(|| format!("{filename} line {i}: parse error"))?;
            row_to_sample(&row)
        })
        .collect()
}

fn row_to_sample(row: &EcgRow) -> Result<Sample> {
    let mut ts_texts = Vec::new();
    let mut ts_data  = Vec::new();

    // Emit leads in canonical 12-lead order; skip any missing leads.
    // Truncate each lead to MAX_SERIES_LEN samples so that batches with
    // many leads (up to 12 × 250 patches) stay within memory budget.
    for &lead_name in LEAD_ORDER {
        if let Some(raw) = row.leads.get(lead_name) {
            let raw = &raw[..raw.len().min(MAX_SERIES_LEN)];
            let (normed, mean, std) = normalize(raw);
            ts_texts.push(format!(
                "This is ECG Lead {lead_name}, it has mean {mean:.4} and std {std:.4}:"
            ));
            ts_data.push(normed);
        }
    }

    if ts_data.is_empty() {
        anyhow::bail!("ECG row has no lead data");
    }

    let pre_prompt = format!(
        "You are an expert cardiologist analyzing an ECG (electrocardiogram).

Clinical Context: {}

Your task is to examine the ECG signal and answer the following medical question:

Question: {}

Instructions:
- Begin by analyzing the time series without assuming a specific answer.
- Think step-by-step about what the observed patterns suggest regarding the cardiac condition.
- Write your rationale as a single, natural paragraph — do not use bullet points.
- Do **not** mention any final answer until the very end.",
        row.clinical_context, row.question
    );

    let post_prompt = "\
Based on your analysis of the ECG data, provide your answer.
Make sure that your last word is the answer. You MUST end your response with \"Answer: \""
        .to_string();

    Ok(Sample {
        pre_prompt,
        time_series_text: ts_texts,
        time_series: ts_data,
        post_prompt,
        answer: row.rationale.trim().to_string(),
        label: None,
    })
}