use candle_core::{Result, Tensor, Device, DType};
use candle_nn::{VarMap, VarBuilder, AdamW, ParamsAdamW, Optimizer, loss, ops};
use std::path::Path;
use crate::{TinyRecursiveModel, TRMConfig};
use crate::data::{NumpyDataLoader, BatchDataLoader};
use crate::models::InnerCarry;
use super::scheduler::CosineScheduler;
use super::ema::{EMA, EMAConfig};
use super::checkpoint::{Checkpoint, CheckpointMetadata};
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub num_epochs: usize,
pub batch_size: usize,
pub learning_rate: f64,
pub lr_min: f64,
pub warmup_steps: usize,
pub total_steps: usize,
pub weight_decay: f64,
pub grad_clip: Option<f64>,
pub ema_decay: f64,
pub save_every: usize,
pub eval_every: usize,
pub checkpoint_dir: String,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
num_epochs: 10,
batch_size: 32,
learning_rate: 3e-4,
lr_min: 3e-5,
warmup_steps: 1000,
total_steps: 100000,
weight_decay: 0.1,
grad_clip: Some(1.0),
ema_decay: 0.9999,
save_every: 1000,
eval_every: 500,
checkpoint_dir: "checkpoints".to_string(),
}
}
}
pub struct Trainer {
model: TinyRecursiveModel,
model_config: TRMConfig,
varmap: VarMap,
optimizer: AdamW,
scheduler: CosineScheduler,
ema: Option<EMA>,
config: TrainingConfig,
device: Device,
step: usize,
}
impl Trainer {
pub fn new(
model_config: TRMConfig,
training_config: TrainingConfig,
device: Device,
) -> Result<Self> {
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
let model = TinyRecursiveModel::new(model_config.clone(), vb)
.map_err(|e| candle_core::Error::Msg(format!("Model init failed: {:?}", e)))?;
let optimizer_params = ParamsAdamW {
lr: training_config.learning_rate,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: training_config.weight_decay,
};
let optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
let scheduler = CosineScheduler::new(super::scheduler::CosineSchedulerConfig {
lr_init: training_config.learning_rate,
lr_min: training_config.lr_min,
warmup_steps: training_config.warmup_steps,
total_steps: training_config.total_steps,
});
let ema = None;
Ok(Self {
model,
model_config,
varmap,
optimizer,
scheduler,
ema,
config: training_config,
device,
step: 0,
})
}
fn compute_loss(
&self,
logits: &Tensor,
targets: &Tensor,
) -> Result<Tensor> {
let batch_size = logits.dim(0)?;
let seq_len = logits.dim(1)?;
let num_classes = logits.dim(2)?;
let target_shape = targets.dims();
if target_shape.len() == 2 && target_shape[1] == 1 {
let logits_pooled = logits.mean(1)?;
let targets_flat = targets.flatten_all()?;
let log_probs = ops::log_softmax(&logits_pooled, candle_core::D::Minus1)?
.to_dtype(DType::F32)?;
let mut loss_sum = 0.0f32;
for i in 0..batch_size {
let target_idx = targets_flat.get(i)?.to_scalar::<u32>()? as usize;
let log_prob = log_probs.get(i)?.get(target_idx)?.to_scalar::<f32>()?;
loss_sum -= log_prob;
}
let loss_val = loss_sum / batch_size as f32;
Tensor::from_slice(&[loss_val], 1, &self.device)?.squeeze(0)
} else {
let logits_flat = logits.reshape((batch_size * seq_len, num_classes))?;
let targets_flat = targets.flatten_all()?;
let log_probs = ops::log_softmax(&logits_flat, candle_core::D::Minus1)?
.to_dtype(DType::F32)?;
let mut loss_sum = 0.0f32;
for i in 0..(batch_size * seq_len) {
let target_idx = targets_flat.get(i)?.to_scalar::<u32>()? as usize;
let log_prob = log_probs.get(i)?.get(target_idx)?.to_scalar::<f32>()?;
loss_sum -= log_prob;
}
let loss_val = loss_sum / (batch_size * seq_len) as f32;
Tensor::from_slice(&[loss_val], 1, &self.device)?.squeeze(0)
}
}
pub fn train_step(
&mut self,
input_ids: &Tensor,
target_ids: &Tensor,
) -> Result<f32> {
let batch_size = input_ids.dim(0)?;
let seq_len = input_ids.dim(1)?;
log::debug!("Input dtype: {:?}, Target dtype: {:?}", input_ids.dtype(), target_ids.dtype());
let dtype = if self.device.is_cuda() { DType::F16 } else { DType::F32 };
let carry = InnerCarry::empty(
batch_size,
seq_len,
self.model_config.hidden_size,
dtype,
&self.device,
)?;
log::debug!("Running forward pass...");
let (_new_carry, logits) = self.model.forward(&carry, input_ids)
.map_err(|e| candle_core::Error::Msg(format!("Forward pass failed: {:?}", e)))?;
log::debug!("Logits shape: {:?}, dtype: {:?}", logits.dims(), logits.dtype());
log::debug!("Computing loss...");
let loss = self.compute_loss(&logits, target_ids)
.map_err(|e| candle_core::Error::Msg(format!("Loss computation failed: {:?}", e)))?;
let loss_val = loss.to_scalar::<f32>()?;
let lr = self.scheduler.get_lr();
self.optimizer.set_learning_rate(lr);
self.optimizer.backward_step(&loss)?;
self.scheduler.step();
self.step += 1;
Ok(loss_val)
}
pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P, loss: Option<f64>) -> Result<()> {
std::fs::create_dir_all(&self.config.checkpoint_dir)
.map_err(|e| candle_core::Error::Msg(format!("Failed to create checkpoint dir: {}", e)))?;
self.varmap.save(path.as_ref())?;
let metadata = CheckpointMetadata {
step: self.step,
lr: self.scheduler.get_lr(),
loss,
config: None,
};
let metadata_path = format!("{}.meta.json", path.as_ref().display());
let metadata_json = serde_json::to_string_pretty(&metadata)
.map_err(|e| candle_core::Error::Msg(format!("Metadata serialization failed: {}", e)))?;
std::fs::write(&metadata_path, metadata_json)
.map_err(|e| candle_core::Error::Msg(format!("Metadata write failed: {}", e)))?;
log::debug!("Saved checkpoint weights to {} and metadata to {}", path.as_ref().display(), metadata_path);
Ok(())
}
pub fn train_epoch(&mut self, dataloader: &mut impl BatchDataLoader) -> Result<f32> {
let mut total_loss = 0.0;
let mut num_batches = 0;
dataloader.reset();
while let Some((input_ids, target_ids)) = dataloader.next_batch(&self.device)? {
let loss = self.train_step(&input_ids, &target_ids)?;
total_loss += loss;
num_batches += 1;
if self.step % 100 == 0 {
log::info!(
"Step {}: loss={:.4}, lr={:.6}",
self.step,
loss,
self.scheduler.get_lr()
);
}
if self.step % self.config.save_every == 0 {
let checkpoint_path = format!(
"{}/checkpoint_step_{}.safetensors",
self.config.checkpoint_dir,
self.step
);
log::info!("Saving checkpoint to {}", checkpoint_path);
self.save_checkpoint(&checkpoint_path, Some(loss as f64))?;
}
}
let avg_loss = total_loss / num_batches as f32;
Ok(avg_loss)
}
pub fn train(&mut self, dataloader: &mut impl BatchDataLoader) -> Result<()> {
log::info!("Starting training for {} epochs", self.config.num_epochs);
log::info!("Total batches per epoch: {}", dataloader.num_batches());
for epoch in 0..self.config.num_epochs {
log::info!("=== Epoch {}/{} ===", epoch + 1, self.config.num_epochs);
let avg_loss = self.train_epoch(dataloader)?;
log::info!(
"Epoch {} complete: avg_loss={:.4}, step={}",
epoch + 1,
avg_loss,
self.step
);
let checkpoint_path = format!(
"{}/checkpoint_epoch_{}.safetensors",
self.config.checkpoint_dir,
epoch + 1
);
self.save_checkpoint(&checkpoint_path, Some(avg_loss as f64))?;
}
log::info!("Training complete!");
let final_path = format!("{}/final_model.safetensors", self.config.checkpoint_dir);
log::info!("Saving final model to {}", final_path);
self.varmap.save(&final_path)?;
Ok(())
}
}