Skip to main content

entrenar/train/
pretrain_real_cuda.rs

1//! CUDA-backend `StepFn` / `ValFn` / `CheckpointFn` for the 370M pretrain
2//! loop (task #132 Phase 2, contract `gpu-training-backend-v1`).
3//!
4//! Mirrors `pretrain_real.rs` but swaps `TransformerTrainer`
5//! (CPU + trueno SIMD) for `CudaTransformerTrainer` (GPU-resident
6//! AdamW + fused CE). The entire module is gated on
7//! `#[cfg(feature = "cuda")]` because `CudaTransformerTrainer::new`
8//! / `train_batch` / `eval_batch` / `save_apr` only exist in the
9//! cuda build — the non-cuda stub returns an error from `new()` and
10//! exposes no step/eval/save methods.
11//!
12//! Contract obligations discharged / strengthened vs the CPU path:
13//! - INV-ARCH-370M-001 (param count ∈ [366M, 374M]) via `debug_assert`
14//!   on `CudaTransformerTrainer::model().parameters()`, matching
15//!   the CPU guard.
16//! - INV-TRAIN-007 (no NaN/Inf): `train_batch` / `eval_batch` return
17//!   finite loss by construction; non-finite outputs abort via
18//!   `PretrainLoop`'s guards.
19//! - INV-TRAIN-008 (grad_norm ≥ 0): `last_grad_norm()` returns the
20//!   real LM-head L2 norm. Strictly stronger than the CPU path's
21//!   `1.0` placeholder.
22//!
23//! Deferred to a follow-up:
24//! - INV-TRAIN-003 (AdamW-state sha256). `CudaTransformerTrainer`
25//!   keeps (m, v, t) on the GPU; discharging this cleanly needs a
26//!   D2H sync that `save_apr` already pays for but `StepFn` does
27//!   not want to pay per-step. Until that sync is factored out,
28//!   the trait default `optimizer_state_sha256 -> None` is used,
29//!   and GATE-TRAIN-006 runs only on the CPU path.
30
31#![cfg(feature = "cuda")]
32
33use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
34use crate::train::pretrain_real::llama_370m_train_config;
35use crate::train::transformer_trainer::{CudaTransformerTrainer, LMBatch};
36use std::cell::RefCell;
37use std::rc::Rc;
38
39/// Shared mutable ownership of a GPU-resident trainer. Both
40/// `CudaRealStepFn` (train steps) and `CudaRealValFn` (eval) clone
41/// this `Rc` so the three hooks see the same GPU memory.
42pub type SharedCudaTrainer = Rc<RefCell<CudaTransformerTrainer>>;
43
44/// Allocate a `CudaTransformerTrainer` with MODEL-2 v2-remedy defaults
45/// and verify INV-ARCH-370M-001 in debug builds.
46///
47/// Returns a `crate::Result` because `CudaTransformerTrainer::new`
48/// can fail on missing CUDA runtime, kernel pre-warm failure, or
49/// block upload failure — the CLI surfaces this as a
50/// GATE-GPUTRAIN-002 error so the operator knows to check their
51/// `--features cuda` build or their GPU.
52pub fn build_shared_cuda_trainer(
53    lr: f32,
54    seq_length: usize,
55    seed: u64,
56) -> crate::Result<SharedCudaTrainer> {
57    let cfg = llama_370m_train_config(lr, seq_length, seed);
58    let trainer = CudaTransformerTrainer::new(cfg)?;
59    #[cfg(debug_assertions)]
60    {
61        let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
62        debug_assert!(
63            (366_000_000..=374_000_000).contains(&param_count),
64            "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
65        );
66    }
67    Ok(Rc::new(RefCell::new(trainer)))
68}
69
70/// CUDA `StepFn` — pulls one `LMBatch` from the shard iterator and
71/// runs a real GPU forward + backward + AdamW step.
72pub struct CudaRealStepFn {
73    trainer: SharedCudaTrainer,
74    batches: Box<dyn Iterator<Item = LMBatch>>,
75}
76
77impl CudaRealStepFn {
78    pub fn new(trainer: SharedCudaTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
79        Self { trainer, batches }
80    }
81}
82
83impl StepFn for CudaRealStepFn {
84    fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
85        // Exhausted shard stream: emit a finite placeholder so the
86        // NaN/Inf guard (INV-TRAIN-007) doesn't mis-fire and the
87        // divergence guard (GATE-TRAIN-005) correctly does not abort.
88        let Some(batch) = self.batches.next() else {
89            return (1.0, 1.0);
90        };
91        let mut trainer = self.trainer.borrow_mut();
92        let loss = trainer.train_batch(&batch);
93        // Real LM-head L2 norm — strictly more informative than the
94        // CPU path's `1.0` placeholder for GATE-TRAIN-008 monitoring.
95        let grad_norm = trainer.last_grad_norm();
96        (loss, grad_norm)
97    }
98
99    // INV-TRAIN-003 intentionally deferred for the GPU path — see
100    // module docs. Uses trait default `-> None`, so the CPU gate
101    // (`--device cpu`) is the one that exercises AdamW-state parity.
102}
103
104/// CUDA `ValFn` — forward-only eval across pre-loaded held-out
105/// batches. Uses `eval_batch` (fused GPU cross-entropy, no logits
106/// D2H) and averages across batches.
107pub struct CudaRealValFn {
108    trainer: SharedCudaTrainer,
109    held_out: Vec<LMBatch>,
110}
111
112impl CudaRealValFn {
113    pub fn new(trainer: SharedCudaTrainer, held_out: Vec<LMBatch>) -> Self {
114        Self { trainer, held_out }
115    }
116}
117
118impl ValFn for CudaRealValFn {
119    fn validate(&mut self, _epoch: usize) -> f32 {
120        if self.held_out.is_empty() {
121            return f32::NAN;
122        }
123        let mut trainer = self.trainer.borrow_mut();
124        let mut total_loss = 0.0_f32;
125        let mut count = 0_usize;
126        for batch in &self.held_out {
127            if batch.batch_size == 0 {
128                continue;
129            }
130            total_loss += trainer.eval_batch(batch);
131            count += 1;
132        }
133        if count == 0 {
134            f32::NAN
135        } else {
136            total_loss / count as f32
137        }
138    }
139}
140
141/// CUDA `CheckpointFn` — writes the 370M weights to
142/// `artifact.checkpoint_path` in APR format. `save_apr` takes
143/// `&mut self` on the CUDA path because it syncs GPU→CPU before
144/// writing, which is why this holds the `SharedCudaTrainer` instead
145/// of cloning the trainer out.
146pub struct CudaAprCheckpointFn {
147    trainer: SharedCudaTrainer,
148    model_name: String,
149    architecture: String,
150}
151
152impl CudaAprCheckpointFn {
153    pub fn new(
154        trainer: SharedCudaTrainer,
155        model_name: impl Into<String>,
156        architecture: impl Into<String>,
157    ) -> Self {
158        Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
159    }
160}
161
162impl CheckpointFn for CudaAprCheckpointFn {
163    fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
164        let mut trainer = self.trainer.borrow_mut();
165        trainer
166            .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
167            .map_err(|e| format!("save_apr (cuda) failed: {e}"))
168    }
169}