Skip to main content

entrenar/autograd/cuda_forward/
cache.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13#[cfg(feature = "cuda")]
14use trueno_gpu::kernels::{
15    Batched4DGemmKernel, BatchedRopeBackwardKernel, BatchedSoftmaxKernel,
16    BatchedToInterleavedKernel, BatchedTransposeKernel, BatchedVectorizedRmsNormKernel,
17    ElementwiseMulKernel, FusedSwigluKernel, GemmKernel, InterleavedToBatchedKernel, Kernel,
18    Nf4GemmKernel, Nf4GemmTransposeKernel, ResidualAddKernel, ScaleKernel, SiluKernel,
19};
20
21use crate::autograd::cuda_tensor::{CudaTensorError, Result};
22
23/// Cached compiled CUDA modules for forward kernels
24#[cfg(feature = "cuda")]
25pub(super) static FORWARD_KERNEL_CACHE: OnceLock<Mutex<ForwardKernelCache>> = OnceLock::new();
26
27/// Cache for compiled forward kernel modules
28///
29/// Stores the device's SM target (e.g. "sm_89") detected at init time.
30/// All PTX must be emitted for this target before compilation.
31///
32/// # Contract: F-PTX-001 (Target Parity)
33///
34/// PTX `.target` directive MUST match the device compute capability.
35/// The cache validates this at compile time and rejects mismatched PTX.
36#[cfg(feature = "cuda")]
37pub(super) struct ForwardKernelCache {
38    ctx: std::sync::Arc<CudaContext>,
39    modules: HashMap<String, CudaModule>,
40    /// Device SM target string (e.g. "sm_89" for RTX 4090)
41    sm_target: String,
42    /// cuBLAS handle (ALB-075): forward=tensor cores, backward=SIMD (ALB-076/trueno#170)
43    cublas: Option<CublasHandle>,
44}
45
46#[cfg(feature = "cuda")]
47impl ForwardKernelCache {
48    pub(super) fn new(ctx: std::sync::Arc<CudaContext>) -> Self {
49        // Detect device compute capability at construction time.
50        // Falls back to sm_70 if detection fails (should never happen
51        // since we already have a valid CudaContext).
52        let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
53
54        // entrenar#318: Forward uses TF32 tensor cores (~41x faster than SIMD on sm_89).
55        // ALB-076: TF32 is safe for forward (NoTrans/NoTrans). Backward uses SIMD handle.
56        let cublas = match CublasHandle::new_with_tensor_cores(&ctx) {
57            Ok(handle) => {
58                eprintln!("[CUDA] cuBLAS initialized — forward TF32 tensor cores (41x vs SIMD)");
59                Some(handle)
60            }
61            Err(e) => {
62                eprintln!("[CUDA] cuBLAS not available ({e:?}), using PTX GEMMs");
63                None
64            }
65        };
66
67        eprintln!("[CUDA] Kernel cache initialized for target: {sm_target}");
68        Self { ctx, modules: HashMap::new(), sm_target, cublas }
69    }
70
71    /// Get a reference to the cuBLAS handle, if available.
72    pub(super) fn cublas(&self) -> Option<&CublasHandle> {
73        self.cublas.as_ref()
74    }
75
76    /// Bind cuBLAS to a stream for the current training step.
77    pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
78        if let Some(ref handle) = self.cublas {
79            handle.set_stream(stream).map_err(|e| {
80                CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
81            })?;
82        }
83        Ok(())
84    }
85
86    /// Get the device SM target for PTX emission.
87    ///
88    /// Consumers MUST use this to emit PTX via `kernel.emit_ptx_for_target(cache.sm_target())`.
89    pub(super) fn sm_target(&self) -> &str {
90        &self.sm_target
91    }
92
93    /// Look up a previously compiled module by key (KAIZEN-058).
94    ///
95    /// Returns `Some` if the module is already cached (post-pre-warm: always).
96    /// Callers should use this before generating PTX to avoid unnecessary
97    /// multi-KB String allocations (~1000 per training step).
98    pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
99        self.modules.get_mut(name)
100    }
101
102    /// Compile PTX and cache the resulting module.
103    ///
104    /// # Contract: F-PTX-001 (Target Parity)
105    ///
106    /// Validates that the PTX `.target` directive matches the device's compute
107    /// capability. Rejects PTX compiled for the wrong architecture.
108    pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
109        use std::collections::hash_map::Entry;
110
111        // F-PTX-001: Validate PTX target matches device
112        if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
113            let ptx_target = target_line.trim().trim_start_matches(".target ");
114            if ptx_target != self.sm_target {
115                return Err(CudaTensorError::KernelError(format!(
116                    "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'. \
117                     Use kernel.emit_ptx_for_target(\"{}\") instead of emit_ptx().",
118                    self.sm_target, self.sm_target
119                )));
120            }
121        }
122
123        match self.modules.entry(name.to_string()) {
124            Entry::Occupied(e) => Ok(e.into_mut()),
125            Entry::Vacant(e) => {
126                // trueno#200: Use from_ptx_direct on Blackwell
127                let (major, _) = self.ctx.compute_capability().map_err(|e| {
128                    CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
129                })?;
130                let module = if major >= 12 {
131                    CudaModule::from_ptx_direct(&self.ctx, ptx)
132                } else {
133                    CudaModule::from_ptx(&self.ctx, ptx)
134                }
135                .map_err(|err| {
136                    CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
137                })?;
138                Ok(e.insert(module))
139            }
140        }
141    }
142
143    /// Pre-warm all kernels needed for transformer forward pass.
144    ///
145    /// # Contract: C-PREWARM-001 (JIT Before Payload)
146    ///
147    /// - **Precondition**: Kernel cache initialized, GPU VRAM mostly free (no blocks uploaded yet)
148    /// - **Postcondition**: All forward-pass PTX modules JIT-compiled and cached
149    /// - **Invariant**: Subsequent `get_or_compile()` calls for these keys hit cache (zero JIT)
150    ///
151    /// CUDA's `cuModuleLoadDataEx` JIT compiler needs device memory for compilation.
152    /// If called after uploading 36 transformer blocks (~22 GB), the near-OOM state causes
153    /// `CUDA_ERROR_ILLEGAL_ADDRESS` during JIT (trueno#107). Pre-warming compiles all PTX
154    /// while VRAM is free, avoiding this failure mode entirely.
155    pub(super) fn pre_warm_for_model(
156        &mut self,
157        hidden_size: usize,
158        intermediate_size: usize,
159        num_heads: usize,
160        num_kv_heads: usize,
161        head_dim: usize,
162        max_seq_len: usize,
163    ) -> Result<()> {
164        let s = max_seq_len as u32;
165        let h = hidden_size as u32;
166        let q_dim = (num_heads * head_dim) as u32; // Q/O projection dim (may differ from h)
167        let kv_h = (num_kv_heads * head_dim) as u32;
168        let i = intermediate_size as u32;
169        let nh = num_heads as u32;
170        let _nkv = num_kv_heads as u32;
171        let hd = head_dim as u32;
172        let sh = s * h; // seq_len * hidden_size
173        let si = s * i; // seq_len * intermediate_size
174
175        let mut count = 0u32;
176        let target = self.sm_target.clone();
177
178        // Helper: generate PTX and compile
179        macro_rules! warm {
180            ($key:expr, $kernel:expr) => {{
181                let ptx = $kernel.emit_ptx_for_target(&target);
182                self.get_or_compile("silu_forward", &ptx)?;
183                count += 1;
184            }};
185        }
186
187        // 1. RMSNorm (batched: single launch for all rows via grid.y)
188        // ALB-076: Use BatchedVectorizedRmsNormKernel instead of per-row RmsNormKernel
189        warm!(format!("batched_rmsnorm_fwd_{h}"), BatchedVectorizedRmsNormKernel::new(h, 1));
190
191        // 2. GEMM: Q/O projections (S, H, H)
192        warm!(format!("gemm_forward_{s}_{h}_{h}"), GemmKernel::naive(s, h, h));
193
194        // 3. GEMM: K/V projections (S, H, kv_hidden)
195        if kv_h != h {
196            warm!(format!("gemm_forward_{s}_{h}_{kv_h}"), GemmKernel::naive(s, kv_h, h));
197        }
198
199        // 4. GEMM: gate/up projections (S, H, I)
200        warm!(format!("gemm_forward_{s}_{h}_{i}"), GemmKernel::naive(s, i, h));
201
202        // 5. GEMM: down projection (S, I, H)
203        warm!(format!("gemm_forward_{s}_{i}_{h}"), GemmKernel::naive(s, h, i));
204
205        // 6. Fused SwiGLU
206        warm!("fused_swiglu_forward".to_string(), FusedSwigluKernel::new(si));
207
208        // 7. Residual add (seq * hidden)
209        warm!("residual_add_forward".to_string(), ResidualAddKernel::new(sh));
210
211        // 8. Interleaved-to-batched (dimension-independent: one module handles all dims)
212        warm!("interleaved_to_batched".to_string(), InterleavedToBatchedKernel::new(s, nh, hd));
213
214        // 9. Batched transpose (dimension-independent: one module handles all dims)
215        warm!("batched_transpose".to_string(), BatchedTransposeKernel::new(nh, s, hd));
216
217        // 10. Batched 4D GEMM: Q@K^T (1, NH, S, S, HD)
218        warm!(
219            format!("batched_4d_gemm_1_{nh}_{s}_{s}_{hd}"),
220            Batched4DGemmKernel::new(1, nh, s, s, hd)
221        );
222
223        // 11. Scale: attention scores (NH * S * S)
224        let score_n = nh * s * s;
225        warm!("scale_forward".to_string(), ScaleKernel::new(score_n));
226
227        // 12. Batched softmax (dimension-independent: one module handles all dims)
228        let softmax_rows = nh * s;
229        warm!("batched_softmax_forward".to_string(), BatchedSoftmaxKernel::new(softmax_rows, s));
230
231        // 13. Batched 4D GEMM: attn@V (1, NH, S, HD, S)
232        warm!(
233            format!("batched_4d_gemm_1_{nh}_{s}_{hd}_{s}"),
234            Batched4DGemmKernel::new(1, nh, s, hd, s)
235        );
236
237        // 13b. Batched 4D GEMM: attention backward grad_V^T (1, NH, HD, S, S)
238        warm!(
239            format!("batched_4d_gemm_1_{nh}_{hd}_{s}_{s}"),
240            Batched4DGemmKernel::new(1, nh, hd, s, s)
241        );
242
243        // 14. Batched-to-interleaved (dimension-independent: one module handles all dims)
244        warm!("batched_to_interleaved".to_string(), BatchedToInterleavedKernel::new(s, nh, hd));
245
246        // 15. Element-wise multiply (used in FFN backward for SwiGLU gate * up)
247        warm!("elementwise_mul_forward".to_string(), ElementwiseMulKernel::new(si));
248
249        // 16. SiLU forward activation (standalone, used in LoRA FFN path)
250        warm!("silu_forward".to_string(), SiluKernel::new(si));
251
252        // 17-20. NF4 quantized GEMM variants (trueno#108: QLoRA support)
253        // Same 4 GEMM shapes but with Nf4GemmKernel instead of GemmKernel.
254        // Only compiled if K is divisible by 64 (NF4 block size).
255        if h.is_multiple_of(64) {
256            // NF4 cache keys exclude M (seq_len) — PTX is shape-independent
257            // (m/n/k are runtime params). Including M causes cache misses when
258            // actual seq_len != max_seq_len, triggering on-demand JIT that fails
259            // after GPU memory is loaded (trueno#184).
260            //
261            // Attention projections use q_dim (= num_heads * head_dim) which may
262            // differ from hidden_size (e.g. Qwen3-4B: h=2560, q_dim=4096).
263            // Q proj: input[S,h] @ W_q[h, q_dim] — key {h}_{q_dim}
264            warm!(format!("nf4_gemm_forward_{h}_{q_dim}"), Nf4GemmKernel::new(s, q_dim, h));
265            // O proj: input[S,q_dim] @ W_o[q_dim, h] — key {q_dim}_{h}
266            if q_dim != h {
267                warm!(format!("nf4_gemm_forward_{q_dim}_{h}"), Nf4GemmKernel::new(s, h, q_dim));
268            }
269            if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
270                warm!(format!("nf4_gemm_forward_{h}_{kv_h}"), Nf4GemmKernel::new(s, kv_h, h));
271            }
272            if i.is_multiple_of(64) {
273                warm!(format!("nf4_gemm_forward_{h}_{i}"), Nf4GemmKernel::new(s, i, h));
274                warm!(format!("nf4_gemm_forward_{i}_{h}"), Nf4GemmKernel::new(s, h, i));
275            }
276        }
277
278        // PMAT-475: Fused NF4 Gate+Up GEMM for FFN (shared input load).
279        if h.is_multiple_of(64) && i.is_multiple_of(64) {
280            use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
281            warm!(format!("fused_nf4_gate_up_{h}_{i}"), FusedNf4GateUpGemmKernel::new(s, i, h));
282        }
283        // PMAT-478: Fused K+V GEMM for GQA attention (reuses Gate+Up kernel).
284        if h.is_multiple_of(64) && kv_h.is_multiple_of(64) && kv_h != i {
285            use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
286            warm!(
287                format!("fused_nf4_gate_up_{h}_{kv_h}"),
288                FusedNf4GateUpGemmKernel::new(s, kv_h, h)
289            );
290        }
291
292        // 19-22. NF4 transposed GEMM for QLoRA backward (ENT-153).
293        // C[M×K] = A[M×N] @ B[K×N]^T — gradient propagation through frozen NF4 layers.
294        if h.is_multiple_of(64) {
295            // Q proj backward: grad[S,q_dim] @ W_q[h, q_dim]^T → [S,h]
296            warm!(
297                format!("nf4_gemm_transpose_{q_dim}_{h}"),
298                Nf4GemmTransposeKernel::new(s, q_dim, h)
299            );
300            // O proj backward: grad[S,h] @ W_o[q_dim, h]^T → [S,q_dim]
301            if q_dim != h {
302                warm!(
303                    format!("nf4_gemm_transpose_{h}_{q_dim}"),
304                    Nf4GemmTransposeKernel::new(s, h, q_dim)
305                );
306            }
307            if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
308                // K/V proj backward: grad[S,kv_h] @ W_k[h, kv_h]^T → [S,h]
309                warm!(
310                    format!("nf4_gemm_transpose_{kv_h}_{h}"),
311                    Nf4GemmTransposeKernel::new(s, kv_h, h)
312                );
313            }
314            if i.is_multiple_of(64) {
315                // Gate/Up backward: grad[S,I] @ W_gate[h,I]^T → [S,h]
316                warm!(format!("nf4_gemm_transpose_{i}_{h}"), Nf4GemmTransposeKernel::new(s, i, h));
317                // Down backward: grad[S,h] @ W_down[I,h]^T → [S,I]
318                warm!(format!("nf4_gemm_transpose_{h}_{i}"), Nf4GemmTransposeKernel::new(s, h, i));
319            }
320        }
321
322        eprintln!("[CUDA] Pre-warmed {count} forward kernels (JIT compiled before block upload)");
323        Ok(())
324    }
325
326    /// Pre-warm LoRA backward GEMM kernels for QLoRA training (ENT-153).
327    ///
328    /// The LoRA backward uses regular fp32 GEMMs for:
329    /// - Forward LoRA: x @ A → [S, R], inter @ B → [S, proj_dim]
330    /// - Backward A: x^T @ grad_inter → grad_A [H, R]
331    /// - Backward B: inter^T @ grad_proj → grad_B [R, proj_dim]
332    /// - Backward input: grad_proj @ B^T → [S, R], then [S, R] @ A^T → [S, H]
333    ///
334    /// These shapes are small (rank << hidden_size) but must still be JIT-compiled.
335    pub(super) fn pre_warm_lora_backward(
336        &mut self,
337        hidden_size: usize,
338        q_dim: usize,
339        kv_hidden_size: usize,
340        max_seq_len: usize,
341        lora_rank: usize,
342    ) -> Result<()> {
343        if lora_rank == 0 {
344            return Ok(());
345        }
346
347        let s = max_seq_len as u32;
348        let h = hidden_size as u32;
349        let r = lora_rank as u32;
350        let qd = q_dim as u32;
351        let kv = kv_hidden_size as u32;
352
353        let mut count = 0u32;
354        let target = self.sm_target.clone();
355
356        macro_rules! warm {
357            ($key:expr, $kernel:expr) => {{
358                let ptx = $kernel.emit_ptx_for_target(&target);
359                self.get_or_compile(&$key, &ptx)?;
360                count += 1;
361            }};
362        }
363
364        // LoRA forward GEMMs (also needed in backward for activation checkpointing)
365        // x[S,H] @ A[H,R] → [S,R]
366        warm!(format!("gemm_forward_{s}_{h}_{r}"), GemmKernel::naive(s, r, h));
367        // inter[S,R] @ B[R,qd] → [S,qd]
368        warm!(format!("gemm_forward_{s}_{r}_{qd}"), GemmKernel::naive(s, qd, r));
369        // inter[S,R] @ B[R,kv] → [S,kv]
370        if kv != qd {
371            warm!(format!("gemm_forward_{s}_{r}_{kv}"), GemmKernel::naive(s, kv, r));
372        }
373
374        // LoRA backward GEMMs (gemm_backward_a and gemm_backward_b use regular GEMM shapes)
375        // grad_B = inter^T[R,S] @ grad_proj[S,qd] → [R,qd]
376        // This is a GEMM with M=R, N=qd, K=S
377        warm!(format!("gemm_forward_{r}_{s}_{qd}"), GemmKernel::naive(r, qd, s));
378        if kv != qd {
379            warm!(format!("gemm_forward_{r}_{s}_{kv}"), GemmKernel::naive(r, kv, s));
380        }
381
382        // grad_li = grad_proj[S,qd] @ B^T[qd,R] → [S,R]
383        // This is effectively GEMM with M=S, N=R, K=qd
384        warm!(format!("gemm_forward_{s}_{qd}_{r}"), GemmKernel::naive(s, r, qd));
385        if kv != qd {
386            warm!(format!("gemm_forward_{s}_{kv}_{r}"), GemmKernel::naive(s, r, kv));
387        }
388
389        // grad_A = x^T[H,S] @ grad_li[S,R] → [H,R]
390        warm!(format!("gemm_forward_{h}_{s}_{r}"), GemmKernel::naive(h, r, s));
391
392        // grad_input += grad_li[S,R] @ A^T[R,H] → [S,H]
393        warm!(format!("gemm_forward_{s}_{r}_{h}"), GemmKernel::naive(s, h, r));
394
395        eprintln!("[CUDA] Pre-warmed {count} LoRA backward kernels");
396        Ok(())
397    }
398}
399
400/// Initialize forward kernel cache with CUDA context
401#[cfg(feature = "cuda")]
402pub fn init_forward_kernel_cache(ctx: std::sync::Arc<CudaContext>) -> Result<()> {
403    FORWARD_KERNEL_CACHE.get_or_init(|| Mutex::new(ForwardKernelCache::new(ctx)));
404    Ok(())
405}
406/// Pre-allocate cuBLAS workspace for CUDA graph capture (PMAT-063).
407#[cfg(feature = "cuda")]
408pub fn set_cublas_workspace(ptr: u64, size: usize) -> Result<()> {
409    let c = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
410    let c = c.lock().map_err(|_| CudaTensorError::KernelError("lock".into()))?;
411    if let Some(h) = c.cublas() {
412        h.set_workspace(ptr, size).map_err(|e| CudaTensorError::KernelError(format!("{e}")))?;
413    }
414    Ok(())
415}
416/// Bind cuBLAS handle to a stream (ALB-075).
417#[cfg(feature = "cuda")]
418pub fn set_forward_cublas_stream(stream: &CudaStream) -> Result<()> {
419    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
420    let cache = cache.lock().map_err(|_err| {
421        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
422    })?;
423    cache.set_cublas_stream(stream)
424}
425
426/// Pre-warm forward kernels (C-PREWARM-001: JIT before block upload).
427#[cfg(feature = "cuda")]
428pub fn pre_warm_forward_kernels(
429    hidden_size: usize,
430    intermediate_size: usize,
431    num_heads: usize,
432    num_kv_heads: usize,
433    head_dim: usize,
434    max_seq_len: usize,
435) -> Result<()> {
436    // trueno#200: Pre-warm backward kernels too (Blackwell JIT crash workaround)
437    pre_warm_backward_kernels_in_forward_cache(num_heads, num_kv_heads, head_dim, max_seq_len)?;
438    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
439    let mut cache = cache.lock().map_err(|_err| {
440        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
441    })?;
442    cache.pre_warm_for_model(
443        hidden_size,
444        intermediate_size,
445        num_heads,
446        num_kv_heads,
447        head_dim,
448        max_seq_len,
449    )
450}
451
452/// Pre-warm backward kernels in forward cache (trueno#200 Blackwell).
453///
454/// CONTRACT: All backward kernels must be compiled before GPU work starts.
455/// On Blackwell (sm_121), cuModuleLoadData fails during active GPU computation.
456#[cfg(feature = "cuda")]
457fn pre_warm_backward_kernels_in_forward_cache(
458    num_heads: usize,
459    _num_kv_heads: usize,
460    head_dim: usize,
461    max_seq_len: usize,
462) -> Result<()> {
463    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
464    let mut cache = cache.lock().map_err(|_err| {
465        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
466    })?;
467
468    let target = cache.sm_target.clone();
469    let _nh = num_heads as u32;
470    let _hd = head_dim as u32;
471    let _s = max_seq_len as u32;
472
473    macro_rules! warm {
474        ($key:expr, $kernel:expr) => {{
475            let ptx = $kernel.emit_ptx_for_target(&target);
476            cache.get_or_compile(&$key, &ptx)?;
477        }};
478    }
479
480    // Batched RoPE backward — missing from pre_warm_for_model, causes
481    // CUDA context poisoning on Blackwell when compiled during backward pass.
482    // Need BOTH num_heads AND num_kv_heads variants (GQA uses different head count for K/V).
483    //
484    // FALSIFY-CUDA-ROPE-THETA-CACHE-KEY-001: cache key now includes theta_bits
485    // (matching runtime in `batched_rope_neox_backward`). The hardcoded
486    // 1_000_000.0 here matches Qwen2 / Qwen2.5 default; for Llama
487    // pretrain (theta=10000) the runtime call will compile its own
488    // module on first use, no longer silently shadowing the Qwen warm.
489    let nh = num_heads as u32;
490    let nkv = _num_kv_heads as u32;
491    let hd = head_dim as u32;
492    let s = max_seq_len as u32;
493    let qwen_theta_bits = 1_000_000.0_f32.to_bits();
494    warm!(
495        format!("batched_rope_bwd_{nh}_{hd}_{s}_th{qwen_theta_bits:08x}"),
496        BatchedRopeBackwardKernel::new(nh, hd, s, 1_000_000.0)
497    );
498    if nkv != nh {
499        warm!(
500            format!("batched_rope_bwd_{nkv}_{hd}_{s}_th{qwen_theta_bits:08x}"),
501            BatchedRopeBackwardKernel::new(nkv, hd, s, 1_000_000.0)
502        );
503    }
504
505    eprintln!("  ✓ Backward rope kernel pre-warmed in forward cache");
506    Ok(())
507}
508
509/// Pre-warm LoRA backward GEMM kernels for QLoRA training (ENT-153).
510///
511/// Must be called BEFORE uploading transformer blocks. Compiles the
512/// small-matrix GEMMs needed for LoRA gradient computation.
513#[cfg(feature = "cuda")]
514pub fn pre_warm_lora_backward_kernels(
515    hidden_size: usize,
516    q_dim: usize,
517    kv_hidden_size: usize,
518    max_seq_len: usize,
519    lora_rank: usize,
520) -> Result<()> {
521    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
522    let mut cache = cache.lock().map_err(|_err| {
523        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
524    })?;
525    cache.pre_warm_lora_backward(hidden_size, q_dim, kv_hidden_size, max_seq_len, lora_rank)
526}