entrenar/autograd/cuda_forward/
cache.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13#[cfg(feature = "cuda")]
14use trueno_gpu::kernels::{
15 Batched4DGemmKernel, BatchedRopeBackwardKernel, BatchedSoftmaxKernel,
16 BatchedToInterleavedKernel, BatchedTransposeKernel, BatchedVectorizedRmsNormKernel,
17 ElementwiseMulKernel, FusedSwigluKernel, GemmKernel, InterleavedToBatchedKernel, Kernel,
18 Nf4GemmKernel, Nf4GemmTransposeKernel, ResidualAddKernel, ScaleKernel, SiluKernel,
19};
20
21use crate::autograd::cuda_tensor::{CudaTensorError, Result};
22
23#[cfg(feature = "cuda")]
25pub(super) static FORWARD_KERNEL_CACHE: OnceLock<Mutex<ForwardKernelCache>> = OnceLock::new();
26
27#[cfg(feature = "cuda")]
37pub(super) struct ForwardKernelCache {
38 ctx: std::sync::Arc<CudaContext>,
39 modules: HashMap<String, CudaModule>,
40 sm_target: String,
42 cublas: Option<CublasHandle>,
44}
45
46#[cfg(feature = "cuda")]
47impl ForwardKernelCache {
48 pub(super) fn new(ctx: std::sync::Arc<CudaContext>) -> Self {
49 let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
53
54 let cublas = match CublasHandle::new_with_tensor_cores(&ctx) {
57 Ok(handle) => {
58 eprintln!("[CUDA] cuBLAS initialized — forward TF32 tensor cores (41x vs SIMD)");
59 Some(handle)
60 }
61 Err(e) => {
62 eprintln!("[CUDA] cuBLAS not available ({e:?}), using PTX GEMMs");
63 None
64 }
65 };
66
67 eprintln!("[CUDA] Kernel cache initialized for target: {sm_target}");
68 Self { ctx, modules: HashMap::new(), sm_target, cublas }
69 }
70
71 pub(super) fn cublas(&self) -> Option<&CublasHandle> {
73 self.cublas.as_ref()
74 }
75
76 pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
78 if let Some(ref handle) = self.cublas {
79 handle.set_stream(stream).map_err(|e| {
80 CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
81 })?;
82 }
83 Ok(())
84 }
85
86 pub(super) fn sm_target(&self) -> &str {
90 &self.sm_target
91 }
92
93 pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
99 self.modules.get_mut(name)
100 }
101
102 pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
109 use std::collections::hash_map::Entry;
110
111 if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
113 let ptx_target = target_line.trim().trim_start_matches(".target ");
114 if ptx_target != self.sm_target {
115 return Err(CudaTensorError::KernelError(format!(
116 "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'. \
117 Use kernel.emit_ptx_for_target(\"{}\") instead of emit_ptx().",
118 self.sm_target, self.sm_target
119 )));
120 }
121 }
122
123 match self.modules.entry(name.to_string()) {
124 Entry::Occupied(e) => Ok(e.into_mut()),
125 Entry::Vacant(e) => {
126 let (major, _) = self.ctx.compute_capability().map_err(|e| {
128 CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
129 })?;
130 let module = if major >= 12 {
131 CudaModule::from_ptx_direct(&self.ctx, ptx)
132 } else {
133 CudaModule::from_ptx(&self.ctx, ptx)
134 }
135 .map_err(|err| {
136 CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
137 })?;
138 Ok(e.insert(module))
139 }
140 }
141 }
142
143 pub(super) fn pre_warm_for_model(
156 &mut self,
157 hidden_size: usize,
158 intermediate_size: usize,
159 num_heads: usize,
160 num_kv_heads: usize,
161 head_dim: usize,
162 max_seq_len: usize,
163 ) -> Result<()> {
164 let s = max_seq_len as u32;
165 let h = hidden_size as u32;
166 let q_dim = (num_heads * head_dim) as u32; let kv_h = (num_kv_heads * head_dim) as u32;
168 let i = intermediate_size as u32;
169 let nh = num_heads as u32;
170 let _nkv = num_kv_heads as u32;
171 let hd = head_dim as u32;
172 let sh = s * h; let si = s * i; let mut count = 0u32;
176 let target = self.sm_target.clone();
177
178 macro_rules! warm {
180 ($key:expr, $kernel:expr) => {{
181 let ptx = $kernel.emit_ptx_for_target(&target);
182 self.get_or_compile("silu_forward", &ptx)?;
183 count += 1;
184 }};
185 }
186
187 warm!(format!("batched_rmsnorm_fwd_{h}"), BatchedVectorizedRmsNormKernel::new(h, 1));
190
191 warm!(format!("gemm_forward_{s}_{h}_{h}"), GemmKernel::naive(s, h, h));
193
194 if kv_h != h {
196 warm!(format!("gemm_forward_{s}_{h}_{kv_h}"), GemmKernel::naive(s, kv_h, h));
197 }
198
199 warm!(format!("gemm_forward_{s}_{h}_{i}"), GemmKernel::naive(s, i, h));
201
202 warm!(format!("gemm_forward_{s}_{i}_{h}"), GemmKernel::naive(s, h, i));
204
205 warm!("fused_swiglu_forward".to_string(), FusedSwigluKernel::new(si));
207
208 warm!("residual_add_forward".to_string(), ResidualAddKernel::new(sh));
210
211 warm!("interleaved_to_batched".to_string(), InterleavedToBatchedKernel::new(s, nh, hd));
213
214 warm!("batched_transpose".to_string(), BatchedTransposeKernel::new(nh, s, hd));
216
217 warm!(
219 format!("batched_4d_gemm_1_{nh}_{s}_{s}_{hd}"),
220 Batched4DGemmKernel::new(1, nh, s, s, hd)
221 );
222
223 let score_n = nh * s * s;
225 warm!("scale_forward".to_string(), ScaleKernel::new(score_n));
226
227 let softmax_rows = nh * s;
229 warm!("batched_softmax_forward".to_string(), BatchedSoftmaxKernel::new(softmax_rows, s));
230
231 warm!(
233 format!("batched_4d_gemm_1_{nh}_{s}_{hd}_{s}"),
234 Batched4DGemmKernel::new(1, nh, s, hd, s)
235 );
236
237 warm!(
239 format!("batched_4d_gemm_1_{nh}_{hd}_{s}_{s}"),
240 Batched4DGemmKernel::new(1, nh, hd, s, s)
241 );
242
243 warm!("batched_to_interleaved".to_string(), BatchedToInterleavedKernel::new(s, nh, hd));
245
246 warm!("elementwise_mul_forward".to_string(), ElementwiseMulKernel::new(si));
248
249 warm!("silu_forward".to_string(), SiluKernel::new(si));
251
252 if h.is_multiple_of(64) {
256 warm!(format!("nf4_gemm_forward_{h}_{q_dim}"), Nf4GemmKernel::new(s, q_dim, h));
265 if q_dim != h {
267 warm!(format!("nf4_gemm_forward_{q_dim}_{h}"), Nf4GemmKernel::new(s, h, q_dim));
268 }
269 if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
270 warm!(format!("nf4_gemm_forward_{h}_{kv_h}"), Nf4GemmKernel::new(s, kv_h, h));
271 }
272 if i.is_multiple_of(64) {
273 warm!(format!("nf4_gemm_forward_{h}_{i}"), Nf4GemmKernel::new(s, i, h));
274 warm!(format!("nf4_gemm_forward_{i}_{h}"), Nf4GemmKernel::new(s, h, i));
275 }
276 }
277
278 if h.is_multiple_of(64) && i.is_multiple_of(64) {
280 use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
281 warm!(format!("fused_nf4_gate_up_{h}_{i}"), FusedNf4GateUpGemmKernel::new(s, i, h));
282 }
283 if h.is_multiple_of(64) && kv_h.is_multiple_of(64) && kv_h != i {
285 use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
286 warm!(
287 format!("fused_nf4_gate_up_{h}_{kv_h}"),
288 FusedNf4GateUpGemmKernel::new(s, kv_h, h)
289 );
290 }
291
292 if h.is_multiple_of(64) {
295 warm!(
297 format!("nf4_gemm_transpose_{q_dim}_{h}"),
298 Nf4GemmTransposeKernel::new(s, q_dim, h)
299 );
300 if q_dim != h {
302 warm!(
303 format!("nf4_gemm_transpose_{h}_{q_dim}"),
304 Nf4GemmTransposeKernel::new(s, h, q_dim)
305 );
306 }
307 if kv_h != h && kv_h != q_dim && kv_h.is_multiple_of(64) {
308 warm!(
310 format!("nf4_gemm_transpose_{kv_h}_{h}"),
311 Nf4GemmTransposeKernel::new(s, kv_h, h)
312 );
313 }
314 if i.is_multiple_of(64) {
315 warm!(format!("nf4_gemm_transpose_{i}_{h}"), Nf4GemmTransposeKernel::new(s, i, h));
317 warm!(format!("nf4_gemm_transpose_{h}_{i}"), Nf4GemmTransposeKernel::new(s, h, i));
319 }
320 }
321
322 eprintln!("[CUDA] Pre-warmed {count} forward kernels (JIT compiled before block upload)");
323 Ok(())
324 }
325
326 pub(super) fn pre_warm_lora_backward(
336 &mut self,
337 hidden_size: usize,
338 q_dim: usize,
339 kv_hidden_size: usize,
340 max_seq_len: usize,
341 lora_rank: usize,
342 ) -> Result<()> {
343 if lora_rank == 0 {
344 return Ok(());
345 }
346
347 let s = max_seq_len as u32;
348 let h = hidden_size as u32;
349 let r = lora_rank as u32;
350 let qd = q_dim as u32;
351 let kv = kv_hidden_size as u32;
352
353 let mut count = 0u32;
354 let target = self.sm_target.clone();
355
356 macro_rules! warm {
357 ($key:expr, $kernel:expr) => {{
358 let ptx = $kernel.emit_ptx_for_target(&target);
359 self.get_or_compile(&$key, &ptx)?;
360 count += 1;
361 }};
362 }
363
364 warm!(format!("gemm_forward_{s}_{h}_{r}"), GemmKernel::naive(s, r, h));
367 warm!(format!("gemm_forward_{s}_{r}_{qd}"), GemmKernel::naive(s, qd, r));
369 if kv != qd {
371 warm!(format!("gemm_forward_{s}_{r}_{kv}"), GemmKernel::naive(s, kv, r));
372 }
373
374 warm!(format!("gemm_forward_{r}_{s}_{qd}"), GemmKernel::naive(r, qd, s));
378 if kv != qd {
379 warm!(format!("gemm_forward_{r}_{s}_{kv}"), GemmKernel::naive(r, kv, s));
380 }
381
382 warm!(format!("gemm_forward_{s}_{qd}_{r}"), GemmKernel::naive(s, r, qd));
385 if kv != qd {
386 warm!(format!("gemm_forward_{s}_{kv}_{r}"), GemmKernel::naive(s, r, kv));
387 }
388
389 warm!(format!("gemm_forward_{h}_{s}_{r}"), GemmKernel::naive(h, r, s));
391
392 warm!(format!("gemm_forward_{s}_{r}_{h}"), GemmKernel::naive(s, h, r));
394
395 eprintln!("[CUDA] Pre-warmed {count} LoRA backward kernels");
396 Ok(())
397 }
398}
399
400#[cfg(feature = "cuda")]
402pub fn init_forward_kernel_cache(ctx: std::sync::Arc<CudaContext>) -> Result<()> {
403 FORWARD_KERNEL_CACHE.get_or_init(|| Mutex::new(ForwardKernelCache::new(ctx)));
404 Ok(())
405}
406#[cfg(feature = "cuda")]
408pub fn set_cublas_workspace(ptr: u64, size: usize) -> Result<()> {
409 let c = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
410 let c = c.lock().map_err(|_| CudaTensorError::KernelError("lock".into()))?;
411 if let Some(h) = c.cublas() {
412 h.set_workspace(ptr, size).map_err(|e| CudaTensorError::KernelError(format!("{e}")))?;
413 }
414 Ok(())
415}
416#[cfg(feature = "cuda")]
418pub fn set_forward_cublas_stream(stream: &CudaStream) -> Result<()> {
419 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
420 let cache = cache.lock().map_err(|_err| {
421 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
422 })?;
423 cache.set_cublas_stream(stream)
424}
425
426#[cfg(feature = "cuda")]
428pub fn pre_warm_forward_kernels(
429 hidden_size: usize,
430 intermediate_size: usize,
431 num_heads: usize,
432 num_kv_heads: usize,
433 head_dim: usize,
434 max_seq_len: usize,
435) -> Result<()> {
436 pre_warm_backward_kernels_in_forward_cache(num_heads, num_kv_heads, head_dim, max_seq_len)?;
438 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
439 let mut cache = cache.lock().map_err(|_err| {
440 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
441 })?;
442 cache.pre_warm_for_model(
443 hidden_size,
444 intermediate_size,
445 num_heads,
446 num_kv_heads,
447 head_dim,
448 max_seq_len,
449 )
450}
451
452#[cfg(feature = "cuda")]
457fn pre_warm_backward_kernels_in_forward_cache(
458 num_heads: usize,
459 _num_kv_heads: usize,
460 head_dim: usize,
461 max_seq_len: usize,
462) -> Result<()> {
463 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
464 let mut cache = cache.lock().map_err(|_err| {
465 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
466 })?;
467
468 let target = cache.sm_target.clone();
469 let _nh = num_heads as u32;
470 let _hd = head_dim as u32;
471 let _s = max_seq_len as u32;
472
473 macro_rules! warm {
474 ($key:expr, $kernel:expr) => {{
475 let ptx = $kernel.emit_ptx_for_target(&target);
476 cache.get_or_compile(&$key, &ptx)?;
477 }};
478 }
479
480 let nh = num_heads as u32;
484 let nkv = _num_kv_heads as u32;
485 let hd = head_dim as u32;
486 let s = max_seq_len as u32;
487 warm!(
488 format!("batched_rope_bwd_{nh}_{hd}"),
489 BatchedRopeBackwardKernel::new(nh, hd, s, 1_000_000.0)
490 );
491 if nkv != nh {
492 warm!(
493 format!("batched_rope_bwd_{nkv}_{hd}"),
494 BatchedRopeBackwardKernel::new(nkv, hd, s, 1_000_000.0)
495 );
496 }
497
498 eprintln!(" ✓ Backward rope kernel pre-warmed in forward cache");
499 Ok(())
500}
501
502#[cfg(feature = "cuda")]
507pub fn pre_warm_lora_backward_kernels(
508 hidden_size: usize,
509 q_dim: usize,
510 kv_hidden_size: usize,
511 max_seq_len: usize,
512 lora_rank: usize,
513) -> Result<()> {
514 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
515 let mut cache = cache.lock().map_err(|_err| {
516 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
517 })?;
518 cache.pre_warm_lora_backward(hidden_size, q_dim, kv_hidden_size, max_seq_len, lora_rank)
519}