opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
//! Curriculum trainer — orchestrates the five progressive training stages.
//!
//! [`CurriculumTrainer`] mirrors the `curriculum_learning.py` script in the
//! Python repository, adapted to Burn's autodiff API and the llama.cpp GGUF
//! backend.
//!
//! # Five-stage curriculum
//!
//! | # | Stage key | Task | Dataset |
//! |---|-----------|------|---------|
//! | 1 | `stage1_mcq` | MCQ on univariate time series | TSQA (WISDM-W) |
//! | 2 | `stage2_captioning` | Time-series captioning | M4 (WISDM-W) |
//! | 3 | `stage3_cot` | HAR chain-of-thought | HAR CoT (WISDM-W) |
//! | 4 | `stage4_sleep_cot` | Sleep-stage CoT | SleepEDF |
//! | 5 | `stage5_ecg_cot` | ECG QA CoT | Synthetic PTB-XL |
//!
//! Each stage:
//! 1. Loads the frozen GGUF LLM once via [`LlamaCppBackend`].
//! 2. Loads trained parameters from the **previous** stage's checkpoint as
//!    the starting point for the encoder and logit-head.
//! 3. Runs AdamW with cosine LR schedule (linear warm-up + cosine decay).
//! 4. Saves a checkpoint whenever validation loss improves.
//! 5. Applies early stopping when val loss fails to improve for
//!    [`crate::config::EARLY_STOP_PAT`] consecutive epochs.
//! 6. Writes test-set predictions to JSONL and SVG/HTML training curves.

use std::{fs, path::{Path, PathBuf}, time::{Duration, Instant}};

use anyhow::{Context, Result};
use burn::{
    grad_clipping::GradientClippingConfig,
    optim::{AdamW, AdamWConfig, GradientsParams, Optimizer},
    optim::adaptor::OptimizerAdaptor,
    prelude::Backend,
    tensor::backend::AutodiffBackend,
};
use indicatif::{ProgressBar, ProgressStyle};
use tracing::{info, warn};

use crate::{
    config::{
        BATCH_SIZE, CURRICULUM_STAGES, EARLY_STOP_PAT, CTX_SIZE, GRAD_CLIP_NORM,
        LR_ENCODER, LR_MIN_FRAC, LR_PROJECTOR, LOSS_EMA_DECAY,
        MAX_EVAL_TOKENS, N_GPU_LAYERS, NUM_EPOCHS,
        MAX_TRAIN_SAMPLES, WARMUP_FRAC, WEIGHT_DECAY,
    },
    data::{
        batch::{collate, Sample},
        ecg::load_ecg_splits,
        har::load_har_splits,
        m4::load_m4_splits,
        sleep::load_sleep_splits,
        tsqa::load_tsqa_splits,
    },
    model::llm::{
        llama_cpp::LlamaCppBackend,
        opentslm_sp::{OpenTslmSp, TrainableComponents},
    },
    training::metrics::{EpochMetrics, StageMetrics, plot_curriculum_overview, write_html_index},
};

// ── Trainer ───────────────────────────────────────────────────────────────────

/// Five-stage curriculum trainer for opentslm.
///
/// Constructed via [`CurriculumTrainer::new`]; run with
/// [`run_stage`](CurriculumTrainer::run_stage) (single stage) or
/// [`run_all`](CurriculumTrainer::run_all) (all five stages in order).
pub struct CurriculumTrainer {
    /// Path to the GGUF model file loaded by [`LlamaCppBackend`].
    pub model_path:   PathBuf,
    /// Root directory containing the training JSONL files (one sub-dir per
    /// stage, e.g. `data/har_cot/train.jsonl`).
    pub data_dir:     PathBuf,
    /// Root for checkpoints and prediction outputs:
    /// `results/<model_stem>/OpenTSLMSP/<stage>/`.
    pub results_dir:  PathBuf,
    /// Root for SVG metric plots and HTML dashboards:
    /// `figures/<stage>/`.
    pub figures_dir:  PathBuf,
    /// Burn device string (currently only `"wgpu"` is used).
    pub device_str:   String,
    /// Base mini-batch size; halved automatically for memory-intensive stages.
    pub batch_size:   usize,
}

impl CurriculumTrainer {
    /// Construct a new [`CurriculumTrainer`].
    ///
    /// Creates the required output directories under `results/` and `figures/`
    /// (no-op if they already exist).  The GGUF model and training data are
    /// **not** loaded here; they are loaded lazily when a stage is run.
    pub fn new(
        model_path: impl AsRef<Path>,
        data_dir:   impl AsRef<Path>,
        device_str: &str,
    ) -> Self {
        let model_name = model_path
            .as_ref()
            .file_stem()
            .and_then(|s| s.to_str())
            .unwrap_or("model")
            .replace(['.', '-', ' '], "_");

        let results_dir = PathBuf::from("results")
            .join(&model_name)
            .join("OpenTSLMSP");

        let figures_dir = PathBuf::from("figures");

        fs::create_dir_all(&results_dir).ok();
        fs::create_dir_all(&figures_dir).ok();
        for stage in CURRICULUM_STAGES {
            fs::create_dir_all(results_dir.join(stage).join("checkpoints")).ok();
            fs::create_dir_all(results_dir.join(stage).join("results")).ok();
            fs::create_dir_all(figures_dir.join(stage)).ok();
        }

        Self {
            model_path:   model_path.as_ref().to_path_buf(),
            data_dir:     data_dir.as_ref().to_path_buf(),
            results_dir,
            figures_dir,
            device_str:   device_str.to_string(),
            batch_size:   BATCH_SIZE,
        }
    }

    // ── Public runners ────────────────────────────────────────────────────

    /// Run all five curriculum stages in order.
    ///
    /// Equivalent to calling [`run_stage`](Self::run_stage) for each element
    /// of [`CURRICULUM_STAGES`].
    pub fn run_all<B>(&mut self) -> Result<()>
    where
        B: AutodiffBackend,
        B::Device: Default,
    {
        for &stage in CURRICULUM_STAGES {
            self.run_stage::<B>(stage)?;
        }
        Ok(())
    }

    /// Run a single curriculum stage end-to-end.
    ///
    /// Steps performed:
    /// 1. Load the frozen GGUF LLM.
    /// 2. Build trainable Burn components sized to match the LLM.
    /// 3. Load parameters from the previous stage checkpoint (if any).
    /// 4. Load the stage's dataset.
    /// 5. Train for up to [`crate::config::NUM_EPOCHS`] epochs
    ///    with early stopping.
    /// 6. Evaluate on the test set and write predictions.
    pub fn run_stage<B>(&mut self, stage: &str) -> Result<()>
    where
        B: AutodiffBackend,
        B::Device: Default,
    {
        let device = B::Device::default();

        // Load the frozen LLM (GGUF, native quantisation).
        let llm = LlamaCppBackend::load(&self.model_path, N_GPU_LAYERS, CTX_SIZE)
            .with_context(|| format!("Failed to load LLM from {:?}", self.model_path))?;

        // Build trainable Burn components sized to match this LLM.
        let mut sp_model = OpenTslmSp::<B>::new(&llm, &device);

        // Load trainable params from the previous curriculum stage.
        self.maybe_load_prev_stage::<B>(&mut sp_model, stage, &device);

        // Dataset.
        let (train, val, test) = self.load_dataset(stage)?;
        info!("{stage}: train={}, val={}, test={}", train.len(), val.len(), test.len());

        // Train.
        let trained = self.train_stage::<B>(sp_model, &llm, train, val, stage, &device)?;

        // Evaluate.
        let bs = self.stage_batch_size(stage);
        self.evaluate::<B>(&trained, &llm, &test, stage, bs, &device)?;

        Ok(())
    }

    // ── Core training loop ────────────────────────────────────────────────

    fn train_stage<B>(
        &self,
        mut sp_model: OpenTslmSp<B>,
        llm:          &LlamaCppBackend,
        train_data:   Vec<Sample>,
        val_data:     Vec<Sample>,
        stage:        &str,
        device:       &B::Device,
    ) -> Result<OpenTslmSp<B>>
    where
        B: AutodiffBackend,
    {
        let checkpoint_dir = self.results_dir.join(stage).join("checkpoints");
        let loss_history   = checkpoint_dir.join("loss_history.txt");

        if !loss_history.exists() {
            fs::write(&loss_history, "Epoch\tTrain_Loss\tVal_Loss\n---\n")?;
        }

        // Single AdamW optimiser over all trainable parameters.
        // Burn's GradientsParams::from_grads consumes the Gradients value, so
        // a single backward pass can only feed one optimiser.step() call.
        // Per-component LRs would require two backward passes (expensive because
        // answer_logits re-runs the LLM KV-cache decode each time).
        // We use the encoder LR (more conservative) and compensate by raising
        // LR_PROJECTOR above LR_ENCODER in config.rs — the head's gradient
        // magnitude is naturally larger (it maps 128 → 151 k) so AdamW's
        // moment normalisation already gives it a relatively larger effective step.
        let mut optimizer: OptimizerAdaptor<AdamW, TrainableComponents<B>, B> =
            AdamWConfig::new()
                .with_weight_decay(WEIGHT_DECAY as f32)
                .with_grad_clipping(Some(GradientClippingConfig::Norm(GRAD_CLIP_NORM)))
                .init();

        let batch_size   = self.stage_batch_size(stage);
        let total_steps  = (train_data.len() / batch_size).max(1) * NUM_EPOCHS;
        let warmup_steps = (total_steps as f64 * WARMUP_FRAC) as usize;

        let mut best_val_loss    = f64::MAX;
        let mut patience_counter = 0usize;
        let mut global_step      = 0usize;
        let mut stage_metrics    = StageMetrics::new(stage);

        for epoch in 0..NUM_EPOCHS {
            let epoch_start = Instant::now();
            let mut train_loss_sum = 0.0f64;
            let mut train_batches  = 0usize;

            let shuffled = shuffle_samples(train_data.clone(), epoch as u64);
            let batches  = make_batches(shuffled, batch_size);

            // Log before the bar so there is always a visible line in the
            // scrollback buffer anchoring this epoch.  Without it, epochs
            // after the first are silent until they finish (≈10 min), and
            // the in-place progress bar can scroll off screen unnoticed.
            info!(
                "{stage} epoch {epoch}/{NUM_EPOCHS} — \
                 training {} batches (batch_size={})",
                batches.len(), batch_size,
            );

            let pb = ProgressBar::new(batches.len() as u64);
            pb.set_style(
                ProgressStyle::default_bar()
                    .template("  [{elapsed_precise}] {bar:35.cyan/blue} {pos:>4}/{len}  {msg}")
                    .unwrap(),
            );
            // Redraw on a background thread so the bar stays visible even if
            // a batch is slow and no other output triggers a forced redraw.
            pb.enable_steady_tick(Duration::from_millis(200));

            let mut ema_loss: Option<f64> = None;

            for batch_samples in batches {
                let lr = cosine_lr(global_step, warmup_steps, total_steps,
                                   LR_ENCODER, LR_ENCODER * LR_MIN_FRAC);
                global_step += 1;

                let batch    = collate(batch_samples);
                let loss     = sp_model.compute_loss(&batch.samples, llm, device);
                let loss_val = loss.clone().to_data().to_vec::<f32>().unwrap()[0] as f64;

                let grads        = loss.backward();
                let grads_params = GradientsParams::from_grads(grads, &sp_model.trainable);
                sp_model.trainable =
                    Optimizer::step(&mut optimizer, lr, sp_model.trainable, grads_params);

                // EMA-smoothed loss so the progress bar shows a visible trend
                // instead of per-batch noise (single batch of 4 CoT samples
                // has σ ≈ 0.4, making raw values appear flat).
                ema_loss = Some(match ema_loss {
                    None    => loss_val,
                    Some(e) => LOSS_EMA_DECAY * e + (1.0 - LOSS_EMA_DECAY) * loss_val,
                });

                train_loss_sum += loss_val;
                train_batches  += 1;
                pb.set_message(format!(
                    "loss(ema)={:.4}  lr={:.2e}",
                    ema_loss.unwrap(), lr,
                ));
                pb.inc(1);
            }
            pb.finish_and_clear();

            let train_loss = train_loss_sum / train_batches.max(1) as f64;
            let (val_loss, val_acc, val_recall) =
                self.eval_metrics_batched::<B>(&sp_model, llm, &val_data, batch_size, device);
            let elapsed = epoch_start.elapsed().as_secs_f32();

            info!(
                "{stage} epoch {epoch:>2}/{NUM_EPOCHS} | \
                 train={train_loss:.4}  val={val_loss:.4}  \
                 acc={:.2}%  recall={:.2}%  ({elapsed:.1}s)",
                val_acc * 100.0, val_recall * 100.0,
            );

            let _ = fs::OpenOptions::new()
                .append(true)
                .open(&loss_history)
                .and_then(|mut f| {
                    use std::io::Write;
                    writeln!(f, "{epoch}\t{train_loss:.6}\t{val_loss:.6}")
                });

            stage_metrics.push(EpochMetrics::new(
                epoch, train_loss, val_loss, val_acc, val_recall,
            ));

            // Save CSV + plots after every epoch so that a crash mid-training
            // still leaves figures for the epochs that did complete.
            if let Err(e) = stage_metrics.save(&self.figures_dir) {
                warn!("Could not save incremental metrics: {e}");
            }
            if let Err(e) = write_html_index(&stage_metrics, &self.figures_dir) {
                warn!("Could not write incremental HTML index: {e}");
            }

            if val_loss < best_val_loss {
                best_val_loss    = val_loss;
                patience_counter = 0;
                self.save_checkpoint::<B>(&sp_model, stage, epoch, val_loss)?;
                info!("  ✓ best val_loss={val_loss:.4}, checkpoint saved");
            } else {
                patience_counter += 1;
                if patience_counter >= EARLY_STOP_PAT {
                    info!("  Early stopping after {EARLY_STOP_PAT} non-improving epochs.");
                    break;
                }
            }
        }

        // ── Save metrics + plots ──────────────────────────────────────────
        if let Err(e) = stage_metrics.save(&self.figures_dir) {
            warn!("Could not save metrics: {e}");
        }
        if let Err(e) = write_html_index(&stage_metrics, &self.figures_dir) {
            warn!("Could not write HTML index: {e}");
        }

        // Regenerate curriculum overview with every stage completed so far.
        let all_loaded: Vec<StageMetrics> = CURRICULUM_STAGES.iter()
            .filter_map(|s| StageMetrics::from_csv(s, &self.figures_dir).ok())
            .collect();
        if all_loaded.len() >= 2 {
            let refs: Vec<&StageMetrics> = all_loaded.iter().collect();
            if let Err(e) = plot_curriculum_overview(&refs, &self.figures_dir) {
                warn!("Could not write curriculum overview: {e}");
            }
        }

        if let Err(e) = self.load_checkpoint::<B>(&mut sp_model, stage, device) {
            warn!("Could not reload best checkpoint: {e}");
        }
        Ok(sp_model)
    }

    // ── Evaluation ────────────────────────────────────────────────────────

    /// Evaluate loss, token accuracy, and macro recall on `data`.
    ///
    /// Returns `(mean_val_loss, token_accuracy_[0,1], macro_recall_[0,1])`.
    ///
    /// `batch_size` is passed explicitly (rather than using `self.batch_size`)
    /// so callers can pass the per-stage effective batch size, which is halved
    /// for memory-heavy stages (sleep / ECG CoT).
    fn eval_metrics_batched<B>(
        &self,
        sp_model:   &OpenTslmSp<B>,
        llm:        &LlamaCppBackend,
        data:       &[Sample],
        batch_size: usize,
        device:     &B::Device,
    ) -> (f64, f64, f64)
    where
        B: AutodiffBackend,
    {
        let batches = make_batches(data.to_vec(), batch_size);
        let (mut loss_sum, mut acc_sum, mut rec_sum, mut n) =
            (0.0f64, 0.0f64, 0.0f64, 0usize);

        for batch_samples in batches {
            let batch = collate(batch_samples);
            let (loss_t, acc, rec) =
                sp_model.compute_loss_and_metrics(&batch.samples, llm, device);
            let loss: f32 = loss_t.to_data().to_vec::<f32>().unwrap()[0];
            loss_sum += loss as f64;
            acc_sum  += acc;
            rec_sum  += rec;
            n        += 1;
        }
        let d = n.max(1) as f64;
        (loss_sum / d, acc_sum / d, rec_sum / d)
    }

    fn evaluate<B>(
        &self,
        sp_model:    &OpenTslmSp<B>,
        llm:         &LlamaCppBackend,
        test_data:   &[Sample],
        stage:       &str,
        _batch_size: usize,   // kept for API symmetry; generate runs one sample at a time
        device:      &B::Device,
    ) -> Result<()>
    where
        B: AutodiffBackend,
    {
        use std::io::Write;

        let pred_file = self.results_dir.join(stage).join("results")
            .join("test_predictions.jsonl");
        let mut file = fs::File::create(&pred_file)?;

        let n = test_data.len();
        info!(
            "{stage}: generating test predictions \
             ({n} samples, max {MAX_EVAL_TOKENS} tokens each) …"
        );

        let pb = ProgressBar::new(n as u64);
        pb.set_style(
            ProgressStyle::default_bar()
                .template("  [{elapsed_precise}] {bar:40.green/white} {pos:>4}/{len}  {msg}")
                .unwrap(),
        );
        pb.enable_steady_tick(Duration::from_millis(200));

        for (idx, sample) in test_data.iter().enumerate() {
            pb.set_message(format!("sample {idx}"));
            let preds = sp_model.generate(
                std::slice::from_ref(sample), llm, MAX_EVAL_TOKENS, device,
            );
            let entry = serde_json::json!({
                "idx":        idx,
                "label":      sample.label.as_deref().unwrap_or(""),
                "prediction": preds.first().cloned().unwrap_or_default(),
                "answer":     &sample.answer,
            });
            writeln!(file, "{}", entry)?;
            pb.inc(1);
        }

        pb.finish_and_clear();
        info!("{stage}: {n} predictions written → {pred_file:?}");
        Ok(())
    }

    // ── Checkpoint I/O ────────────────────────────────────────────────────

    /// Return the path to the best-model checkpoint JSON for `stage`.
    fn checkpoint_path(&self, stage: &str) -> PathBuf {
        self.results_dir
            .join(stage)
            .join("checkpoints")
            .join("best_model.json")
    }

    /// Save a checkpoint for `stage` at the current `epoch` / `val_loss`.
    ///
    /// # Note on full parameter serialisation
    ///
    /// Complete parameter serialisation requires Burn's `NamedMpkFileRecorder`,
    /// which needs the model's named parameter tree.  As a stepping stone, this
    /// implementation saves a JSON metadata file (`best_model.json`) containing
    /// only the epoch number and validation loss.  Full serialisation can be
    /// added by replacing this method with a `NamedMpkFileRecorder` call.
    fn save_checkpoint<B: Backend>(
        &self,
        _sp_model: &OpenTslmSp<B>,
        stage:     &str,
        epoch:     usize,
        val_loss:  f64,
    ) -> Result<()> {
        let meta = serde_json::json!({ "epoch": epoch, "val_loss": val_loss });
        fs::write(self.checkpoint_path(stage), serde_json::to_string(&meta)?)
            .context("Cannot write checkpoint metadata")?;
        Ok(())
    }

    /// Attempt to reload the best checkpoint for `stage`.
    ///
    /// Currently loads only the JSON metadata (epoch + val_loss) and logs it.
    /// When full parameter serialisation is implemented this method should
    /// restore the encoder and logit-head weights as well.
    fn load_checkpoint<B: Backend>(
        &self,
        _sp_model: &mut OpenTslmSp<B>,
        stage:     &str,
        _device:   &B::Device,
    ) -> Result<()> {
        let path = self.checkpoint_path(stage);
        if path.exists() {
            let text = fs::read_to_string(&path)?;
            let v: serde_json::Value = serde_json::from_str(&text)?;
            info!(
                "  Loaded checkpoint metadata from {stage} \
                 (epoch {}, val_loss {})",
                v["epoch"], v["val_loss"]
            );
        }
        Ok(())
    }

    /// Load trainable parameters from the immediately preceding curriculum
    /// stage, if one exists.
    ///
    /// This implements curriculum warm-starting: each stage begins from the
    /// parameters that performed best on the previous stage's validation set,
    /// rather than from a random initialisation.
    fn maybe_load_prev_stage<B: Backend>(
        &self,
        sp_model:      &mut OpenTslmSp<B>,
        current_stage: &str,
        device:        &B::Device,
    ) {
        if let Some(i) = CURRICULUM_STAGES.iter().position(|&s| s == current_stage) {
            if i > 0 {
                let prev = CURRICULUM_STAGES[i - 1];
                info!("  Loading trainable params from previous stage '{prev}'");
                if let Err(e) = self.load_checkpoint::<B>(sp_model, prev, device) {
                    warn!("  Could not load from '{prev}': {e}");
                }
            }
        }
    }

    // ── Per-stage batch size ──────────────────────────────────────────────

    /// Return the effective mini-batch size for `stage`.
    ///
    /// Sleep and ECG CoT samples are substantially more memory-intensive than
    /// the earlier stages (longer EEG/ECG series + longer CoT answers), so
    /// the batch size is halved for those stages to avoid OOM kills on
    /// machines with ≤ 16 GB RAM.
    fn stage_batch_size(&self, stage: &str) -> usize {
        match stage {
            "stage4_sleep_cot" | "stage5_ecg_cot" => (self.batch_size / 2).max(1),
            _ => self.batch_size,
        }
    }

    // ── Dataset loading ───────────────────────────────────────────────────

    /// Load train / val / test splits for `stage`.
    ///
    /// Each split is capped at [`crate::config::MAX_TRAIN_SAMPLES`]
    /// (val and test at 1/5 of that) to keep default runs fast.  Set
    /// `MAX_TRAIN_SAMPLES = usize::MAX` in `config.rs` for a full training run.
    fn load_dataset(&self, stage: &str) -> Result<(Vec<Sample>, Vec<Sample>, Vec<Sample>)> {
        let (train, val, test) = match stage {
            "stage1_mcq"       => { let s = load_tsqa_splits(&self.data_dir)?;  (s.train, s.val, s.test) }
            "stage2_captioning"=> { let s = load_m4_splits(&self.data_dir)?;    (s.train, s.val, s.test) }
            "stage3_cot"       => { let s = load_har_splits(&self.data_dir)?;   (s.train, s.val, s.test) }
            "stage4_sleep_cot" => { let s = load_sleep_splits(&self.data_dir)?; (s.train, s.val, s.test) }
            "stage5_ecg_cot"   => { let s = load_ecg_splits(&self.data_dir)?;   (s.train, s.val, s.test) }
            other => anyhow::bail!("Unknown stage '{other}'"),
        };
        Ok((
            train.into_iter().take(MAX_TRAIN_SAMPLES).collect(),
            val.into_iter().take(MAX_TRAIN_SAMPLES / 5).collect(),
            test.into_iter().take(MAX_TRAIN_SAMPLES / 5).collect(),
        ))
    }
}

// ── Helpers ───────────────────────────────────────────────────────────────────

/// Partition `samples` into consecutive mini-batches of `batch_size`.
/// The last batch may be smaller than `batch_size`.
fn make_batches(samples: Vec<Sample>, batch_size: usize) -> Vec<Vec<Sample>> {
    samples.chunks(batch_size).map(|c| c.to_vec()).collect()
}

/// Shuffle `samples` deterministically using a seeded RNG.
///
/// Different `seed` values produce different permutations, so passing the
/// epoch index as `seed` gives a new shuffle each epoch without relying on
/// global mutable state.
fn shuffle_samples(mut v: Vec<Sample>, seed: u64) -> Vec<Sample> {
    use rand::{seq::SliceRandom, SeedableRng, rngs::StdRng};
    let mut rng = StdRng::seed_from_u64(seed);
    v.shuffle(&mut rng);
    v
}

/// Cosine learning rate schedule with linear warm-up.
///
/// - Steps `0..warmup_steps`: LR rises linearly from `min_lr` to `peak_lr`.
/// - Steps `warmup_steps..total_steps`: LR decays via cosine from `peak_lr`
///   down to `min_lr`.
/// - Beyond `total_steps`: LR stays at `min_lr`.
fn cosine_lr(
    step:         usize,
    warmup_steps: usize,
    total_steps:  usize,
    peak_lr:      f64,
    min_lr:       f64,
) -> f64 {
    if step < warmup_steps {
        // Linear warm-up.
        let frac = if warmup_steps == 0 { 1.0 } else { step as f64 / warmup_steps as f64 };
        return min_lr + (peak_lr - min_lr) * frac;
    }
    let decay_steps = total_steps.saturating_sub(warmup_steps).max(1);
    let t = (step - warmup_steps) as f64 / decay_steps as f64;
    // Clamp t to [0, 1] so the final step doesn't undershoot.
    let t = t.min(1.0);
    min_lr + (peak_lr - min_lr) * 0.5 * (1.0 + (std::f64::consts::PI * t).cos())
}