Skip to main content

entrenar/autograd/cuda_forward/
cache.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13#[cfg(feature = "cuda")]
14use trueno_gpu::kernels::{
15    Batched4DGemmKernel, BatchedRopeBackwardKernel, BatchedSoftmaxKernel,
16    BatchedToInterleavedKernel, BatchedTransposeKernel, BatchedVectorizedRmsNormKernel,
17    ElementwiseMulKernel, FusedSwigluKernel, GemmKernel, InterleavedToBatchedKernel, Kernel,
18    Nf4GemmKernel, Nf4GemmTransposeKernel, ResidualAddKernel, ScaleKernel, SiluKernel,
19};
20
21use crate::autograd::cuda_tensor::{CudaTensorError, Result};
22
23/// Cached compiled CUDA modules for forward kernels
24#[cfg(feature = "cuda")]
25pub(super) static FORWARD_KERNEL_CACHE: OnceLock<Mutex<ForwardKernelCache>> = OnceLock::new();
26
27/// Cache for compiled forward kernel modules
28///
29/// Stores the device's SM target (e.g. "sm_89") detected at init time.
30/// All PTX must be emitted for this target before compilation.
31///
32/// # Contract: F-PTX-001 (Target Parity)
33///
34/// PTX `.target` directive MUST match the device compute capability.
35/// The cache validates this at compile time and rejects mismatched PTX.
36#[cfg(feature = "cuda")]
37pub(super) struct ForwardKernelCache {
38    ctx: std::sync::Arc<CudaContext>,
39    modules: HashMap<String, CudaModule>,
40    /// Device SM target string (e.g. "sm_89" for RTX 4090)
41    sm_target: String,
42    /// cuBLAS handle (ALB-075): forward=tensor cores, backward=SIMD (ALB-076/trueno#170)
43    cublas: Option<CublasHandle>,
44}
45
46#[cfg(feature = "cuda")]
47impl ForwardKernelCache {
48    pub(super) fn new(ctx: std::sync::Arc<CudaContext>) -> Self {
49        // Detect device compute capability at construction time.
50        // Falls back to sm_70 if detection fails (should never happen
51        // since we already have a valid CudaContext).
52        let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
53
54        // entrenar#318: Forward uses TF32 tensor cores (~41x faster than SIMD on sm_89).
55        // ALB-076: TF32 is safe for forward (NoTrans/NoTrans). Backward uses SIMD handle.
56        let cublas = match CublasHandle::new_with_tensor_cores(&ctx) {
57            Ok(handle) => {
58                eprintln!("[CUDA] cuBLAS initialized — forward TF32 tensor cores (41x vs SIMD)");
59                Some(handle)
60            }
61            Err(e) => {
62                eprintln!("[CUDA] cuBLAS not available ({e:?}), using PTX GEMMs");
63                None
64            }
65        };
66
67        eprintln!("[CUDA] Kernel cache initialized for target: {sm_target}");
68        Self { ctx, modules: HashMap::new(), sm_target, cublas }
69    }
70
71    /// Get a reference to the cuBLAS handle, if available.
72    pub(super) fn cublas(&self) -> Option<&CublasHandle> {
73        self.cublas.as_ref()
74    }
75
76    /// Bind cuBLAS to a stream for the current training step.
77    pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
78        if let Some(ref handle) = self.cublas {
79            handle.set_stream(stream).map_err(|e| {
80                CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
81            })?;
82        }
83        Ok(())
84    }
85
86    /// Get the device SM target for PTX emission.
87    ///
88    /// Consumers MUST use this to emit PTX via `kernel.emit_ptx_for_target(cache.sm_target())`.
89    pub(super) fn sm_target(&self) -> &str {
90        &self.sm_target
91    }
92
93    /// Look up a previously compiled module by key (KAIZEN-058).
94    ///
95    /// Returns `Some` if the module is already cached (post-pre-warm: always).
96    /// Callers should use this before generating PTX to avoid unnecessary
97    /// multi-KB String allocations (~1000 per training step).
98    pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
99        self.modules.get_mut(name)
100    }
101
102    /// Compile PTX and cache the resulting module.
103    ///
104    /// # Contract: F-PTX-001 (Target Parity)
105    ///
106    /// Validates that the PTX `.target` directive matches the device's compute
107    /// capability. Rejects PTX compiled for the wrong architecture.
108    pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
109        use std::collections::hash_map::Entry;
110
111        // F-PTX-001: Validate PTX target matches device
112        if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
113            let ptx_target = target_line.trim().trim_start_matches(".target ");
114            if ptx_target != self.sm_target {
115                return Err(CudaTensorError::KernelError(format!(
116                    "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'. \
117                     Use kernel.emit_ptx_for_target(\"{}\") instead of emit_ptx().",
118                    self.sm_target, self.sm_target
119                )));
120            }
121        }
122
123        match self.modules.entry(name.to_string()) {
124            Entry::Occupied(e) => Ok(e.into_mut()),
125            Entry::Vacant(e) => {
126                // PMAT-698i: diagnostic logging. Surfaces every forward-cache
127                // JIT event with its kernel name so missing pre-warm entries
128                // are identifiable in O(1) instead of O(N) iterations.
129                eprintln!("[FWD-CACHE] Compiling '{name}' (ptx_len={})", ptx.len());
130                // trueno#200: Use from_ptx_direct on Blackwell
131                let (major, _) = self.ctx.compute_capability().map_err(|e| {
132                    CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
133                })?;
134                let module = if major >= 12 {
135                    CudaModule::from_ptx_direct(&self.ctx, ptx)
136                } else {
137                    CudaModule::from_ptx(&self.ctx, ptx)
138                }
139                .map_err(|err| {
140                    CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
141                })?;
142                eprintln!("[FWD-CACHE] OK '{name}'");
143                Ok(e.insert(module))
144            }
145        }
146    }
147
148    /// Pre-warm all kernels needed for transformer forward pass.
149    ///
150    /// # Contract: C-PREWARM-001 (JIT Before Payload)
151    ///
152    /// - **Precondition**: Kernel cache initialized, GPU VRAM mostly free (no blocks uploaded yet)
153    /// - **Postcondition**: All forward-pass PTX modules JIT-compiled and cached
154    /// - **Invariant**: Subsequent `get_or_compile()` calls for these keys hit cache (zero JIT)
155    ///
156    /// CUDA's `cuModuleLoadDataEx` JIT compiler needs device memory for compilation.
157    /// If called after uploading 36 transformer blocks (~22 GB), the near-OOM state causes
158    /// `CUDA_ERROR_ILLEGAL_ADDRESS` during JIT (trueno#107). Pre-warming compiles all PTX
159    /// while VRAM is free, avoiding this failure mode entirely.
160    pub(super) fn pre_warm_for_model(
161        &mut self,
162        hidden_size: usize,
163        intermediate_size: usize,
164        num_heads: usize,
165        num_kv_heads: usize,
166        head_dim: usize,
167        max_seq_len: usize,
168    ) -> Result<()> {
169        let s = max_seq_len as u32;
170        let h = hidden_size as u32;
171        let q_dim = (num_heads * head_dim) as u32; // Q/O projection dim (may differ from h)
172        let kv_h = (num_kv_heads * head_dim) as u32;
173        let i = intermediate_size as u32;
174        let nh = num_heads as u32;
175        let _nkv = num_kv_heads as u32;
176        let hd = head_dim as u32;
177        let sh = s * h; // seq_len * hidden_size
178        let si = s * i; // seq_len * intermediate_size
179
180        let mut count = 0u32;
181        let target = self.sm_target.clone();
182
183        // Helper: generate PTX and compile.
184        //
185        // PMAT-698j: previously hardcoded "silu_forward" as the cache key,
186        // which meant every warm!() call collided on the same HashMap entry.
187        // Only the FIRST kernel compiled actually got stored; all subsequent
188        // warm!() invocations short-circuited because "silu_forward" was
189        // already occupied. At runtime every other kernel (rmsnorm, rope,
190        // softmax, swiglu, residual, etc.) cache-missed under its real key
191        // and JIT-compiled mid-training — on Blackwell sm_121 that
192        // corrupted the CUDA stream and surfaced as the cascading "Block 0
193        // upload failed" / "forward_backward_with_grad returned None"
194        // errors hunted across PMAT-698e..i.
195        //
196        // Discovered by PMAT-698i diagnostic logging: [FWD-CACHE] showed
197        // every "pre-warmed" kernel actually JIT'd at first use because
198        // the cache only contained one entry. One-character fix.
199        macro_rules! warm {
200            ($key:expr, $kernel:expr) => {{
201                let key = $key;
202                let ptx = $kernel.emit_ptx_for_target(&target);
203                self.get_or_compile(&key, &ptx)?;
204                count += 1;
205            }};
206        }
207
208        // 1. RMSNorm (batched: single launch for all rows via grid.y)
209        // ALB-076: Use BatchedVectorizedRmsNormKernel instead of per-row RmsNormKernel
210        //
211        // PMAT-698k: the runtime key format includes the eps as bit-pattern
212        // suffix (normalization.rs:139:
213        //   let key = format!("batched_rmsnorm_fwd_{hidden_size}_eps{eps_bits:08x}"))
214        // Pre-warm key used to omit the eps suffix → cache miss at runtime →
215        // JIT mid-forward → Blackwell sm_121 stream poisoning.
216        //
217        // PMAT-698n: PMAT-698k pre-warmed at eps=1e-5 (0x3727c5ac) but the
218        // dominant model (Qwen2 / Qwen2.5) uses rms_norm_eps=1e-6
219        // (0x358637bd). Live diagnostic confirmed the runtime key on the
220        // Phase 3 dispatch was `batched_rmsnorm_fwd_896_eps358637bd`. Switch
221        // the pre-warm default to 1e-6 (Qwen2 standard) AND additionally
222        // pre-warm 1e-5 (Llama/Mistral standard) for cross-family coverage.
223        // The cost of pre-warming both is ~30 KB of cache headroom.
224        let qwen2_eps_bits = 1.0e-6_f32.to_bits(); // 0x358637bd
225        let llama_eps_bits = 1.0e-5_f32.to_bits(); // 0x3727c5ac
226        warm!(
227            format!("batched_rmsnorm_fwd_{h}_eps{qwen2_eps_bits:08x}"),
228            BatchedVectorizedRmsNormKernel::new(h, 1)
229        );
230        if qwen2_eps_bits != llama_eps_bits {
231            warm!(
232                format!("batched_rmsnorm_fwd_{h}_eps{llama_eps_bits:08x}"),
233                BatchedVectorizedRmsNormKernel::new(h, 1)
234            );
235        }
236
237        // PMAT-700 (SPEC-BLACKWELL-FIX-001 Fix #2): when cuBLAS is available
238        // and the runtime takes its fast path for the standard 2D GEMMs
239        // (Q/K/V/O/gate/up/down projections — see ALB-075 dispatch in
240        // gemm.rs:47-49 and cuda_block.rs:2895), pre-warming the PTX
241        // equivalents is wasted VRAM. On sm_121 (Blackwell GB10) the
242        // resulting JIT-cache footprint pushes block upload over the budget
243        // and CUDA_ERROR_OUT_OF_MEMORY fires at "Block 0 upload". Skipping
244        // these four pre-warms when cuBLAS is bound saves ~5-7 PTX modules
245        // per cache (more on multi-block-size models) and unblocks gx10
246        // dispatch without any runtime path change.
247        //
248        // Falsifier: F-BLACKWELL-CUBLAS-PREWARM-001 — assert the cache
249        // module count after pre_warm_for_model decreases when cuBLAS is
250        // present, and that runtime forward still produces identical
251        // results on a known input (cuBLAS path was already taken).
252        let has_cublas = self.cublas.is_some();
253        if !has_cublas {
254            // 2. GEMM: Q/O projections (S, H, H)
255            warm!(format!("gemm_forward_{s}_{h}_{h}"), GemmKernel::naive(s, h, h));
256
257            // 3. GEMM: K/V projections (S, H, kv_hidden)
258            if kv_h != h {
259                warm!(format!("gemm_forward_{s}_{h}_{kv_h}"), GemmKernel::naive(s, kv_h, h));
260            }
261
262            // 4. GEMM: gate/up projections (S, H, I)
263            warm!(format!("gemm_forward_{s}_{h}_{i}"), GemmKernel::naive(s, i, h));
264
265            // 5. GEMM: down projection (S, I, H)
266            warm!(format!("gemm_forward_{s}_{i}_{h}"), GemmKernel::naive(s, h, i));
267        } else {
268            eprintln!("[CUDA] Skipping PTX pre-warm for 4 GEMM kernels (cuBLAS active — PMAT-700)");
269        }
270
271        // PMAT-698k + PMAT-698p: pre-warm batched_rope_fwd at BOTH seq_len=1
272        // (Phase 3 single-token smoke) AND APR_DISTILL_SMOKE_SEQ_LEN
273        // (default 256 — Phase 4 real-corpus seq). Runtime keys
274        // (normalization.rs:339):
275        //   batched_rope_fwd_{num_heads}_{head_dim}_{seq_len}_th{theta_bits:08x}
276        // Stage C/D dispatch on gx10 confirmed runtime emits 2 [FWD-CACHE]
277        // Compiling events post-pre-warm for rope_fwd at seq=256 — avoidable
278        // JIT-cache pressure that PMAT-700-B closed for GEMMs.
279        use trueno_gpu::kernels::BatchedRopeKernel;
280        let qwen_theta = 1_000_000.0_f32;
281        let qwen_theta_bits = qwen_theta.to_bits();
282        let phase4_rope_seq: u32 = std::env::var("APR_DISTILL_SMOKE_SEQ_LEN")
283            .ok()
284            .and_then(|v| v.parse().ok())
285            .unwrap_or(256);
286        let nkv = _nkv;
287        for rope_seq in [1_u32, phase4_rope_seq] {
288            warm!(
289                format!("batched_rope_fwd_{nh}_{hd}_{rope_seq}_th{qwen_theta_bits:08x}"),
290                BatchedRopeKernel::new(nh, hd, rope_seq, qwen_theta)
291            );
292            if nkv != nh {
293                warm!(
294                    format!("batched_rope_fwd_{nkv}_{hd}_{rope_seq}_th{qwen_theta_bits:08x}"),
295                    BatchedRopeKernel::new(nkv, hd, rope_seq, qwen_theta)
296                );
297            }
298        }
299
300        // 6. Fused SwiGLU
301        warm!("fused_swiglu_forward".to_string(), FusedSwigluKernel::new(si));
302
303        // 7. Residual add (seq * hidden)
304        warm!("residual_add_forward".to_string(), ResidualAddKernel::new(sh));
305
306        // 8. Interleaved-to-batched (dimension-independent: one module handles all dims)
307        warm!("interleaved_to_batched".to_string(), InterleavedToBatchedKernel::new(s, nh, hd));
308
309        // 9. Batched transpose (dimension-independent: one module handles all dims)
310        warm!("batched_transpose".to_string(), BatchedTransposeKernel::new(nh, s, hd));
311
312        // 10. Batched 4D GEMM: Q@K^T (1, NH, S, S, HD)
313        warm!(
314            format!("batched_4d_gemm_1_{nh}_{s}_{s}_{hd}"),
315            Batched4DGemmKernel::new(1, nh, s, s, hd)
316        );
317
318        // 11. Scale: attention scores (NH * S * S)
319        let score_n = nh * s * s;
320        warm!("scale_forward".to_string(), ScaleKernel::new(score_n));
321
322        // 12. Batched softmax (dimension-independent: one module handles all dims)
323        let softmax_rows = nh * s;
324        warm!("batched_softmax_forward".to_string(), BatchedSoftmaxKernel::new(softmax_rows, s));
325
326        // 13. Batched 4D GEMM: attn@V (1, NH, S, HD, S)
327        warm!(
328            format!("batched_4d_gemm_1_{nh}_{s}_{hd}_{s}"),
329            Batched4DGemmKernel::new(1, nh, s, hd, s)
330        );
331
332        // 13b. Batched 4D GEMM: attention backward grad_V^T (1, NH, HD, S, S)
333        warm!(
334            format!("batched_4d_gemm_1_{nh}_{hd}_{s}_{s}"),
335            Batched4DGemmKernel::new(1, nh, hd, s, s)
336        );
337
338        // 14. Batched-to-interleaved (dimension-independent: one module handles all dims)
339        warm!("batched_to_interleaved".to_string(), BatchedToInterleavedKernel::new(s, nh, hd));
340
341        // 15. Element-wise multiply (used in FFN backward for SwiGLU gate * up)
342        warm!("elementwise_mul_forward".to_string(), ElementwiseMulKernel::new(si));
343
344        // 16. SiLU forward activation (standalone, used in LoRA FFN path)
345        warm!("silu_forward".to_string(), SiluKernel::new(si));
346
347        // 17-20. NF4 quantized GEMM variants (trueno#108: QLoRA support)
348        // Same 4 GEMM shapes but with Nf4GemmKernel instead of GemmKernel.
349        // Only compiled if K is divisible by 64 (NF4 block size).
350        if h.is_multiple_of(64) {
351            // NF4 cache keys exclude M (seq_len) — PTX is shape-independent
352            // (m/n/k are runtime params). Including M causes cache misses when
353            // actual seq_len != max_seq_len, triggering on-demand JIT that fails
354            // after GPU memory is loaded (trueno#184).
355            //
356            // Attention projections use q_dim (= num_heads * head_dim) which may
357            // differ from hidden_size (e.g. Qwen3-4B: h=2560, q_dim=4096).
358            // Q proj: input[S,h] @ W_q[h, q_dim] — key {h}_{q_dim}
359            warm!(format!("nf4_gemm_forward_{h}_{q_dim}"), Nf4GemmKernel::new(s, q_dim, h));
360            // O proj: input[S,q_dim] @ W_o[q_dim, h] — key {q_dim}_{h}
361            if q_dim != h {
362                warm!(format!("nf4_gemm_forward_{q_dim}_{h}"), Nf4GemmKernel::new(s, h, q_dim));
363            }
364            if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
365                warm!(format!("nf4_gemm_forward_{h}_{kv_h}"), Nf4GemmKernel::new(s, kv_h, h));
366            }
367            if i.is_multiple_of(64) {
368                warm!(format!("nf4_gemm_forward_{h}_{i}"), Nf4GemmKernel::new(s, i, h));
369                warm!(format!("nf4_gemm_forward_{i}_{h}"), Nf4GemmKernel::new(s, h, i));
370            }
371        }
372
373        // PMAT-475: Fused NF4 Gate+Up GEMM for FFN (shared input load).
374        if h.is_multiple_of(64) && i.is_multiple_of(64) {
375            use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
376            warm!(format!("fused_nf4_gate_up_{h}_{i}"), FusedNf4GateUpGemmKernel::new(s, i, h));
377        }
378        // PMAT-478: Fused K+V GEMM for GQA attention (reuses Gate+Up kernel).
379        if h.is_multiple_of(64) && kv_h.is_multiple_of(64) && kv_h != i {
380            use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
381            warm!(
382                format!("fused_nf4_gate_up_{h}_{kv_h}"),
383                FusedNf4GateUpGemmKernel::new(s, kv_h, h)
384            );
385        }
386
387        // 19-22. NF4 transposed GEMM for QLoRA backward (ENT-153).
388        // C[M×K] = A[M×N] @ B[K×N]^T — gradient propagation through frozen NF4 layers.
389        if h.is_multiple_of(64) {
390            // Q proj backward: grad[S,q_dim] @ W_q[h, q_dim]^T → [S,h]
391            warm!(
392                format!("nf4_gemm_transpose_{q_dim}_{h}"),
393                Nf4GemmTransposeKernel::new(s, q_dim, h)
394            );
395            // O proj backward: grad[S,h] @ W_o[q_dim, h]^T → [S,q_dim]
396            if q_dim != h {
397                warm!(
398                    format!("nf4_gemm_transpose_{h}_{q_dim}"),
399                    Nf4GemmTransposeKernel::new(s, h, q_dim)
400                );
401            }
402            if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
403                // K/V proj backward: grad[S,kv_h] @ W_k[h, kv_h]^T → [S,h]
404                warm!(
405                    format!("nf4_gemm_transpose_{kv_h}_{h}"),
406                    Nf4GemmTransposeKernel::new(s, kv_h, h)
407                );
408            }
409            if i.is_multiple_of(64) {
410                // Gate/Up backward: grad[S,I] @ W_gate[h,I]^T → [S,h]
411                warm!(format!("nf4_gemm_transpose_{i}_{h}"), Nf4GemmTransposeKernel::new(s, i, h));
412                // Down backward: grad[S,h] @ W_down[I,h]^T → [S,I]
413                warm!(format!("nf4_gemm_transpose_{h}_{i}"), Nf4GemmTransposeKernel::new(s, h, i));
414            }
415        }
416
417        eprintln!("[CUDA] Pre-warmed {count} forward kernels (JIT compiled before block upload)");
418        Ok(())
419    }
420
421    /// Pre-warm LoRA backward GEMM kernels for QLoRA training (ENT-153).
422    ///
423    /// The LoRA backward uses regular fp32 GEMMs for:
424    /// - Forward LoRA: x @ A → [S, R], inter @ B → [S, proj_dim]
425    /// - Backward A: x^T @ grad_inter → grad_A [H, R]
426    /// - Backward B: inter^T @ grad_proj → grad_B [R, proj_dim]
427    /// - Backward input: grad_proj @ B^T → [S, R], then [S, R] @ A^T → [S, H]
428    ///
429    /// These shapes are small (rank << hidden_size) but must still be JIT-compiled.
430    pub(super) fn pre_warm_lora_backward(
431        &mut self,
432        hidden_size: usize,
433        q_dim: usize,
434        kv_hidden_size: usize,
435        max_seq_len: usize,
436        lora_rank: usize,
437    ) -> Result<()> {
438        if lora_rank == 0 {
439            return Ok(());
440        }
441
442        let s = max_seq_len as u32;
443        let h = hidden_size as u32;
444        let r = lora_rank as u32;
445        let qd = q_dim as u32;
446        let kv = kv_hidden_size as u32;
447
448        let mut count = 0u32;
449        let target = self.sm_target.clone();
450
451        macro_rules! warm {
452            ($key:expr, $kernel:expr) => {{
453                let ptx = $kernel.emit_ptx_for_target(&target);
454                self.get_or_compile(&$key, &ptx)?;
455                count += 1;
456            }};
457        }
458
459        // LoRA forward GEMMs (also needed in backward for activation checkpointing)
460        // x[S,H] @ A[H,R] → [S,R]
461        warm!(format!("gemm_forward_{s}_{h}_{r}"), GemmKernel::naive(s, r, h));
462        // inter[S,R] @ B[R,qd] → [S,qd]
463        warm!(format!("gemm_forward_{s}_{r}_{qd}"), GemmKernel::naive(s, qd, r));
464        // inter[S,R] @ B[R,kv] → [S,kv]
465        if kv != qd {
466            warm!(format!("gemm_forward_{s}_{r}_{kv}"), GemmKernel::naive(s, kv, r));
467        }
468
469        // LoRA backward GEMMs (gemm_backward_a and gemm_backward_b use regular GEMM shapes)
470        // grad_B = inter^T[R,S] @ grad_proj[S,qd] → [R,qd]
471        // This is a GEMM with M=R, N=qd, K=S
472        warm!(format!("gemm_forward_{r}_{s}_{qd}"), GemmKernel::naive(r, qd, s));
473        if kv != qd {
474            warm!(format!("gemm_forward_{r}_{s}_{kv}"), GemmKernel::naive(r, kv, s));
475        }
476
477        // grad_li = grad_proj[S,qd] @ B^T[qd,R] → [S,R]
478        // This is effectively GEMM with M=S, N=R, K=qd
479        warm!(format!("gemm_forward_{s}_{qd}_{r}"), GemmKernel::naive(s, r, qd));
480        if kv != qd {
481            warm!(format!("gemm_forward_{s}_{kv}_{r}"), GemmKernel::naive(s, r, kv));
482        }
483
484        // grad_A = x^T[H,S] @ grad_li[S,R] → [H,R]
485        warm!(format!("gemm_forward_{h}_{s}_{r}"), GemmKernel::naive(h, r, s));
486
487        // grad_input += grad_li[S,R] @ A^T[R,H] → [S,H]
488        warm!(format!("gemm_forward_{s}_{r}_{h}"), GemmKernel::naive(s, h, r));
489
490        eprintln!("[CUDA] Pre-warmed {count} LoRA backward kernels");
491        Ok(())
492    }
493}
494
495/// Initialize forward kernel cache with CUDA context
496#[cfg(feature = "cuda")]
497pub fn init_forward_kernel_cache(ctx: std::sync::Arc<CudaContext>) -> Result<()> {
498    FORWARD_KERNEL_CACHE.get_or_init(|| Mutex::new(ForwardKernelCache::new(ctx)));
499    Ok(())
500}
501/// Pre-allocate cuBLAS workspace for CUDA graph capture (PMAT-063).
502#[cfg(feature = "cuda")]
503pub fn set_cublas_workspace(ptr: u64, size: usize) -> Result<()> {
504    let c = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
505    let c = c.lock().map_err(|_| CudaTensorError::KernelError("lock".into()))?;
506    if let Some(h) = c.cublas() {
507        h.set_workspace(ptr, size).map_err(|e| CudaTensorError::KernelError(format!("{e}")))?;
508    }
509    Ok(())
510}
511/// Bind cuBLAS handle to a stream (ALB-075).
512#[cfg(feature = "cuda")]
513pub fn set_forward_cublas_stream(stream: &CudaStream) -> Result<()> {
514    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
515    let cache = cache.lock().map_err(|_err| {
516        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
517    })?;
518    cache.set_cublas_stream(stream)
519}
520
521/// Pre-warm forward kernels (C-PREWARM-001: JIT before block upload).
522#[cfg(feature = "cuda")]
523pub fn pre_warm_forward_kernels(
524    hidden_size: usize,
525    intermediate_size: usize,
526    num_heads: usize,
527    num_kv_heads: usize,
528    head_dim: usize,
529    max_seq_len: usize,
530) -> Result<()> {
531    // trueno#200: Pre-warm backward kernels too (Blackwell JIT crash workaround)
532    pre_warm_backward_kernels_in_forward_cache(num_heads, num_kv_heads, head_dim, max_seq_len)?;
533    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
534    let mut cache = cache.lock().map_err(|_err| {
535        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
536    })?;
537    cache.pre_warm_for_model(
538        hidden_size,
539        intermediate_size,
540        num_heads,
541        num_kv_heads,
542        head_dim,
543        max_seq_len,
544    )
545}
546
547/// Pre-warm backward kernels in forward cache (trueno#200 Blackwell).
548///
549/// CONTRACT: All backward kernels must be compiled before GPU work starts.
550/// On Blackwell (sm_121), cuModuleLoadData fails during active GPU computation.
551#[cfg(feature = "cuda")]
552fn pre_warm_backward_kernels_in_forward_cache(
553    num_heads: usize,
554    _num_kv_heads: usize,
555    head_dim: usize,
556    max_seq_len: usize,
557) -> Result<()> {
558    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
559    let mut cache = cache.lock().map_err(|_err| {
560        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
561    })?;
562
563    let target = cache.sm_target.clone();
564    let _nh = num_heads as u32;
565    let _hd = head_dim as u32;
566    let _s = max_seq_len as u32;
567
568    macro_rules! warm {
569        ($key:expr, $kernel:expr) => {{
570            let ptx = $kernel.emit_ptx_for_target(&target);
571            cache.get_or_compile(&$key, &ptx)?;
572        }};
573    }
574
575    // Batched RoPE backward — missing from pre_warm_for_model, causes
576    // CUDA context poisoning on Blackwell when compiled during backward pass.
577    // Need BOTH num_heads AND num_kv_heads variants (GQA uses different head count for K/V).
578    //
579    // FALSIFY-CUDA-ROPE-THETA-CACHE-KEY-001: cache key now includes theta_bits
580    // (matching runtime in `batched_rope_neox_backward`). The hardcoded
581    // 1_000_000.0 here matches Qwen2 / Qwen2.5 default; for Llama
582    // pretrain (theta=10000) the runtime call will compile its own
583    // module on first use, no longer silently shadowing the Qwen warm.
584    let nh = num_heads as u32;
585    let nkv = _num_kv_heads as u32;
586    let hd = head_dim as u32;
587    let s = max_seq_len as u32;
588    let qwen_theta_bits = 1_000_000.0_f32.to_bits();
589    warm!(
590        format!("batched_rope_bwd_{nh}_{hd}_{s}_th{qwen_theta_bits:08x}"),
591        BatchedRopeBackwardKernel::new(nh, hd, s, 1_000_000.0)
592    );
593    if nkv != nh {
594        warm!(
595            format!("batched_rope_bwd_{nkv}_{hd}_{s}_th{qwen_theta_bits:08x}"),
596            BatchedRopeBackwardKernel::new(nkv, hd, s, 1_000_000.0)
597        );
598    }
599
600    eprintln!("  ✓ Backward rope kernel pre-warmed in forward cache");
601    Ok(())
602}
603
604/// Pre-warm LoRA backward GEMM kernels for QLoRA training (ENT-153).
605///
606/// Must be called BEFORE uploading transformer blocks. Compiles the
607/// small-matrix GEMMs needed for LoRA gradient computation.
608#[cfg(feature = "cuda")]
609pub fn pre_warm_lora_backward_kernels(
610    hidden_size: usize,
611    q_dim: usize,
612    kv_hidden_size: usize,
613    max_seq_len: usize,
614    lora_rank: usize,
615) -> Result<()> {
616    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
617    let mut cache = cache.lock().map_err(|_err| {
618        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
619    })?;
620    cache.pre_warm_lora_backward(hidden_size, q_dim, kv_hidden_size, max_seq_len, lora_rank)
621}