use std::{
collections::HashMap,
fmt::Write as _,
fs,
path::{Path, PathBuf},
};
use anyhow::{bail, Context, Result};
use hf_hub::{
api::sync::{Api, ApiRepo},
Repo, RepoType,
};
use parquet::{
file::reader::{FileReader, SerializedFileReader},
record::reader::RowIter,
};
use rand::{rngs::StdRng, seq::{IndexedRandom, SliceRandom}, Rng, SeedableRng};
use serde_json::{json, Value};
use tracing::info;
pub struct DownloadConfig {
pub out_dir: PathBuf,
pub limit: Option<usize>,
pub only: Option<String>,
}
pub fn run(cfg: &DownloadConfig) -> Result<()> {
fs::create_dir_all(&cfg.out_dir)?;
let run_har = cfg.only.as_deref().map_or(true, |o| o == "har");
let run_sleep = cfg.only.as_deref().map_or(true, |o| o == "sleep");
let run_ecg = cfg.only.as_deref().map_or(true, |o| o == "ecg");
if run_har { download_wisdm(&cfg.out_dir, cfg.limit)?; }
if run_sleep { download_sleep(&cfg.out_dir, cfg.limit)?; }
if run_ecg { generate_ecg(&cfg.out_dir, cfg.limit)?; }
print_summary(&cfg.out_dir)?;
Ok(())
}
fn parquet_repo(dataset: &str) -> Result<ApiRepo> {
let api = Api::new().map_err(|e| anyhow::anyhow!("HF Hub init failed: {e}"))?;
Ok(api.repo(Repo::with_revision(
dataset.to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
)))
}
fn fetch_split(dataset: &str, split: &str, limit: Option<usize>) -> Result<Vec<Value>> {
let repo = parquet_repo(dataset)?;
let info = repo
.info()
.map_err(|e| anyhow::anyhow!("Could not fetch repo info for {dataset}: {e}"))?;
let shards: Vec<String> = info
.siblings
.into_iter()
.map(|s| s.rfilename)
.filter(|name| {
name.ends_with(".parquet")
&& name.split('/').any(|seg| seg == split)
})
.collect();
if shards.is_empty() {
bail!("No parquet shards found for split '{split}' in dataset '{dataset}'");
}
info!(
" {dataset}/{split}: {} shard(s) to download",
shards.len()
);
let mut rows: Vec<Value> = Vec::new();
'outer: for shard in &shards {
let local = repo
.download(shard)
.map_err(|e| anyhow::anyhow!("Failed to download shard {shard}: {e}"))?;
let file = fs::File::open(&local)?;
let reader = SerializedFileReader::new(file)
.with_context(|| format!("Failed to open parquet file {}", local.display()))?;
for row_result in RowIter::from_file_into(Box::new(reader)) {
let row = row_result.context("Parquet row decode error")?;
rows.push(row.to_json_value());
if limit.is_some_and(|l| rows.len() >= l) {
break 'outer;
}
}
}
info!(" fetched {} rows ({split}/{dataset})", rows.len());
Ok(rows)
}
fn download_wisdm(out_dir: &Path, limit: Option<usize>) -> Result<()> {
info!("━━ WISDM-W (wrist accel+gyro, 12 activities) ━━");
let train = fetch_split("claudiogsc/WISDM-W", "train", limit)?;
let test = fetch_split("claudiogsc/WISDM-W", "test", limit.map(|l| l / 5))?;
let mut rng = StdRng::seed_from_u64(42);
let mut all_train = train;
all_train.shuffle(&mut rng);
let n_val = (all_train.len() / 10).max(1);
let val = all_train.drain(..n_val).collect::<Vec<_>>();
let train = all_train;
write_jsonl(out_dir, "har_cot/train.jsonl", train.iter().map(wisdm_to_har).collect::<Result<Vec<_>>>()?.as_slice())?;
write_jsonl(out_dir, "har_cot/val.jsonl", val.iter() .map(wisdm_to_har).collect::<Result<Vec<_>>>()?.as_slice())?;
write_jsonl(out_dir, "har_cot/test.jsonl", test.iter() .map(wisdm_to_har).collect::<Result<Vec<_>>>()?.as_slice())?;
let mut tsqa_rng = StdRng::seed_from_u64(7);
let tsqa: Result<Vec<Value>> = train.iter()
.chain(val.iter())
.chain(test.iter())
.map(|r| wisdm_to_tsqa(r, &mut tsqa_rng))
.collect();
write_jsonl(out_dir, "tsqa/train.jsonl", &tsqa?)?;
let m4: Result<Vec<Value>> = train.iter().map(wisdm_to_m4).collect();
write_jsonl(out_dir, "m4/train_samples.jsonl", &m4?)?;
Ok(())
}
struct WisdmRow {
seq: Vec<[f32; 6]>,
label: usize,
}
fn parse_wisdm(row: &Value) -> Result<WisdmRow> {
let seq_val = &row["sequence"];
let seq: Vec<[f32; 6]> = seq_val
.as_array()
.context("sequence not an array")?
.iter()
.map(|step| {
let arr = step.as_array().context("step not an array")?;
if arr.len() < 6 { bail!("step has <6 elements"); }
Ok([
arr[0].as_f64().unwrap_or(0.0) as f32,
arr[1].as_f64().unwrap_or(0.0) as f32,
arr[2].as_f64().unwrap_or(0.0) as f32,
arr[3].as_f64().unwrap_or(0.0) as f32,
arr[4].as_f64().unwrap_or(0.0) as f32,
arr[5].as_f64().unwrap_or(0.0) as f32,
])
})
.collect::<Result<_>>()?;
let label = row["label"].as_i64().context("label missing")? as usize;
Ok(WisdmRow { seq, label })
}
fn wisdm_to_har(row: &Value) -> Result<Value> {
let r = parse_wisdm(row)?;
let label = wisdm_label(r.label);
let x: Vec<f32> = r.seq.iter().map(|s| s[0]).collect();
let y: Vec<f32> = r.seq.iter().map(|s| s[1]).collect();
let z: Vec<f32> = r.seq.iter().map(|s| s[2]).collect();
let stats = accel_stats(&x, &y, &z);
Ok(json!({
"x_axis": x,
"y_axis": y,
"z_axis": z,
"label": label,
"rationale": har_rationale(label, &stats),
}))
}
fn wisdm_to_tsqa(row: &Value, rng: &mut StdRng) -> Result<Value> {
let r = parse_wisdm(row)?;
let correct = wisdm_label(r.label);
let x: Vec<f32> = r.seq.iter().map(|s| s[0]).collect();
let (options, answer) = mcq_options(correct, &WISDM_LABELS, rng);
let question = format!(
"A wrist-worn smartwatch recorded the following 4-second accelerometer \
(x-axis) signal during a physical activity. Which activity is most \
likely being performed?\n{options}"
);
let series_json = serde_json::to_string(&x)?;
Ok(json!({
"Question": question,
"Answer": answer,
"Task": "activity classification",
"Series": series_json,
}))
}
fn wisdm_to_m4(row: &Value) -> Result<Value> {
let r = parse_wisdm(row)?;
let label = wisdm_label(r.label);
let x: Vec<f32> = r.seq.iter().map(|s| s[0]).collect();
let y: Vec<f32> = r.seq.iter().map(|s| s[1]).collect();
let z: Vec<f32> = r.seq.iter().map(|s| s[2]).collect();
let stats = accel_stats(&x, &y, &z);
let caption = format!(
"This 4-second wrist accelerometer segment captures {} activity. \
X-axis mean {:.3} g (std {:.3} g). \
Resultant magnitude mean {:.3} g (std {:.3} g). {}",
label,
stats["mean_x"], stats["std_x"],
stats["mean_mag"], stats["std_mag"],
wisdm_motion_note(label),
);
Ok(json!({
"series": x,
"caption": caption,
"info": "wrist accelerometer x-axis, WISDM-W, 50 Hz, 4 s",
}))
}
fn download_sleep(out_dir: &Path, limit: Option<usize>) -> Result<()> {
info!("━━ SleepEDF (Fpz-Cz EEG, 30-sec, 5 stages) ━━");
let train = fetch_split("wbxlala/sleep_edf_str", "train", limit)?;
let test = fetch_split("wbxlala/sleep_edf_str", "test", limit.map(|l| l / 5))?;
let mut rng = StdRng::seed_from_u64(42);
let mut all_train = train;
all_train.shuffle(&mut rng);
let n_val = (all_train.len() / 10).max(1);
let val = all_train.drain(..n_val).collect::<Vec<_>>();
let train = all_train;
let to_sleep = |rows: &[Value]| -> Result<Vec<Value>> {
rows.iter().map(sleep_to_cot).collect()
};
write_jsonl(out_dir, "sleep_cot/train.jsonl", &to_sleep(&train)?)?;
write_jsonl(out_dir, "sleep_cot/val.jsonl", &to_sleep(&val)?)?;
write_jsonl(out_dir, "sleep_cot/test.jsonl", &to_sleep(&test)?)?;
Ok(())
}
fn sleep_to_cot(row: &Value) -> Result<Value> {
let sample_str = row["sample"].as_str().context("sample missing")?;
let ts: Vec<f32> = sample_str
.split_ascii_whitespace()
.map(|s| s.parse::<f32>().unwrap_or(0.0))
.collect();
let label_int = row["label"].as_i64().context("label missing")? as usize;
let label = SLEEP_STAGES.get(label_int).copied().unwrap_or("Unknown");
let stats = eeg_stats(&ts);
Ok(json!({
"time_series": ts,
"label": label,
"rationale": sleep_rationale(label, &stats),
}))
}
fn generate_ecg(out_dir: &Path, limit: Option<usize>) -> Result<()> {
info!("━━ Synthetic ECG (2-lead, 4-sec, 250 Hz) ━━");
info!(" (PhysioNet is not served by the HF datasets API; using synthetic data)");
let n_total = limit.unwrap_or(9_000);
let mut rng = StdRng::seed_from_u64(123);
let mut records: Vec<Value> = (0..n_total)
.map(|_| synth_ecg_record(&mut rng))
.collect();
records.shuffle(&mut rng);
let n_test = (n_total / 10).max(1);
let n_val = (n_total / 10).max(1);
let test = records.drain(..n_test).collect::<Vec<_>>();
let val = records.drain(..n_val).collect::<Vec<_>>();
let train = records;
write_jsonl(out_dir, "ecg_qa_cot/train.jsonl", &train)?;
write_jsonl(out_dir, "ecg_qa_cot/val.jsonl", &val)?;
write_jsonl(out_dir, "ecg_qa_cot/test.jsonl", &test)?;
Ok(())
}
fn synth_ecg_record(rng: &mut StdRng) -> Value {
const FS: usize = 250;
const N: usize = 1000;
let has_anomaly: bool = rng.random_bool(0.5);
let hr_bpm: f32 = rng.random_range(60.0_f32..100.0);
let rr_nominal = FS as f32 / (hr_bpm / 60.0);
let mut rpeaks: Vec<usize> = Vec::new();
let mut pos = rng.random_range(10_usize..40);
while pos < N {
rpeaks.push(pos);
let jitter: f32 = if has_anomaly {
rng.random_range(-60.0_f32..60.0)
} else {
rng.random_range(-5.0_f32..5.0)
};
let skip = has_anomaly && rng.random_bool(0.1);
let advance = if skip { rr_nominal * 2.0 } else { rr_nominal };
pos += (advance + jitter).round().max(40.0) as usize;
}
let noise_amp: f32 = rng.random_range(0.01..0.04);
let mut lead1 = vec![0.0_f32; N];
let mut lead2 = vec![0.0_f32; N];
let skip_p = has_anomaly && rng.random_bool(0.6);
for &rp in &rpeaks {
let rp = rp as isize;
for di in -20_isize..=20 {
let i = rp + di;
if i < 0 || i >= N as isize { continue; }
let t = di as f32;
let qrs = -0.10 * gauss(t + 7.0, 3.0)
+ 1.00 * gauss(t, 2.0)
- 0.15 * gauss(t - 7.0, 3.0);
lead1[i as usize] += qrs;
lead2[i as usize] += qrs * 0.8;
}
if !skip_p {
let pp = rp - 20;
for di in -10_isize..=10 {
let i = pp + di;
if i < 0 || i >= N as isize { continue; }
let pw = 0.15 * gauss(di as f32, 5.0);
lead1[i as usize] += pw;
lead2[i as usize] += pw * 1.2;
}
}
let tp = rp + 40;
for di in -20_isize..=20 {
let i = tp + di;
if i < 0 || i >= N as isize { continue; }
let tw = 0.30 * gauss(di as f32, 9.0);
lead1[i as usize] += tw;
lead2[i as usize] += tw * 0.9;
}
}
for i in 0..N {
let wander = 0.02 * f32::sin(2.0 * std::f32::consts::PI * i as f32 / N as f32);
lead1[i] += wander + noise_amp * (rng.random::<f32>() - 0.5) * 2.0;
lead2[i] += wander + noise_amp * (rng.random::<f32>() - 0.5) * 2.0;
}
let label = if has_anomaly { "abnormal rhythm" } else { "normal sinus rhythm" };
let (question, rationale, context) = ecg_rationale(has_anomaly, &lead1);
let round5 = |v: f32| (v * 100_000.0).round() / 100_000.0;
json!({
"question": question,
"rationale": rationale,
"clinical_context": context,
"template_id": if has_anomaly { 1 } else { 0 },
"question_type": "rhythm classification",
"source": "synthetic",
"leads": {
"I": lead1.iter().map(|&v| round5(v)).collect::<Vec<_>>(),
"II": lead2.iter().map(|&v| round5(v)).collect::<Vec<_>>(),
},
})
}
#[inline]
fn gauss(x: f32, sigma: f32) -> f32 {
f32::exp(-0.5 * (x / sigma).powi(2))
}
fn write_jsonl(out_dir: &Path, rel_path: &str, records: &[Value]) -> Result<()> {
let path = out_dir.join(rel_path);
fs::create_dir_all(path.parent().unwrap())?;
let mut out = String::with_capacity(records.len() * 256);
for rec in records {
serde_json::to_writer(unsafe { out.as_mut_vec() }, rec)?;
out.push('\n');
}
fs::write(&path, &out)?;
info!(" wrote {:>6} records → {rel_path}", records.len());
Ok(())
}
fn print_summary(out_dir: &Path) -> Result<()> {
let mut entries: Vec<(String, usize)> = Vec::new();
for entry in walkdir(out_dir)? {
if entry.extension().and_then(|e| e.to_str()) == Some("jsonl") {
let text = fs::read_to_string(&entry)?;
let n = text.lines().filter(|l| !l.trim().is_empty()).count();
let rel = entry.strip_prefix(out_dir)?.to_string_lossy().into_owned();
entries.push((rel, n));
}
}
entries.sort();
info!("\n✓ Output summary:");
for (rel, n) in &entries {
info!(" {rel:<45} {n:>6} rows");
}
Ok(())
}
fn walkdir(dir: &Path) -> Result<Vec<PathBuf>> {
let mut paths = Vec::new();
if !dir.exists() { return Ok(paths); }
for entry in fs::read_dir(dir)? {
let p = entry?.path();
if p.is_dir() {
paths.extend(walkdir(&p)?);
} else {
paths.push(p);
}
}
Ok(paths)
}
fn accel_stats(x: &[f32], y: &[f32], z: &[f32]) -> HashMap<&'static str, f32> {
let mean = |v: &[f32]| v.iter().sum::<f32>() / v.len() as f32;
let std = |v: &[f32], m: f32| {
(v.iter().map(|a| (a - m).powi(2)).sum::<f32>() / v.len() as f32).sqrt()
};
let mx = mean(x);
let mag: Vec<f32> = x.iter().zip(y).zip(z)
.map(|((xi, yi), zi)| (xi*xi + yi*yi + zi*zi).sqrt())
.collect();
let mmag = mean(&mag);
let mut m = HashMap::new();
m.insert("mean_x", mx);
m.insert("std_x", std(x, mx));
m.insert("mean_mag", mmag);
m.insert("std_mag", std(&mag, mmag));
m
}
fn eeg_stats(ts: &[f32]) -> HashMap<&'static str, f32> {
let n = ts.len() as f32;
let mu = ts.iter().sum::<f32>() / n;
let std = (ts.iter().map(|v| (v - mu).powi(2)).sum::<f32>() / n).sqrt();
let mab = ts.iter().map(|v| v.abs()).sum::<f32>() / n;
let mut m = HashMap::new();
m.insert("mean", mu);
m.insert("std", std);
m.insert("mean_abs", mab);
m
}
fn har_rationale(label: &str, s: &HashMap<&str, f32>) -> String {
let (mx, sx, mmag, smag) = (s["mean_x"], s["std_x"], s["mean_mag"], s["std_mag"]);
let description = match label {
"walking" => format!("The wrist accelerometer shows rhythmic oscillations (std ≈ {sx:.2} g) at a cadence consistent with bipedal locomotion (~2 Hz). Mean magnitude {mmag:.2} g is typical for normal walking pace."),
"jogging" => format!("The accelerometer shows high-amplitude, rapid oscillations (std ≈ {sx:.2} g, mean mag {mmag:.2} g). The high-frequency, high-energy pattern with repeated impact spikes is characteristic of running or jogging."),
"climbing stairs"=> format!("The accelerometer shows irregular, asymmetric steps (std ≈ {sx:.2} g). The alternating thrust-and-lift pattern with moderate amplitude ({mmag:.2} g) matches stair ascent or descent."),
"sitting" => format!("The accelerometer shows near-constant values with very low variability (std ≈ {sx:.2} g). Dominant gravity component ({mmag:.2} g) with minimal dynamic movement indicates a seated, stationary posture."),
"standing" => format!("The accelerometer is nearly static (std ≈ {sx:.2} g, mag ≈ {mmag:.2} g). The stable 1-g gravity vector with only micro-tremor noise is consistent with quiet standing."),
"kicking" => format!("The accelerometer shows sharp, high-amplitude transients (std ≈ {sx:.2} g) interspersed with low-activity intervals. The impulsive kick events at {mmag:.2} g followed by return to rest match a kicking motion."),
"catching" => format!("The accelerometer shows sudden bursts (std ≈ {sx:.2} g) followed by stabilisation. The reach-and-absorb motion pattern at {mmag:.2} g is consistent with catching an object."),
"dribbling" => format!("The accelerometer shows repetitive, moderate-amplitude pulses (std ≈ {sx:.2} g, ~{mmag:.2} g) at a regular tempo. The rhythmic downward-push pattern matches ball dribbling."),
"writing" => format!("The accelerometer shows low-amplitude, continuous micro-movements (std ≈ {sx:.2} g). The fine-motor, pen-on-paper motion with near-constant orientation ({mmag:.2} g) indicates writing."),
"clapping" => format!("The accelerometer shows rapid, symmetric bilateral impulses (std ≈ {sx:.2} g) at a regular rate. The sharp collision events at {mmag:.2} g are consistent with hand clapping."),
"brushing teeth" => format!("The accelerometer shows small, high-frequency oscillations (std ≈ {sx:.2} g) at ~4 Hz. The constrained, repetitive brush-stroke motion at {mmag:.2} g matches tooth brushing."),
"eating" => format!("The accelerometer shows intermittent, low-amplitude wrist movements (std ≈ {sx:.2} g) punctuated by brief stillness. The lift-to-mouth pattern at {mmag:.2} g is consistent with eating."),
other => format!("The accelerometer shows activity-dependent motion (std ≈ {sx:.2} g, mean mag {mmag:.2} g) consistent with {other}."),
};
let _ = mx; format!("{description}\n\nAnswer: {label}")
}
fn sleep_rationale(label: &str, s: &HashMap<&str, f32>) -> String {
let (std, mab) = (s["std"], s["mean_abs"]);
let description = match label {
"Wake" => format!("The EEG shows high-frequency, mixed-amplitude activity (std ≈ {std:.1} μV). Beta and alpha oscillations are present; no dominant slow waves. Mean absolute amplitude {mab:.1} μV is typical of wakefulness."),
"Non-REM stage 1" => format!("The EEG shows low-amplitude, mixed-frequency activity (std ≈ {std:.1} μV) with theta waves (4–8 Hz) beginning to dominate. Alpha is diminishing; mean absolute amplitude {mab:.1} μV indicates light sleep onset."),
"Non-REM stage 2" => format!("The EEG (std ≈ {std:.1} μV) shows characteristic sleep spindles (12–15 Hz bursts) and K-complexes. Mean absolute amplitude {mab:.1} μV. The spindle-K-complex signature definitively indicates NREM stage 2."),
"Non-REM stage 3" => format!("The EEG shows high-amplitude, low-frequency delta waves (std ≈ {std:.1} μV, mean abs {mab:.1} μV). Delta waves occupy >20% of the epoch. The large, synchronised waveforms indicate deep slow-wave sleep."),
"REM sleep" => format!("The EEG (std ≈ {std:.1} μV, mean abs {mab:.1} μV) shows low-amplitude mixed-frequency activity with sawtooth waves, resembling wakefulness. The desynchronised pattern alongside theta-band content indicates REM sleep."),
other => format!("The EEG segment (std ≈ {std:.1} μV, mean abs {mab:.1} μV) is consistent with {other}."),
};
format!("{description}\n\nAnswer: {label}")
}
fn ecg_rationale(has_anomaly: bool, lead1: &[f32]) -> (String, String, String) {
let n = lead1.len() as f32;
let mean = lead1.iter().sum::<f32>() / n;
let std = (lead1.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n).sqrt();
let peak = lead1.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
let (label, desc, context) = if has_anomaly {
(
"abnormal rhythm",
format!(
"The ECG lead shows irregular amplitude variations (std ≈ {std:.3} mV, \
peak ≈ {peak:.3} mV). RR intervals appear irregular and the baseline shows \
oscillatory activity without consistent P-wave morphology. These features \
are inconsistent with normal sinus rhythm and suggest an underlying arrhythmia."
),
"Patient referred for cardiac monitoring due to palpitations.",
)
} else {
(
"normal sinus rhythm",
format!(
"The ECG lead shows regular, repeating complexes (std ≈ {std:.3} mV, \
peak ≈ {peak:.3} mV). RR intervals are consistent, P-waves precede each \
QRS complex, and QRS morphology is narrow and uniform. These features are \
consistent with normal sinus rhythm."
),
"Routine cardiac screening.",
)
};
let question = "Does this short ECG segment show normal sinus rhythm or an abnormal rhythm?";
let rationale = format!("{desc}\n\nAnswer: {label}");
(question.to_string(), rationale, context.to_string())
}
fn mcq_options(correct: &str, pool: &[&str], rng: &mut StdRng) -> (String, String) {
let wrong: Vec<&str> = pool.iter().copied().filter(|&a| a != correct).collect();
let mut choices: Vec<&str> = wrong.choose_multiple(rng, 3).copied().collect();
choices.push(correct);
choices.shuffle(rng);
let letters = ['A', 'B', 'C', 'D'];
let mut options = String::new();
let mut answer_letter = 'A';
for (letter, &choice) in letters.iter().zip(choices.iter()) {
write!(options, "({letter}) {choice} ").unwrap();
if choice == correct { answer_letter = *letter; }
}
(options.trim_end().to_string(), format!("({answer_letter}) {correct}"))
}
fn wisdm_motion_note(label: &str) -> &'static str {
match label {
"walking" => "Regular cadence oscillations at ~2 Hz.",
"jogging" => "High-amplitude impacts at ~3 Hz.",
"climbing stairs" => "Asymmetric step pattern with moderate thrust.",
"sitting" => "Near-static signal dominated by gravity.",
"standing" => "Stable 1-g baseline with micro-tremor noise.",
"kicking" => "Sharp impulsive transients separated by rest.",
"catching" => "Sudden deceleration bursts from catching motion.",
"dribbling" => "Rhythmic downward pulses at regular tempo.",
"writing" => "Low-amplitude fine-motor micro-movements.",
"clapping" => "Bilateral symmetric impact pairs at regular rate.",
"brushing teeth" => "Rapid small-amplitude oscillations at ~4 Hz.",
"eating" => "Intermittent lift-to-mouth wrist trajectories.",
_ => "Activity-dependent motion pattern.",
}
}
const WISDM_LABELS: [&str; 12] = [
"walking", "jogging", "climbing stairs", "sitting", "standing",
"kicking", "catching", "dribbling", "writing", "clapping",
"brushing teeth", "eating",
];
fn wisdm_label(idx: usize) -> &'static str {
WISDM_LABELS.get(idx).copied().unwrap_or("unknown")
}
const SLEEP_STAGES: [&str; 5] = [
"Wake", "Non-REM stage 1", "Non-REM stage 2", "Non-REM stage 3", "REM sleep",
];