entrenar/autograd/cuda_forward/cache.rs
1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13#[cfg(feature = "cuda")]
14use trueno_gpu::kernels::{
15 Batched4DGemmKernel, BatchedRopeBackwardKernel, BatchedSoftmaxKernel,
16 BatchedToInterleavedKernel, BatchedTransposeKernel, BatchedVectorizedRmsNormKernel,
17 ElementwiseMulKernel, FusedSwigluKernel, GemmKernel, InterleavedToBatchedKernel, Kernel,
18 Nf4GemmKernel, Nf4GemmTransposeKernel, ResidualAddKernel, ScaleKernel, SiluKernel,
19};
20
21use crate::autograd::cuda_tensor::{CudaTensorError, Result};
22
23/// Cached compiled CUDA modules for forward kernels
24#[cfg(feature = "cuda")]
25pub(super) static FORWARD_KERNEL_CACHE: OnceLock<Mutex<ForwardKernelCache>> = OnceLock::new();
26
27/// Cache for compiled forward kernel modules
28///
29/// Stores the device's SM target (e.g. "sm_89") detected at init time.
30/// All PTX must be emitted for this target before compilation.
31///
32/// # Contract: F-PTX-001 (Target Parity)
33///
34/// PTX `.target` directive MUST match the device compute capability.
35/// The cache validates this at compile time and rejects mismatched PTX.
36#[cfg(feature = "cuda")]
37pub(super) struct ForwardKernelCache {
38 ctx: std::sync::Arc<CudaContext>,
39 modules: HashMap<String, CudaModule>,
40 /// Device SM target string (e.g. "sm_89" for RTX 4090)
41 sm_target: String,
42 /// cuBLAS handle (ALB-075): forward=tensor cores, backward=SIMD (ALB-076/trueno#170)
43 cublas: Option<CublasHandle>,
44}
45
46#[cfg(feature = "cuda")]
47impl ForwardKernelCache {
48 pub(super) fn new(ctx: std::sync::Arc<CudaContext>) -> Self {
49 // Detect device compute capability at construction time.
50 // Falls back to sm_70 if detection fails (should never happen
51 // since we already have a valid CudaContext).
52 let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
53
54 // entrenar#318: Forward uses TF32 tensor cores (~41x faster than SIMD on sm_89).
55 // ALB-076: TF32 is safe for forward (NoTrans/NoTrans). Backward uses SIMD handle.
56 let cublas = match CublasHandle::new_with_tensor_cores(&ctx) {
57 Ok(handle) => {
58 eprintln!("[CUDA] cuBLAS initialized — forward TF32 tensor cores (41x vs SIMD)");
59 Some(handle)
60 }
61 Err(e) => {
62 eprintln!("[CUDA] cuBLAS not available ({e:?}), using PTX GEMMs");
63 None
64 }
65 };
66
67 eprintln!("[CUDA] Kernel cache initialized for target: {sm_target}");
68 Self { ctx, modules: HashMap::new(), sm_target, cublas }
69 }
70
71 /// Get a reference to the cuBLAS handle, if available.
72 pub(super) fn cublas(&self) -> Option<&CublasHandle> {
73 self.cublas.as_ref()
74 }
75
76 /// Bind cuBLAS to a stream for the current training step.
77 pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
78 if let Some(ref handle) = self.cublas {
79 handle.set_stream(stream).map_err(|e| {
80 CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
81 })?;
82 }
83 Ok(())
84 }
85
86 /// Get the device SM target for PTX emission.
87 ///
88 /// Consumers MUST use this to emit PTX via `kernel.emit_ptx_for_target(cache.sm_target())`.
89 pub(super) fn sm_target(&self) -> &str {
90 &self.sm_target
91 }
92
93 /// Look up a previously compiled module by key (KAIZEN-058).
94 ///
95 /// Returns `Some` if the module is already cached (post-pre-warm: always).
96 /// Callers should use this before generating PTX to avoid unnecessary
97 /// multi-KB String allocations (~1000 per training step).
98 pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
99 self.modules.get_mut(name)
100 }
101
102 /// Compile PTX and cache the resulting module.
103 ///
104 /// # Contract: F-PTX-001 (Target Parity)
105 ///
106 /// Validates that the PTX `.target` directive matches the device's compute
107 /// capability. Rejects PTX compiled for the wrong architecture.
108 pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
109 use std::collections::hash_map::Entry;
110
111 // F-PTX-001: Validate PTX target matches device
112 if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
113 let ptx_target = target_line.trim().trim_start_matches(".target ");
114 if ptx_target != self.sm_target {
115 return Err(CudaTensorError::KernelError(format!(
116 "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'. \
117 Use kernel.emit_ptx_for_target(\"{}\") instead of emit_ptx().",
118 self.sm_target, self.sm_target
119 )));
120 }
121 }
122
123 match self.modules.entry(name.to_string()) {
124 Entry::Occupied(e) => Ok(e.into_mut()),
125 Entry::Vacant(e) => {
126 // PMAT-698i: diagnostic logging. Surfaces every forward-cache
127 // JIT event with its kernel name so missing pre-warm entries
128 // are identifiable in O(1) instead of O(N) iterations.
129 eprintln!("[FWD-CACHE] Compiling '{name}' (ptx_len={})", ptx.len());
130 // trueno#200: Use from_ptx_direct on Blackwell
131 let (major, _) = self.ctx.compute_capability().map_err(|e| {
132 CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
133 })?;
134 let module = if major >= 12 {
135 CudaModule::from_ptx_direct(&self.ctx, ptx)
136 } else {
137 CudaModule::from_ptx(&self.ctx, ptx)
138 }
139 .map_err(|err| {
140 CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
141 })?;
142 eprintln!("[FWD-CACHE] OK '{name}'");
143 Ok(e.insert(module))
144 }
145 }
146 }
147
148 /// Pre-warm all kernels needed for transformer forward pass.
149 ///
150 /// # Contract: C-PREWARM-001 (JIT Before Payload)
151 ///
152 /// - **Precondition**: Kernel cache initialized, GPU VRAM mostly free (no blocks uploaded yet)
153 /// - **Postcondition**: All forward-pass PTX modules JIT-compiled and cached
154 /// - **Invariant**: Subsequent `get_or_compile()` calls for these keys hit cache (zero JIT)
155 ///
156 /// CUDA's `cuModuleLoadDataEx` JIT compiler needs device memory for compilation.
157 /// If called after uploading 36 transformer blocks (~22 GB), the near-OOM state causes
158 /// `CUDA_ERROR_ILLEGAL_ADDRESS` during JIT (trueno#107). Pre-warming compiles all PTX
159 /// while VRAM is free, avoiding this failure mode entirely.
160 pub(super) fn pre_warm_for_model(
161 &mut self,
162 hidden_size: usize,
163 intermediate_size: usize,
164 num_heads: usize,
165 num_kv_heads: usize,
166 head_dim: usize,
167 max_seq_len: usize,
168 ) -> Result<()> {
169 let s = max_seq_len as u32;
170 let h = hidden_size as u32;
171 let q_dim = (num_heads * head_dim) as u32; // Q/O projection dim (may differ from h)
172 let kv_h = (num_kv_heads * head_dim) as u32;
173 let i = intermediate_size as u32;
174 let nh = num_heads as u32;
175 let _nkv = num_kv_heads as u32;
176 let hd = head_dim as u32;
177 let sh = s * h; // seq_len * hidden_size
178 let si = s * i; // seq_len * intermediate_size
179
180 let mut count = 0u32;
181 let target = self.sm_target.clone();
182
183 // Helper: generate PTX and compile.
184 //
185 // PMAT-698j: previously hardcoded "silu_forward" as the cache key,
186 // which meant every warm!() call collided on the same HashMap entry.
187 // Only the FIRST kernel compiled actually got stored; all subsequent
188 // warm!() invocations short-circuited because "silu_forward" was
189 // already occupied. At runtime every other kernel (rmsnorm, rope,
190 // softmax, swiglu, residual, etc.) cache-missed under its real key
191 // and JIT-compiled mid-training — on Blackwell sm_121 that
192 // corrupted the CUDA stream and surfaced as the cascading "Block 0
193 // upload failed" / "forward_backward_with_grad returned None"
194 // errors hunted across PMAT-698e..i.
195 //
196 // Discovered by PMAT-698i diagnostic logging: [FWD-CACHE] showed
197 // every "pre-warmed" kernel actually JIT'd at first use because
198 // the cache only contained one entry. One-character fix.
199 macro_rules! warm {
200 ($key:expr, $kernel:expr) => {{
201 let key = $key;
202 let ptx = $kernel.emit_ptx_for_target(&target);
203 self.get_or_compile(&key, &ptx)?;
204 count += 1;
205 }};
206 }
207
208 // 1. RMSNorm (batched: single launch for all rows via grid.y)
209 // ALB-076: Use BatchedVectorizedRmsNormKernel instead of per-row RmsNormKernel
210 //
211 // PMAT-698k: the runtime key format includes the eps as bit-pattern
212 // suffix (normalization.rs:139:
213 // let key = format!("batched_rmsnorm_fwd_{hidden_size}_eps{eps_bits:08x}"))
214 // Pre-warm key used to omit the eps suffix → cache miss at runtime →
215 // JIT mid-forward → Blackwell sm_121 stream poisoning.
216 //
217 // PMAT-698n: PMAT-698k pre-warmed at eps=1e-5 (0x3727c5ac) but the
218 // dominant model (Qwen2 / Qwen2.5) uses rms_norm_eps=1e-6
219 // (0x358637bd). Live diagnostic confirmed the runtime key on the
220 // Phase 3 dispatch was `batched_rmsnorm_fwd_896_eps358637bd`. Switch
221 // the pre-warm default to 1e-6 (Qwen2 standard) AND additionally
222 // pre-warm 1e-5 (Llama/Mistral standard) for cross-family coverage.
223 // The cost of pre-warming both is ~30 KB of cache headroom.
224 let qwen2_eps_bits = 1.0e-6_f32.to_bits(); // 0x358637bd
225 let llama_eps_bits = 1.0e-5_f32.to_bits(); // 0x3727c5ac
226 warm!(
227 format!("batched_rmsnorm_fwd_{h}_eps{qwen2_eps_bits:08x}"),
228 BatchedVectorizedRmsNormKernel::new(h, 1)
229 );
230 if qwen2_eps_bits != llama_eps_bits {
231 warm!(
232 format!("batched_rmsnorm_fwd_{h}_eps{llama_eps_bits:08x}"),
233 BatchedVectorizedRmsNormKernel::new(h, 1)
234 );
235 }
236
237 // PMAT-700 (SPEC-BLACKWELL-FIX-001 Fix #2): when cuBLAS is available
238 // and the runtime takes its fast path for the standard 2D GEMMs
239 // (Q/K/V/O/gate/up/down projections — see ALB-075 dispatch in
240 // gemm.rs:47-49 and cuda_block.rs:2895), pre-warming the PTX
241 // equivalents is wasted VRAM. On sm_121 (Blackwell GB10) the
242 // resulting JIT-cache footprint pushes block upload over the budget
243 // and CUDA_ERROR_OUT_OF_MEMORY fires at "Block 0 upload". Skipping
244 // these four pre-warms when cuBLAS is bound saves ~5-7 PTX modules
245 // per cache (more on multi-block-size models) and unblocks gx10
246 // dispatch without any runtime path change.
247 //
248 // Falsifier: F-BLACKWELL-CUBLAS-PREWARM-001 — assert the cache
249 // module count after pre_warm_for_model decreases when cuBLAS is
250 // present, and that runtime forward still produces identical
251 // results on a known input (cuBLAS path was already taken).
252 let has_cublas = self.cublas.is_some();
253 if !has_cublas {
254 // 2. GEMM: Q/O projections (S, H, H)
255 warm!(format!("gemm_forward_{s}_{h}_{h}"), GemmKernel::naive(s, h, h));
256
257 // 3. GEMM: K/V projections (S, H, kv_hidden)
258 if kv_h != h {
259 warm!(format!("gemm_forward_{s}_{h}_{kv_h}"), GemmKernel::naive(s, kv_h, h));
260 }
261
262 // 4. GEMM: gate/up projections (S, H, I)
263 warm!(format!("gemm_forward_{s}_{h}_{i}"), GemmKernel::naive(s, i, h));
264
265 // 5. GEMM: down projection (S, I, H)
266 warm!(format!("gemm_forward_{s}_{i}_{h}"), GemmKernel::naive(s, h, i));
267 } else {
268 eprintln!("[CUDA] Skipping PTX pre-warm for 4 GEMM kernels (cuBLAS active — PMAT-700)");
269 }
270
271 // PMAT-698k + PMAT-698p: pre-warm batched_rope_fwd at BOTH seq_len=1
272 // (Phase 3 single-token smoke) AND APR_DISTILL_SMOKE_SEQ_LEN
273 // (default 256 — Phase 4 real-corpus seq). Runtime keys
274 // (normalization.rs:339):
275 // batched_rope_fwd_{num_heads}_{head_dim}_{seq_len}_th{theta_bits:08x}
276 // Stage C/D dispatch on gx10 confirmed runtime emits 2 [FWD-CACHE]
277 // Compiling events post-pre-warm for rope_fwd at seq=256 — avoidable
278 // JIT-cache pressure that PMAT-700-B closed for GEMMs.
279 use trueno_gpu::kernels::BatchedRopeKernel;
280 let qwen_theta = 1_000_000.0_f32;
281 let qwen_theta_bits = qwen_theta.to_bits();
282 let phase4_rope_seq: u32 = std::env::var("APR_DISTILL_SMOKE_SEQ_LEN")
283 .ok()
284 .and_then(|v| v.parse().ok())
285 .unwrap_or(256);
286 let nkv = _nkv;
287 for rope_seq in [1_u32, phase4_rope_seq] {
288 warm!(
289 format!("batched_rope_fwd_{nh}_{hd}_{rope_seq}_th{qwen_theta_bits:08x}"),
290 BatchedRopeKernel::new(nh, hd, rope_seq, qwen_theta)
291 );
292 if nkv != nh {
293 warm!(
294 format!("batched_rope_fwd_{nkv}_{hd}_{rope_seq}_th{qwen_theta_bits:08x}"),
295 BatchedRopeKernel::new(nkv, hd, rope_seq, qwen_theta)
296 );
297 }
298 }
299
300 // 6. Fused SwiGLU
301 warm!("fused_swiglu_forward".to_string(), FusedSwigluKernel::new(si));
302
303 // 7. Residual add (seq * hidden)
304 warm!("residual_add_forward".to_string(), ResidualAddKernel::new(sh));
305
306 // 8. Interleaved-to-batched (dimension-independent: one module handles all dims)
307 warm!("interleaved_to_batched".to_string(), InterleavedToBatchedKernel::new(s, nh, hd));
308
309 // 9. Batched transpose (dimension-independent: one module handles all dims)
310 warm!("batched_transpose".to_string(), BatchedTransposeKernel::new(nh, s, hd));
311
312 // 10. Batched 4D GEMM: Q@K^T (1, NH, S, S, HD)
313 warm!(
314 format!("batched_4d_gemm_1_{nh}_{s}_{s}_{hd}"),
315 Batched4DGemmKernel::new(1, nh, s, s, hd)
316 );
317
318 // 11. Scale: attention scores (NH * S * S)
319 let score_n = nh * s * s;
320 warm!("scale_forward".to_string(), ScaleKernel::new(score_n));
321
322 // 12. Batched softmax (dimension-independent: one module handles all dims)
323 let softmax_rows = nh * s;
324 warm!("batched_softmax_forward".to_string(), BatchedSoftmaxKernel::new(softmax_rows, s));
325
326 // 13. Batched 4D GEMM: attn@V (1, NH, S, HD, S)
327 warm!(
328 format!("batched_4d_gemm_1_{nh}_{s}_{hd}_{s}"),
329 Batched4DGemmKernel::new(1, nh, s, hd, s)
330 );
331
332 // 13b. Batched 4D GEMM: attention backward grad_V^T (1, NH, HD, S, S)
333 warm!(
334 format!("batched_4d_gemm_1_{nh}_{hd}_{s}_{s}"),
335 Batched4DGemmKernel::new(1, nh, hd, s, s)
336 );
337
338 // 14. Batched-to-interleaved (dimension-independent: one module handles all dims)
339 warm!("batched_to_interleaved".to_string(), BatchedToInterleavedKernel::new(s, nh, hd));
340
341 // 15. Element-wise multiply (used in FFN backward for SwiGLU gate * up)
342 warm!("elementwise_mul_forward".to_string(), ElementwiseMulKernel::new(si));
343
344 // 16. SiLU forward activation (standalone, used in LoRA FFN path)
345 warm!("silu_forward".to_string(), SiluKernel::new(si));
346
347 // 17-20. NF4 quantized GEMM variants (trueno#108: QLoRA support)
348 // Same 4 GEMM shapes but with Nf4GemmKernel instead of GemmKernel.
349 // Only compiled if K is divisible by 64 (NF4 block size).
350 if h.is_multiple_of(64) {
351 // NF4 cache keys exclude M (seq_len) — PTX is shape-independent
352 // (m/n/k are runtime params). Including M causes cache misses when
353 // actual seq_len != max_seq_len, triggering on-demand JIT that fails
354 // after GPU memory is loaded (trueno#184).
355 //
356 // Attention projections use q_dim (= num_heads * head_dim) which may
357 // differ from hidden_size (e.g. Qwen3-4B: h=2560, q_dim=4096).
358 // Q proj: input[S,h] @ W_q[h, q_dim] — key {h}_{q_dim}
359 warm!(format!("nf4_gemm_forward_{h}_{q_dim}"), Nf4GemmKernel::new(s, q_dim, h));
360 // O proj: input[S,q_dim] @ W_o[q_dim, h] — key {q_dim}_{h}
361 if q_dim != h {
362 warm!(format!("nf4_gemm_forward_{q_dim}_{h}"), Nf4GemmKernel::new(s, h, q_dim));
363 }
364 if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
365 warm!(format!("nf4_gemm_forward_{h}_{kv_h}"), Nf4GemmKernel::new(s, kv_h, h));
366 }
367 if i.is_multiple_of(64) {
368 warm!(format!("nf4_gemm_forward_{h}_{i}"), Nf4GemmKernel::new(s, i, h));
369 warm!(format!("nf4_gemm_forward_{i}_{h}"), Nf4GemmKernel::new(s, h, i));
370 }
371 }
372
373 // PMAT-475: Fused NF4 Gate+Up GEMM for FFN (shared input load).
374 if h.is_multiple_of(64) && i.is_multiple_of(64) {
375 use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
376 warm!(format!("fused_nf4_gate_up_{h}_{i}"), FusedNf4GateUpGemmKernel::new(s, i, h));
377 }
378 // PMAT-478: Fused K+V GEMM for GQA attention (reuses Gate+Up kernel).
379 if h.is_multiple_of(64) && kv_h.is_multiple_of(64) && kv_h != i {
380 use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
381 warm!(
382 format!("fused_nf4_gate_up_{h}_{kv_h}"),
383 FusedNf4GateUpGemmKernel::new(s, kv_h, h)
384 );
385 }
386
387 // 19-22. NF4 transposed GEMM for QLoRA backward (ENT-153).
388 // C[M×K] = A[M×N] @ B[K×N]^T — gradient propagation through frozen NF4 layers.
389 if h.is_multiple_of(64) {
390 // Q proj backward: grad[S,q_dim] @ W_q[h, q_dim]^T → [S,h]
391 warm!(
392 format!("nf4_gemm_transpose_{q_dim}_{h}"),
393 Nf4GemmTransposeKernel::new(s, q_dim, h)
394 );
395 // O proj backward: grad[S,h] @ W_o[q_dim, h]^T → [S,q_dim]
396 if q_dim != h {
397 warm!(
398 format!("nf4_gemm_transpose_{h}_{q_dim}"),
399 Nf4GemmTransposeKernel::new(s, h, q_dim)
400 );
401 }
402 if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
403 // K/V proj backward: grad[S,kv_h] @ W_k[h, kv_h]^T → [S,h]
404 warm!(
405 format!("nf4_gemm_transpose_{kv_h}_{h}"),
406 Nf4GemmTransposeKernel::new(s, kv_h, h)
407 );
408 }
409 if i.is_multiple_of(64) {
410 // Gate/Up backward: grad[S,I] @ W_gate[h,I]^T → [S,h]
411 warm!(format!("nf4_gemm_transpose_{i}_{h}"), Nf4GemmTransposeKernel::new(s, i, h));
412 // Down backward: grad[S,h] @ W_down[I,h]^T → [S,I]
413 warm!(format!("nf4_gemm_transpose_{h}_{i}"), Nf4GemmTransposeKernel::new(s, h, i));
414 }
415 }
416
417 eprintln!("[CUDA] Pre-warmed {count} forward kernels (JIT compiled before block upload)");
418 Ok(())
419 }
420
421 /// Pre-warm LoRA backward GEMM kernels for QLoRA training (ENT-153).
422 ///
423 /// The LoRA backward uses regular fp32 GEMMs for:
424 /// - Forward LoRA: x @ A → [S, R], inter @ B → [S, proj_dim]
425 /// - Backward A: x^T @ grad_inter → grad_A [H, R]
426 /// - Backward B: inter^T @ grad_proj → grad_B [R, proj_dim]
427 /// - Backward input: grad_proj @ B^T → [S, R], then [S, R] @ A^T → [S, H]
428 ///
429 /// These shapes are small (rank << hidden_size) but must still be JIT-compiled.
430 pub(super) fn pre_warm_lora_backward(
431 &mut self,
432 hidden_size: usize,
433 q_dim: usize,
434 kv_hidden_size: usize,
435 max_seq_len: usize,
436 lora_rank: usize,
437 ) -> Result<()> {
438 if lora_rank == 0 {
439 return Ok(());
440 }
441
442 let s = max_seq_len as u32;
443 let h = hidden_size as u32;
444 let r = lora_rank as u32;
445 let qd = q_dim as u32;
446 let kv = kv_hidden_size as u32;
447
448 let mut count = 0u32;
449 let target = self.sm_target.clone();
450
451 macro_rules! warm {
452 ($key:expr, $kernel:expr) => {{
453 let ptx = $kernel.emit_ptx_for_target(&target);
454 self.get_or_compile(&$key, &ptx)?;
455 count += 1;
456 }};
457 }
458
459 // LoRA forward GEMMs (also needed in backward for activation checkpointing)
460 // x[S,H] @ A[H,R] → [S,R]
461 warm!(format!("gemm_forward_{s}_{h}_{r}"), GemmKernel::naive(s, r, h));
462 // inter[S,R] @ B[R,qd] → [S,qd]
463 warm!(format!("gemm_forward_{s}_{r}_{qd}"), GemmKernel::naive(s, qd, r));
464 // inter[S,R] @ B[R,kv] → [S,kv]
465 if kv != qd {
466 warm!(format!("gemm_forward_{s}_{r}_{kv}"), GemmKernel::naive(s, kv, r));
467 }
468
469 // LoRA backward GEMMs (gemm_backward_a and gemm_backward_b use regular GEMM shapes)
470 // grad_B = inter^T[R,S] @ grad_proj[S,qd] → [R,qd]
471 // This is a GEMM with M=R, N=qd, K=S
472 warm!(format!("gemm_forward_{r}_{s}_{qd}"), GemmKernel::naive(r, qd, s));
473 if kv != qd {
474 warm!(format!("gemm_forward_{r}_{s}_{kv}"), GemmKernel::naive(r, kv, s));
475 }
476
477 // grad_li = grad_proj[S,qd] @ B^T[qd,R] → [S,R]
478 // This is effectively GEMM with M=S, N=R, K=qd
479 warm!(format!("gemm_forward_{s}_{qd}_{r}"), GemmKernel::naive(s, r, qd));
480 if kv != qd {
481 warm!(format!("gemm_forward_{s}_{kv}_{r}"), GemmKernel::naive(s, r, kv));
482 }
483
484 // grad_A = x^T[H,S] @ grad_li[S,R] → [H,R]
485 warm!(format!("gemm_forward_{h}_{s}_{r}"), GemmKernel::naive(h, r, s));
486
487 // grad_input += grad_li[S,R] @ A^T[R,H] → [S,H]
488 warm!(format!("gemm_forward_{s}_{r}_{h}"), GemmKernel::naive(s, h, r));
489
490 eprintln!("[CUDA] Pre-warmed {count} LoRA backward kernels");
491 Ok(())
492 }
493}
494
495/// Initialize forward kernel cache with CUDA context
496#[cfg(feature = "cuda")]
497pub fn init_forward_kernel_cache(ctx: std::sync::Arc<CudaContext>) -> Result<()> {
498 FORWARD_KERNEL_CACHE.get_or_init(|| Mutex::new(ForwardKernelCache::new(ctx)));
499 Ok(())
500}
501/// Pre-allocate cuBLAS workspace for CUDA graph capture (PMAT-063).
502#[cfg(feature = "cuda")]
503pub fn set_cublas_workspace(ptr: u64, size: usize) -> Result<()> {
504 let c = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
505 let c = c.lock().map_err(|_| CudaTensorError::KernelError("lock".into()))?;
506 if let Some(h) = c.cublas() {
507 h.set_workspace(ptr, size).map_err(|e| CudaTensorError::KernelError(format!("{e}")))?;
508 }
509 Ok(())
510}
511/// Bind cuBLAS handle to a stream (ALB-075).
512#[cfg(feature = "cuda")]
513pub fn set_forward_cublas_stream(stream: &CudaStream) -> Result<()> {
514 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
515 let cache = cache.lock().map_err(|_err| {
516 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
517 })?;
518 cache.set_cublas_stream(stream)
519}
520
521/// Pre-warm forward kernels (C-PREWARM-001: JIT before block upload).
522#[cfg(feature = "cuda")]
523pub fn pre_warm_forward_kernels(
524 hidden_size: usize,
525 intermediate_size: usize,
526 num_heads: usize,
527 num_kv_heads: usize,
528 head_dim: usize,
529 max_seq_len: usize,
530) -> Result<()> {
531 // trueno#200: Pre-warm backward kernels too (Blackwell JIT crash workaround)
532 pre_warm_backward_kernels_in_forward_cache(num_heads, num_kv_heads, head_dim, max_seq_len)?;
533 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
534 let mut cache = cache.lock().map_err(|_err| {
535 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
536 })?;
537 cache.pre_warm_for_model(
538 hidden_size,
539 intermediate_size,
540 num_heads,
541 num_kv_heads,
542 head_dim,
543 max_seq_len,
544 )
545}
546
547/// Pre-warm backward kernels in forward cache (trueno#200 Blackwell).
548///
549/// CONTRACT: All backward kernels must be compiled before GPU work starts.
550/// On Blackwell (sm_121), cuModuleLoadData fails during active GPU computation.
551#[cfg(feature = "cuda")]
552fn pre_warm_backward_kernels_in_forward_cache(
553 num_heads: usize,
554 _num_kv_heads: usize,
555 head_dim: usize,
556 max_seq_len: usize,
557) -> Result<()> {
558 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
559 let mut cache = cache.lock().map_err(|_err| {
560 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
561 })?;
562
563 let target = cache.sm_target.clone();
564 let _nh = num_heads as u32;
565 let _hd = head_dim as u32;
566 let _s = max_seq_len as u32;
567
568 macro_rules! warm {
569 ($key:expr, $kernel:expr) => {{
570 let ptx = $kernel.emit_ptx_for_target(&target);
571 cache.get_or_compile(&$key, &ptx)?;
572 }};
573 }
574
575 // Batched RoPE backward — missing from pre_warm_for_model, causes
576 // CUDA context poisoning on Blackwell when compiled during backward pass.
577 // Need BOTH num_heads AND num_kv_heads variants (GQA uses different head count for K/V).
578 //
579 // FALSIFY-CUDA-ROPE-THETA-CACHE-KEY-001: cache key now includes theta_bits
580 // (matching runtime in `batched_rope_neox_backward`). The hardcoded
581 // 1_000_000.0 here matches Qwen2 / Qwen2.5 default; for Llama
582 // pretrain (theta=10000) the runtime call will compile its own
583 // module on first use, no longer silently shadowing the Qwen warm.
584 let nh = num_heads as u32;
585 let nkv = _num_kv_heads as u32;
586 let hd = head_dim as u32;
587 let s = max_seq_len as u32;
588 let qwen_theta_bits = 1_000_000.0_f32.to_bits();
589 warm!(
590 format!("batched_rope_bwd_{nh}_{hd}_{s}_th{qwen_theta_bits:08x}"),
591 BatchedRopeBackwardKernel::new(nh, hd, s, 1_000_000.0)
592 );
593 if nkv != nh {
594 warm!(
595 format!("batched_rope_bwd_{nkv}_{hd}_{s}_th{qwen_theta_bits:08x}"),
596 BatchedRopeBackwardKernel::new(nkv, hd, s, 1_000_000.0)
597 );
598 }
599
600 eprintln!(" ✓ Backward rope kernel pre-warmed in forward cache");
601 Ok(())
602}
603
604/// Pre-warm LoRA backward GEMM kernels for QLoRA training (ENT-153).
605///
606/// Must be called BEFORE uploading transformer blocks. Compiles the
607/// small-matrix GEMMs needed for LoRA gradient computation.
608#[cfg(feature = "cuda")]
609pub fn pre_warm_lora_backward_kernels(
610 hidden_size: usize,
611 q_dim: usize,
612 kv_hidden_size: usize,
613 max_seq_len: usize,
614 lora_rank: usize,
615) -> Result<()> {
616 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
617 let mut cache = cache.lock().map_err(|_err| {
618 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
619 })?;
620 cache.pre_warm_lora_backward(hidden_size, q_dim, kv_hidden_size, max_seq_len, lora_rank)
621}