use std::{fs, path::{Path, PathBuf}, time::{Duration, Instant}};
use anyhow::{Context, Result};
use burn::{
grad_clipping::GradientClippingConfig,
optim::{AdamW, AdamWConfig, GradientsParams, Optimizer},
optim::adaptor::OptimizerAdaptor,
prelude::Backend,
tensor::backend::AutodiffBackend,
};
use indicatif::{ProgressBar, ProgressStyle};
use tracing::{info, warn};
use crate::{
config::{
BATCH_SIZE, CURRICULUM_STAGES, EARLY_STOP_PAT, CTX_SIZE, GRAD_CLIP_NORM,
LR_ENCODER, LR_MIN_FRAC, LR_PROJECTOR, LOSS_EMA_DECAY,
MAX_EVAL_TOKENS, N_GPU_LAYERS, NUM_EPOCHS,
MAX_TRAIN_SAMPLES, WARMUP_FRAC, WEIGHT_DECAY,
},
data::{
batch::{collate, Sample},
ecg::load_ecg_splits,
har::load_har_splits,
m4::load_m4_splits,
sleep::load_sleep_splits,
tsqa::load_tsqa_splits,
},
model::llm::{
llama_cpp::LlamaCppBackend,
opentslm_sp::{OpenTslmSp, TrainableComponents},
},
training::metrics::{EpochMetrics, StageMetrics, plot_curriculum_overview, write_html_index},
};
pub struct CurriculumTrainer {
pub model_path: PathBuf,
pub data_dir: PathBuf,
pub results_dir: PathBuf,
pub figures_dir: PathBuf,
pub device_str: String,
pub batch_size: usize,
}
impl CurriculumTrainer {
pub fn new(
model_path: impl AsRef<Path>,
data_dir: impl AsRef<Path>,
device_str: &str,
) -> Self {
let model_name = model_path
.as_ref()
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("model")
.replace(['.', '-', ' '], "_");
let results_dir = PathBuf::from("results")
.join(&model_name)
.join("OpenTSLMSP");
let figures_dir = PathBuf::from("figures");
fs::create_dir_all(&results_dir).ok();
fs::create_dir_all(&figures_dir).ok();
for stage in CURRICULUM_STAGES {
fs::create_dir_all(results_dir.join(stage).join("checkpoints")).ok();
fs::create_dir_all(results_dir.join(stage).join("results")).ok();
fs::create_dir_all(figures_dir.join(stage)).ok();
}
Self {
model_path: model_path.as_ref().to_path_buf(),
data_dir: data_dir.as_ref().to_path_buf(),
results_dir,
figures_dir,
device_str: device_str.to_string(),
batch_size: BATCH_SIZE,
}
}
pub fn run_all<B>(&mut self) -> Result<()>
where
B: AutodiffBackend,
B::Device: Default,
{
for &stage in CURRICULUM_STAGES {
self.run_stage::<B>(stage)?;
}
Ok(())
}
pub fn run_stage<B>(&mut self, stage: &str) -> Result<()>
where
B: AutodiffBackend,
B::Device: Default,
{
let device = B::Device::default();
let llm = LlamaCppBackend::load(&self.model_path, N_GPU_LAYERS, CTX_SIZE)
.with_context(|| format!("Failed to load LLM from {:?}", self.model_path))?;
let mut sp_model = OpenTslmSp::<B>::new(&llm, &device);
self.maybe_load_prev_stage::<B>(&mut sp_model, stage, &device);
let (train, val, test) = self.load_dataset(stage)?;
info!("{stage}: train={}, val={}, test={}", train.len(), val.len(), test.len());
let trained = self.train_stage::<B>(sp_model, &llm, train, val, stage, &device)?;
let bs = self.stage_batch_size(stage);
self.evaluate::<B>(&trained, &llm, &test, stage, bs, &device)?;
Ok(())
}
fn train_stage<B>(
&self,
mut sp_model: OpenTslmSp<B>,
llm: &LlamaCppBackend,
train_data: Vec<Sample>,
val_data: Vec<Sample>,
stage: &str,
device: &B::Device,
) -> Result<OpenTslmSp<B>>
where
B: AutodiffBackend,
{
let checkpoint_dir = self.results_dir.join(stage).join("checkpoints");
let loss_history = checkpoint_dir.join("loss_history.txt");
if !loss_history.exists() {
fs::write(&loss_history, "Epoch\tTrain_Loss\tVal_Loss\n---\n")?;
}
let mut optimizer: OptimizerAdaptor<AdamW, TrainableComponents<B>, B> =
AdamWConfig::new()
.with_weight_decay(WEIGHT_DECAY as f32)
.with_grad_clipping(Some(GradientClippingConfig::Norm(GRAD_CLIP_NORM)))
.init();
let batch_size = self.stage_batch_size(stage);
let total_steps = (train_data.len() / batch_size).max(1) * NUM_EPOCHS;
let warmup_steps = (total_steps as f64 * WARMUP_FRAC) as usize;
let mut best_val_loss = f64::MAX;
let mut patience_counter = 0usize;
let mut global_step = 0usize;
let mut stage_metrics = StageMetrics::new(stage);
for epoch in 0..NUM_EPOCHS {
let epoch_start = Instant::now();
let mut train_loss_sum = 0.0f64;
let mut train_batches = 0usize;
let shuffled = shuffle_samples(train_data.clone(), epoch as u64);
let batches = make_batches(shuffled, batch_size);
info!(
"{stage} epoch {epoch}/{NUM_EPOCHS} — \
training {} batches (batch_size={})",
batches.len(), batch_size,
);
let pb = ProgressBar::new(batches.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(" [{elapsed_precise}] {bar:35.cyan/blue} {pos:>4}/{len} {msg}")
.unwrap(),
);
pb.enable_steady_tick(Duration::from_millis(200));
let mut ema_loss: Option<f64> = None;
for batch_samples in batches {
let lr = cosine_lr(global_step, warmup_steps, total_steps,
LR_ENCODER, LR_ENCODER * LR_MIN_FRAC);
global_step += 1;
let batch = collate(batch_samples);
let loss = sp_model.compute_loss(&batch.samples, llm, device);
let loss_val = loss.clone().to_data().to_vec::<f32>().unwrap()[0] as f64;
let grads = loss.backward();
let grads_params = GradientsParams::from_grads(grads, &sp_model.trainable);
sp_model.trainable =
Optimizer::step(&mut optimizer, lr, sp_model.trainable, grads_params);
ema_loss = Some(match ema_loss {
None => loss_val,
Some(e) => LOSS_EMA_DECAY * e + (1.0 - LOSS_EMA_DECAY) * loss_val,
});
train_loss_sum += loss_val;
train_batches += 1;
pb.set_message(format!(
"loss(ema)={:.4} lr={:.2e}",
ema_loss.unwrap(), lr,
));
pb.inc(1);
}
pb.finish_and_clear();
let train_loss = train_loss_sum / train_batches.max(1) as f64;
let (val_loss, val_acc, val_recall) =
self.eval_metrics_batched::<B>(&sp_model, llm, &val_data, batch_size, device);
let elapsed = epoch_start.elapsed().as_secs_f32();
info!(
"{stage} epoch {epoch:>2}/{NUM_EPOCHS} | \
train={train_loss:.4} val={val_loss:.4} \
acc={:.2}% recall={:.2}% ({elapsed:.1}s)",
val_acc * 100.0, val_recall * 100.0,
);
let _ = fs::OpenOptions::new()
.append(true)
.open(&loss_history)
.and_then(|mut f| {
use std::io::Write;
writeln!(f, "{epoch}\t{train_loss:.6}\t{val_loss:.6}")
});
stage_metrics.push(EpochMetrics::new(
epoch, train_loss, val_loss, val_acc, val_recall,
));
if let Err(e) = stage_metrics.save(&self.figures_dir) {
warn!("Could not save incremental metrics: {e}");
}
if let Err(e) = write_html_index(&stage_metrics, &self.figures_dir) {
warn!("Could not write incremental HTML index: {e}");
}
if val_loss < best_val_loss {
best_val_loss = val_loss;
patience_counter = 0;
self.save_checkpoint::<B>(&sp_model, stage, epoch, val_loss)?;
info!(" ✓ best val_loss={val_loss:.4}, checkpoint saved");
} else {
patience_counter += 1;
if patience_counter >= EARLY_STOP_PAT {
info!(" Early stopping after {EARLY_STOP_PAT} non-improving epochs.");
break;
}
}
}
if let Err(e) = stage_metrics.save(&self.figures_dir) {
warn!("Could not save metrics: {e}");
}
if let Err(e) = write_html_index(&stage_metrics, &self.figures_dir) {
warn!("Could not write HTML index: {e}");
}
let all_loaded: Vec<StageMetrics> = CURRICULUM_STAGES.iter()
.filter_map(|s| StageMetrics::from_csv(s, &self.figures_dir).ok())
.collect();
if all_loaded.len() >= 2 {
let refs: Vec<&StageMetrics> = all_loaded.iter().collect();
if let Err(e) = plot_curriculum_overview(&refs, &self.figures_dir) {
warn!("Could not write curriculum overview: {e}");
}
}
if let Err(e) = self.load_checkpoint::<B>(&mut sp_model, stage, device) {
warn!("Could not reload best checkpoint: {e}");
}
Ok(sp_model)
}
fn eval_metrics_batched<B>(
&self,
sp_model: &OpenTslmSp<B>,
llm: &LlamaCppBackend,
data: &[Sample],
batch_size: usize,
device: &B::Device,
) -> (f64, f64, f64)
where
B: AutodiffBackend,
{
let batches = make_batches(data.to_vec(), batch_size);
let (mut loss_sum, mut acc_sum, mut rec_sum, mut n) =
(0.0f64, 0.0f64, 0.0f64, 0usize);
for batch_samples in batches {
let batch = collate(batch_samples);
let (loss_t, acc, rec) =
sp_model.compute_loss_and_metrics(&batch.samples, llm, device);
let loss: f32 = loss_t.to_data().to_vec::<f32>().unwrap()[0];
loss_sum += loss as f64;
acc_sum += acc;
rec_sum += rec;
n += 1;
}
let d = n.max(1) as f64;
(loss_sum / d, acc_sum / d, rec_sum / d)
}
fn evaluate<B>(
&self,
sp_model: &OpenTslmSp<B>,
llm: &LlamaCppBackend,
test_data: &[Sample],
stage: &str,
_batch_size: usize, device: &B::Device,
) -> Result<()>
where
B: AutodiffBackend,
{
use std::io::Write;
let pred_file = self.results_dir.join(stage).join("results")
.join("test_predictions.jsonl");
let mut file = fs::File::create(&pred_file)?;
let n = test_data.len();
info!(
"{stage}: generating test predictions \
({n} samples, max {MAX_EVAL_TOKENS} tokens each) …"
);
let pb = ProgressBar::new(n as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(" [{elapsed_precise}] {bar:40.green/white} {pos:>4}/{len} {msg}")
.unwrap(),
);
pb.enable_steady_tick(Duration::from_millis(200));
for (idx, sample) in test_data.iter().enumerate() {
pb.set_message(format!("sample {idx}"));
let preds = sp_model.generate(
std::slice::from_ref(sample), llm, MAX_EVAL_TOKENS, device,
);
let entry = serde_json::json!({
"idx": idx,
"label": sample.label.as_deref().unwrap_or(""),
"prediction": preds.first().cloned().unwrap_or_default(),
"answer": &sample.answer,
});
writeln!(file, "{}", entry)?;
pb.inc(1);
}
pb.finish_and_clear();
info!("{stage}: {n} predictions written → {pred_file:?}");
Ok(())
}
fn checkpoint_path(&self, stage: &str) -> PathBuf {
self.results_dir
.join(stage)
.join("checkpoints")
.join("best_model.json")
}
fn save_checkpoint<B: Backend>(
&self,
_sp_model: &OpenTslmSp<B>,
stage: &str,
epoch: usize,
val_loss: f64,
) -> Result<()> {
let meta = serde_json::json!({ "epoch": epoch, "val_loss": val_loss });
fs::write(self.checkpoint_path(stage), serde_json::to_string(&meta)?)
.context("Cannot write checkpoint metadata")?;
Ok(())
}
fn load_checkpoint<B: Backend>(
&self,
_sp_model: &mut OpenTslmSp<B>,
stage: &str,
_device: &B::Device,
) -> Result<()> {
let path = self.checkpoint_path(stage);
if path.exists() {
let text = fs::read_to_string(&path)?;
let v: serde_json::Value = serde_json::from_str(&text)?;
info!(
" Loaded checkpoint metadata from {stage} \
(epoch {}, val_loss {})",
v["epoch"], v["val_loss"]
);
}
Ok(())
}
fn maybe_load_prev_stage<B: Backend>(
&self,
sp_model: &mut OpenTslmSp<B>,
current_stage: &str,
device: &B::Device,
) {
if let Some(i) = CURRICULUM_STAGES.iter().position(|&s| s == current_stage) {
if i > 0 {
let prev = CURRICULUM_STAGES[i - 1];
info!(" Loading trainable params from previous stage '{prev}'");
if let Err(e) = self.load_checkpoint::<B>(sp_model, prev, device) {
warn!(" Could not load from '{prev}': {e}");
}
}
}
}
fn stage_batch_size(&self, stage: &str) -> usize {
match stage {
"stage4_sleep_cot" | "stage5_ecg_cot" => (self.batch_size / 2).max(1),
_ => self.batch_size,
}
}
fn load_dataset(&self, stage: &str) -> Result<(Vec<Sample>, Vec<Sample>, Vec<Sample>)> {
let (train, val, test) = match stage {
"stage1_mcq" => { let s = load_tsqa_splits(&self.data_dir)?; (s.train, s.val, s.test) }
"stage2_captioning"=> { let s = load_m4_splits(&self.data_dir)?; (s.train, s.val, s.test) }
"stage3_cot" => { let s = load_har_splits(&self.data_dir)?; (s.train, s.val, s.test) }
"stage4_sleep_cot" => { let s = load_sleep_splits(&self.data_dir)?; (s.train, s.val, s.test) }
"stage5_ecg_cot" => { let s = load_ecg_splits(&self.data_dir)?; (s.train, s.val, s.test) }
other => anyhow::bail!("Unknown stage '{other}'"),
};
Ok((
train.into_iter().take(MAX_TRAIN_SAMPLES).collect(),
val.into_iter().take(MAX_TRAIN_SAMPLES / 5).collect(),
test.into_iter().take(MAX_TRAIN_SAMPLES / 5).collect(),
))
}
}
fn make_batches(samples: Vec<Sample>, batch_size: usize) -> Vec<Vec<Sample>> {
samples.chunks(batch_size).map(|c| c.to_vec()).collect()
}
fn shuffle_samples(mut v: Vec<Sample>, seed: u64) -> Vec<Sample> {
use rand::{seq::SliceRandom, SeedableRng, rngs::StdRng};
let mut rng = StdRng::seed_from_u64(seed);
v.shuffle(&mut rng);
v
}
fn cosine_lr(
step: usize,
warmup_steps: usize,
total_steps: usize,
peak_lr: f64,
min_lr: f64,
) -> f64 {
if step < warmup_steps {
let frac = if warmup_steps == 0 { 1.0 } else { step as f64 / warmup_steps as f64 };
return min_lr + (peak_lr - min_lr) * frac;
}
let decay_steps = total_steps.saturating_sub(warmup_steps).max(1);
let t = (step - warmup_steps) as f64 / decay_steps as f64;
let t = t.min(1.0);
min_lr + (peak_lr - min_lr) * 0.5 * (1.0 + (std::f64::consts::PI * t).cos())
}