Skip to main content

entrenar/autograd/cuda_backward/
cache.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Arc, Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13
14// trueno#200: Blackwell Blackwell skip requires local trueno-gpu with from_ptx fix.
15// Add [patch.crates-io] trueno-gpu = { path = "../trueno/trueno-gpu" } to Cargo.toml.
16
17use super::super::cuda_tensor::{CudaTensorError, Result};
18
19/// Cached compiled CUDA modules for backward kernels
20#[cfg(feature = "cuda")]
21pub(super) static KERNEL_CACHE: OnceLock<Mutex<KernelCache>> = OnceLock::new();
22
23/// Cache for compiled backward kernel modules
24///
25/// # Contract: F-PTX-001 (Target Parity)
26///
27/// Same invariant as forward cache — PTX target must match device.
28#[cfg(feature = "cuda")]
29pub(super) struct KernelCache {
30    ctx: Arc<CudaContext>,
31    modules: HashMap<String, CudaModule>,
32    sm_target: String,
33    /// cuBLAS handle for backward GEMMs (ALB-075). Uses CUBLAS_DEFAULT_MATH
34    /// (SIMD, no tensor cores) per ALB-076/trueno#170 to avoid NaN in transposed GEMMs.
35    cublas: Option<CublasHandle>,
36}
37
38#[cfg(feature = "cuda")]
39impl KernelCache {
40    pub(super) fn new(ctx: Arc<CudaContext>) -> Self {
41        let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
42        let cublas = CublasHandle::new(&ctx).ok();
43        Self { ctx, modules: HashMap::new(), sm_target, cublas }
44    }
45
46    /// Get a reference to the cuBLAS handle, if available.
47    pub(super) fn cublas(&self) -> Option<&CublasHandle> {
48        self.cublas.as_ref()
49    }
50
51    /// Bind cuBLAS to a stream for the current training step.
52    pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
53        if let Some(ref handle) = self.cublas {
54            handle.set_stream(stream).map_err(|e| {
55                CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
56            })?;
57        }
58        Ok(())
59    }
60
61    pub(super) fn sm_target(&self) -> &str {
62        &self.sm_target
63    }
64
65    /// Get the CudaContext this cache was built against. Used by callers that
66    /// need to allocate temp `GpuBuffer<T>` (which require a context) — e.g.
67    /// `rms_norm_backward`'s per-row partial buffer for the FALSIFY-GPUTRAIN-006
68    /// deterministic two-stage gamma reduction.
69    pub(super) fn ctx(&self) -> &Arc<CudaContext> {
70        &self.ctx
71    }
72
73    /// Look up a previously compiled module by key (KAIZEN-058).
74    pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
75        self.modules.get_mut(name)
76    }
77
78    pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
79        use std::collections::hash_map::Entry;
80
81        // F-PTX-001: Validate PTX target matches device
82        if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
83            let ptx_target = target_line.trim().trim_start_matches(".target ");
84            if ptx_target != self.sm_target {
85                return Err(CudaTensorError::KernelError(format!(
86                    "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'",
87                    self.sm_target
88                )));
89            }
90        }
91
92        match self.modules.entry(name.to_string()) {
93            Entry::Occupied(e) => Ok(e.into_mut()),
94            Entry::Vacant(e) => {
95                eprintln!("[BWD-CACHE] Compiling '{name}' (ptx_len={})", ptx.len());
96
97                // trueno#200: On Blackwell, CudaModule::from_ptx uses cuModuleLoadDataEx
98                // which poisons the CUDA context. Bypass it entirely with direct
99                // cuModuleLoadData via the raw driver API.
100                let (major, _minor) = self.ctx.compute_capability().map_err(|e| {
101                    CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
102                })?;
103
104                // trueno#200: Use from_ptx_direct on Blackwell to avoid context poisoning
105                let module = if major >= 12 {
106                    CudaModule::from_ptx_direct(&self.ctx, ptx)
107                } else {
108                    CudaModule::from_ptx(&self.ctx, ptx)
109                }
110                .map_err(|err| {
111                    eprintln!("[BWD-CACHE] FAILED '{name}': {err:?}");
112                    CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
113                })?;
114                eprintln!("[BWD-CACHE] OK '{name}'");
115                Ok(e.insert(module))
116            }
117        }
118    }
119}
120
121/// Initialize kernel cache with CUDA context
122#[cfg(feature = "cuda")]
123pub fn init_kernel_cache(ctx: Arc<CudaContext>) -> Result<()> {
124    KERNEL_CACHE.get_or_init(|| Mutex::new(KernelCache::new(ctx)));
125    Ok(())
126}
127
128/// Bind cuBLAS handle in the backward cache to a stream (ALB-075).
129#[cfg(feature = "cuda")]
130pub fn set_backward_cublas_stream(stream: &CudaStream) -> Result<()> {
131    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
132    let cache = cache.lock().map_err(|_err| {
133        CudaTensorError::KernelError("Failed to acquire backward kernel cache lock".to_string())
134    })?;
135    cache.set_cublas_stream(stream)
136}
137
138/// Pre-warm backward GEMM kernels for training gradient computation (ENT-153).
139///
140/// Covers both LoRA-only shapes (NF4 QLoRA) and full fp32 backward shapes.
141///
142/// ## LoRA backward shapes
143/// `gemm_backward_b` (weight grads): `(S,R,qd)`, `(S,R,kv)`, `(S,H,R)`
144/// `gemm_backward_a` (input grads): `(S,qd,R)`, `(S,kv,R)`, `(S,R,H)`
145///
146/// ## Full fp32 backward shapes (non-NF4)
147/// `gemm_backward_a`: `(S,I,H)` down, `(S,H,I)` gate/up, `(S,H,H)` Q/O, `(S,kv,H)` K/V
148/// `gemm_backward_b`: `(S,I,H)` grad_w_down, `(S,H,I)` grad_w_gate/up,
149///                     `(S,H,H)` grad_w_q/o, `(S,kv,H)` grad_w_k/v
150///
151/// All must be JIT-compiled before block upload fills VRAM (C-PREWARM-001).
152#[cfg(feature = "cuda")]
153pub fn pre_warm_lora_backward_kernels(
154    hidden_size: usize,
155    q_dim: usize,
156    kv_hidden_size: usize,
157    max_seq_len: usize,
158    lora_rank: usize,
159    intermediate_size: usize,
160    num_heads: usize,
161    quantize_nf4: bool,
162) -> Result<()> {
163    use trueno_gpu::kernels::backward::{
164        BatchedRmsNormBackwardKernel, BatchedSoftmaxBackwardKernel, GemmBackwardAKernel,
165        GemmBackwardBKernel, RmsNormGammaReduceKernel, SiluBackwardKernel,
166    };
167    use trueno_gpu::kernels::Kernel;
168
169    eprintln!("[BWD-PREWARM] Called with lora_rank={lora_rank}, hidden={hidden_size}, inter={intermediate_size}");
170
171    // PMAT-698g: non-LoRA backward pre-warm. The original gate at
172    // `lora_rank == 0` short-circuited the entire function — leaving
173    // silu_backward, batched_softmax_backward, batched_rms_norm_backward,
174    // and rms_norm_gamma_reduce to JIT on demand mid-training. On Blackwell
175    // sm_121, on-demand JIT during the first backward step corrupts the
176    // CUDA stream (trueno#200), surfacing as
177    //   forward_backward_with_grad returned None
178    // mid-step (Phase 3 dispatch v8 confirmed this: kernels compiled
179    // successfully but stream state was poisoned afterwards).
180    //
181    // Now: only the LoRA-specific gemm_backward warmups are skipped when
182    // lora_rank == 0; the activation/norm kernels and the standard FP32
183    // GEMM backward shapes are always pre-warmed (they're needed by
184    // non-LoRA distillation training too).
185    let is_lora = lora_rank > 0;
186
187    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
188    let mut cache = cache.lock().map_err(|_err| {
189        CudaTensorError::KernelError("Failed to acquire backward kernel cache lock".to_string())
190    })?;
191
192    let s = max_seq_len as u32;
193    let h = hidden_size as u32;
194    let r = lora_rank as u32;
195    let qd = q_dim as u32;
196    let kv = kv_hidden_size as u32;
197    let i = intermediate_size as u32;
198    let nh = num_heads as u32;
199
200    let mut count = 0u32;
201    let target = cache.sm_target().to_string();
202
203    macro_rules! warm {
204        ($key:expr, $kernel:expr) => {{
205            let key = $key;
206            let ptx = $kernel.emit_ptx_for_target(&target);
207            cache.get_or_compile(&key, &ptx)?;
208            count += 1;
209        }};
210    }
211
212    // Tile size must match BACKWARD_TILE_SIZE in gemm.rs (C-TILE-BWD-007)
213    let tile: u32 = 16;
214
215    // ── LoRA backward shapes (LoRA training only) ──
216    if is_lora {
217        // gemm_backward_b: weight gradients
218        warm!(
219            format!("gemm_backward_b_{s}_{r}_{qd}"),
220            GemmBackwardBKernel::tiled_unrolled(s, r, qd, tile)
221        );
222        if kv != qd {
223            warm!(
224                format!("gemm_backward_b_{s}_{r}_{kv}"),
225                GemmBackwardBKernel::tiled_unrolled(s, r, kv, tile)
226            );
227        }
228        warm!(
229            format!("gemm_backward_b_{s}_{h}_{r}"),
230            GemmBackwardBKernel::tiled_unrolled(s, h, r, tile)
231        );
232
233        // gemm_backward_a: input gradients
234        warm!(
235            format!("gemm_backward_a_{s}_{qd}_{r}"),
236            GemmBackwardAKernel::tiled_unrolled(s, qd, r, tile)
237        );
238        if kv != qd {
239            warm!(
240                format!("gemm_backward_a_{s}_{kv}_{r}"),
241                GemmBackwardAKernel::tiled_unrolled(s, kv, r, tile)
242            );
243        }
244        warm!(
245            format!("gemm_backward_a_{s}_{r}_{h}"),
246            GemmBackwardAKernel::tiled_unrolled(s, r, h, tile)
247        );
248    }
249
250    // ── Full fp32 backward shapes (non-NF4 mode) ──
251    if !quantize_nf4 {
252        // Attention backward: Q/O (S,H,H), K/V (S,kv,H)
253        warm!(
254            format!("gemm_backward_a_{s}_{h}_{h}"),
255            GemmBackwardAKernel::tiled_unrolled(s, h, h, tile)
256        );
257        warm!(
258            format!("gemm_backward_b_{s}_{h}_{h}"),
259            GemmBackwardBKernel::tiled_unrolled(s, h, h, tile)
260        );
261        if kv != h {
262            warm!(
263                format!("gemm_backward_a_{s}_{kv}_{h}"),
264                GemmBackwardAKernel::tiled_unrolled(s, kv, h, tile)
265            );
266            warm!(
267                format!("gemm_backward_b_{s}_{kv}_{h}"),
268                GemmBackwardBKernel::tiled_unrolled(s, kv, h, tile)
269            );
270        }
271
272        // FFN backward: gate/up (S,H,I), down (S,I,H)
273        warm!(
274            format!("gemm_backward_a_{s}_{h}_{i}"),
275            GemmBackwardAKernel::tiled_unrolled(s, h, i, tile)
276        );
277        warm!(
278            format!("gemm_backward_b_{s}_{h}_{i}"),
279            GemmBackwardBKernel::tiled_unrolled(s, h, i, tile)
280        );
281        warm!(
282            format!("gemm_backward_a_{s}_{i}_{h}"),
283            GemmBackwardAKernel::tiled_unrolled(s, i, h, tile)
284        );
285        warm!(
286            format!("gemm_backward_b_{s}_{i}_{h}"),
287            GemmBackwardBKernel::tiled_unrolled(s, i, h, tile)
288        );
289    }
290
291    // ── Activation backward: SiLU ──
292    let si = s * i;
293    warm!("silu_backward".to_string(), SiluBackwardKernel::new(si));
294
295    // ── Structured backward kernels (attention + normalization) ──
296    // Batched softmax backward (dimension-independent)
297    let softmax_rows = nh * s;
298    warm!(
299        "batched_softmax_backward".to_string(),
300        BatchedSoftmaxBackwardKernel::new(softmax_rows, s)
301    );
302
303    // RMSNorm backward (dimension-independent)
304    let eps = 1e-5_f32;
305    warm!("batched_rms_norm_backward".to_string(), BatchedRmsNormBackwardKernel::new(s, h, eps));
306
307    // PMAT-698h: RMSNorm backward is TWO stages — `batched_rms_norm_backward`
308    // produces a [batch_size, hidden] gamma-partial buffer, then
309    // `rms_norm_gamma_reduce` (in structured.rs:247) sums it row-by-row in
310    // deterministic fixed order to produce the final [hidden] gamma
311    // gradient. The reduce kernel was previously JIT'd on first call —
312    // discovered in Phase 3 dispatch v9 where every other backward kernel
313    // was pre-warmed but gamma_reduce still hit on-demand JIT and the
314    // sm_121 stream poisoned during it. Now: always pre-warm with the
315    // same (batch_size=max_seq_len, hidden_size) dims that structured.rs
316    // constructs at call time.
317    warm!("rms_norm_gamma_reduce".to_string(), RmsNormGammaReduceKernel::new(s, h));
318
319    let _ = count;
320    Ok(())
321}