#![cfg(feature = "cuda")]
use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
use crate::train::pretrain_real::llama_370m_train_config;
use crate::train::transformer_trainer::{CudaTransformerTrainer, LMBatch};
use std::cell::RefCell;
use std::rc::Rc;
pub type SharedCudaTrainer = Rc<RefCell<CudaTransformerTrainer>>;
pub fn build_shared_cuda_trainer(
lr: f32,
seq_length: usize,
seed: u64,
) -> crate::Result<SharedCudaTrainer> {
let cfg = llama_370m_train_config(lr, seq_length, seed);
let trainer = CudaTransformerTrainer::new(cfg)?;
#[cfg(debug_assertions)]
{
let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
debug_assert!(
(366_000_000..=374_000_000).contains(¶m_count),
"INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
);
}
Ok(Rc::new(RefCell::new(trainer)))
}
pub struct CudaRealStepFn {
trainer: SharedCudaTrainer,
batches: Box<dyn Iterator<Item = LMBatch>>,
}
impl CudaRealStepFn {
pub fn new(trainer: SharedCudaTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
Self { trainer, batches }
}
}
impl StepFn for CudaRealStepFn {
fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
let Some(batch) = self.batches.next() else {
return (1.0, 1.0);
};
let mut trainer = self.trainer.borrow_mut();
let loss = trainer.train_batch(&batch);
let grad_norm = trainer.last_grad_norm();
(loss, grad_norm)
}
}
pub struct CudaRealValFn {
trainer: SharedCudaTrainer,
held_out: Vec<LMBatch>,
}
impl CudaRealValFn {
pub fn new(trainer: SharedCudaTrainer, held_out: Vec<LMBatch>) -> Self {
Self { trainer, held_out }
}
}
impl ValFn for CudaRealValFn {
fn validate(&mut self, _epoch: usize) -> f32 {
if self.held_out.is_empty() {
return f32::NAN;
}
let mut trainer = self.trainer.borrow_mut();
let mut total_loss = 0.0_f32;
let mut count = 0_usize;
for batch in &self.held_out {
if batch.batch_size == 0 {
continue;
}
total_loss += trainer.eval_batch(batch);
count += 1;
}
if count == 0 {
f32::NAN
} else {
total_loss / count as f32
}
}
}
pub struct CudaAprCheckpointFn {
trainer: SharedCudaTrainer,
model_name: String,
architecture: String,
}
impl CudaAprCheckpointFn {
pub fn new(
trainer: SharedCudaTrainer,
model_name: impl Into<String>,
architecture: impl Into<String>,
) -> Self {
Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
}
}
impl CheckpointFn for CudaAprCheckpointFn {
fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
let mut trainer = self.trainer.borrow_mut();
trainer
.save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
.map_err(|e| format!("save_apr (cuda) failed: {e}"))
}
}