entrenar/train/pretrain_real_cuda.rs
1//! CUDA-backend `StepFn` / `ValFn` / `CheckpointFn` for the 370M pretrain
2//! loop (task #132 Phase 2, contract `gpu-training-backend-v1`).
3//!
4//! Mirrors `pretrain_real.rs` but swaps `TransformerTrainer`
5//! (CPU + trueno SIMD) for `CudaTransformerTrainer` (GPU-resident
6//! AdamW + fused CE). The entire module is gated on
7//! `#[cfg(feature = "cuda")]` because `CudaTransformerTrainer::new`
8//! / `train_batch` / `eval_batch` / `save_apr` only exist in the
9//! cuda build — the non-cuda stub returns an error from `new()` and
10//! exposes no step/eval/save methods.
11//!
12//! Contract obligations discharged / strengthened vs the CPU path:
13//! - INV-ARCH-370M-001 (param count ∈ [366M, 374M]) via `debug_assert`
14//! on `CudaTransformerTrainer::model().parameters()`, matching
15//! the CPU guard.
16//! - INV-TRAIN-007 (no NaN/Inf): `train_batch` / `eval_batch` return
17//! finite loss by construction; non-finite outputs abort via
18//! `PretrainLoop`'s guards.
19//! - INV-TRAIN-008 (grad_norm ≥ 0): `last_grad_norm()` returns the
20//! real LM-head L2 norm. Strictly stronger than the CPU path's
21//! `1.0` placeholder.
22//!
23//! Deferred to a follow-up:
24//! - INV-TRAIN-003 (AdamW-state sha256). `CudaTransformerTrainer`
25//! keeps (m, v, t) on the GPU; discharging this cleanly needs a
26//! D2H sync that `save_apr` already pays for but `StepFn` does
27//! not want to pay per-step. Until that sync is factored out,
28//! the trait default `optimizer_state_sha256 -> None` is used,
29//! and GATE-TRAIN-006 runs only on the CPU path.
30
31#![cfg(feature = "cuda")]
32
33use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
34use crate::train::pretrain_real::llama_370m_train_config;
35use crate::train::transformer_trainer::{CudaTransformerTrainer, LMBatch};
36use std::cell::RefCell;
37use std::rc::Rc;
38
39/// Shared mutable ownership of a GPU-resident trainer. Both
40/// `CudaRealStepFn` (train steps) and `CudaRealValFn` (eval) clone
41/// this `Rc` so the three hooks see the same GPU memory.
42pub type SharedCudaTrainer = Rc<RefCell<CudaTransformerTrainer>>;
43
44/// Allocate a `CudaTransformerTrainer` with MODEL-2 v2-remedy defaults
45/// and verify INV-ARCH-370M-001 in debug builds.
46///
47/// Returns a `crate::Result` because `CudaTransformerTrainer::new`
48/// can fail on missing CUDA runtime, kernel pre-warm failure, or
49/// block upload failure — the CLI surfaces this as a
50/// GATE-GPUTRAIN-002 error so the operator knows to check their
51/// `--features cuda` build or their GPU.
52pub fn build_shared_cuda_trainer(
53 lr: f32,
54 seq_length: usize,
55 seed: u64,
56) -> crate::Result<SharedCudaTrainer> {
57 let cfg = llama_370m_train_config(lr, seq_length, seed);
58 let trainer = CudaTransformerTrainer::new(cfg)?;
59 #[cfg(debug_assertions)]
60 {
61 let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
62 debug_assert!(
63 (366_000_000..=374_000_000).contains(¶m_count),
64 "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
65 );
66 }
67 Ok(Rc::new(RefCell::new(trainer)))
68}
69
70/// CUDA `StepFn` — pulls one `LMBatch` from the shard iterator and
71/// runs a real GPU forward + backward + AdamW step.
72pub struct CudaRealStepFn {
73 trainer: SharedCudaTrainer,
74 batches: Box<dyn Iterator<Item = LMBatch>>,
75}
76
77impl CudaRealStepFn {
78 pub fn new(trainer: SharedCudaTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
79 Self { trainer, batches }
80 }
81}
82
83impl StepFn for CudaRealStepFn {
84 fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
85 // Exhausted shard stream: emit a finite placeholder so the
86 // NaN/Inf guard (INV-TRAIN-007) doesn't mis-fire and the
87 // divergence guard (GATE-TRAIN-005) correctly does not abort.
88 let Some(batch) = self.batches.next() else {
89 return (1.0, 1.0);
90 };
91 let mut trainer = self.trainer.borrow_mut();
92 let loss = trainer.train_batch(&batch);
93 // Real LM-head L2 norm — strictly more informative than the
94 // CPU path's `1.0` placeholder for GATE-TRAIN-008 monitoring.
95 let grad_norm = trainer.last_grad_norm();
96 (loss, grad_norm)
97 }
98
99 // INV-TRAIN-003 intentionally deferred for the GPU path — see
100 // module docs. Uses trait default `-> None`, so the CPU gate
101 // (`--device cpu`) is the one that exercises AdamW-state parity.
102}
103
104/// CUDA `ValFn` — forward-only eval across pre-loaded held-out
105/// batches. Uses `eval_batch` (fused GPU cross-entropy, no logits
106/// D2H) and averages across batches.
107pub struct CudaRealValFn {
108 trainer: SharedCudaTrainer,
109 held_out: Vec<LMBatch>,
110}
111
112impl CudaRealValFn {
113 pub fn new(trainer: SharedCudaTrainer, held_out: Vec<LMBatch>) -> Self {
114 Self { trainer, held_out }
115 }
116}
117
118impl ValFn for CudaRealValFn {
119 fn validate(&mut self, _epoch: usize) -> f32 {
120 if self.held_out.is_empty() {
121 return f32::NAN;
122 }
123 let mut trainer = self.trainer.borrow_mut();
124 let mut total_loss = 0.0_f32;
125 let mut count = 0_usize;
126 for batch in &self.held_out {
127 if batch.batch_size == 0 {
128 continue;
129 }
130 total_loss += trainer.eval_batch(batch);
131 count += 1;
132 }
133 if count == 0 {
134 f32::NAN
135 } else {
136 total_loss / count as f32
137 }
138 }
139}
140
141/// CUDA `CheckpointFn` — writes the 370M weights to
142/// `artifact.checkpoint_path` in APR format. `save_apr` takes
143/// `&mut self` on the CUDA path because it syncs GPU→CPU before
144/// writing, which is why this holds the `SharedCudaTrainer` instead
145/// of cloning the trainer out.
146pub struct CudaAprCheckpointFn {
147 trainer: SharedCudaTrainer,
148 model_name: String,
149 architecture: String,
150}
151
152impl CudaAprCheckpointFn {
153 pub fn new(
154 trainer: SharedCudaTrainer,
155 model_name: impl Into<String>,
156 architecture: impl Into<String>,
157 ) -> Self {
158 Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
159 }
160}
161
162impl CheckpointFn for CudaAprCheckpointFn {
163 fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
164 let mut trainer = self.trainer.borrow_mut();
165 trainer
166 .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
167 .map_err(|e| format!("save_apr (cuda) failed: {e}"))
168 }
169}