1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CublasHandle, CudaStream, GemmOp, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10 Batched4DGemmKernel, FusedSwigluKernel, GemmKernel, Kernel, Nf4GemmKernel,
11 Nf4GemmTransposeKernel, Nf4TensorCoreGemmKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19#[cfg(feature = "cuda")]
24pub fn fused_swiglu_forward(
25 gate: &GpuBuffer<f32>,
26 up: &GpuBuffer<f32>,
27 output: &mut GpuBuffer<f32>,
28 n: u32,
29 stream: &CudaStream,
30) -> Result<()> {
31 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
32 let mut cache = cache.lock().map_err(|_err| {
33 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
34 })?;
35
36 let key = "fused_swiglu_forward".to_string(); let module = match cache.get_cached(&key) {
38 Some(m) => m,
39 None => {
40 let kernel = FusedSwigluKernel::new(n);
41 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
42 cache.get_or_compile(&key, &ptx)?
43 }
44 };
45
46 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
47
48 let gate_ptr = gate.as_ptr();
49 let up_ptr = up.as_ptr();
50 let output_ptr = output.as_ptr();
51
52 let mut args: [*mut std::ffi::c_void; 4] = [
53 &gate_ptr as *const _ as *mut _,
54 &up_ptr as *const _ as *mut _,
55 &output_ptr as *const _ as *mut _,
56 &n as *const _ as *mut _,
57 ];
58
59 unsafe {
62 stream.launch_kernel(module, "fused_swiglu", &config, &mut args).map_err(|e| {
63 CudaTensorError::KernelError(format!("Fused SwiGLU forward launch failed: {e:?}"))
64 })?;
65 }
66
67 Ok(())
68}
69
70#[cfg(feature = "cuda")]
77pub fn gemm_forward(
78 a: &GpuBuffer<f32>,
79 b: &GpuBuffer<f32>,
80 c: &mut GpuBuffer<f32>,
81 m: u32,
82 k: u32,
83 n: u32,
84 stream: &CudaStream,
85) -> Result<()> {
86 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
87 let mut cache = cache.lock().map_err(|_err| {
88 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
89 })?;
90 if let Some(cublas) = cache.cublas() {
91 return cublas_gemm_forward(cublas, a, b, c, m, k, n);
92 }
93
94 let key = format!("gemm_forward_{m}_{k}_{n}");
96 let module = match cache.get_cached(&key) {
97 Some(m) => m,
98 None => {
99 let kernel = GemmKernel::naive(m, n, k);
100 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
101 cache.get_or_compile(&key, &ptx)?
102 }
103 };
104
105 let config = LaunchConfig {
109 grid: (n.div_ceil(16), m.div_ceil(16), 1),
110 block: (16, 16, 1),
111 shared_mem: 0,
112 };
113
114 let a_ptr = a.as_ptr();
115 let b_ptr = b.as_ptr();
116 let c_ptr = c.as_ptr();
117
118 let mut args: [*mut std::ffi::c_void; 6] = [
121 &a_ptr as *const _ as *mut _,
122 &b_ptr as *const _ as *mut _,
123 &c_ptr as *const _ as *mut _,
124 &m as *const _ as *mut _,
125 &n as *const _ as *mut _,
126 &k as *const _ as *mut _,
127 ];
128
129 unsafe {
132 stream.launch_kernel(module, "gemm_naive", &config, &mut args).map_err(|e| {
133 CudaTensorError::KernelError(format!("GEMM forward launch failed: {e:?}"))
134 })?;
135 }
136
137 Ok(())
138}
139
140#[cfg(feature = "cuda")]
143pub fn gemm_forward_bt(
144 a: &GpuBuffer<f32>,
145 b: &GpuBuffer<f32>,
146 c: &mut GpuBuffer<f32>,
147 m: u32,
148 k: u32,
149 n: u32,
150 _stream: &CudaStream,
151) -> Result<()> {
152 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
153 let cache = cache.lock().map_err(|_| CudaTensorError::KernelError("cache lock".to_string()))?;
154 if let Some(cublas) = cache.cublas() {
155 return cublas_gemm_forward_bt(cublas, a, b, c, m, k, n);
156 }
157 Err(CudaTensorError::KernelError("gemm_forward_bt requires cuBLAS".to_string()))
158}
159
160#[cfg(feature = "cuda")]
161fn cublas_gemm_forward_bt(
162 cublas: &CublasHandle,
163 a: &GpuBuffer<f32>,
164 b: &GpuBuffer<f32>,
165 c: &mut GpuBuffer<f32>,
166 m: u32,
167 k: u32,
168 n: u32,
169) -> Result<()> {
170 cublas
173 .gemm_f32(
174 GemmOp::Trans, GemmOp::NoTrans, n as i32,
177 m as i32,
178 k as i32,
179 1.0,
180 b.as_ptr(),
181 k as i32, a.as_ptr(),
183 k as i32, 0.0,
185 c.as_ptr(),
186 n as i32, )
188 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM BT failed: {e:?}")))
189}
190
191#[cfg(feature = "cuda")]
193fn cublas_gemm_forward(
194 cublas: &CublasHandle,
195 a: &GpuBuffer<f32>,
196 b: &GpuBuffer<f32>,
197 c: &mut GpuBuffer<f32>,
198 m: u32,
199 k: u32,
200 n: u32,
201) -> Result<()> {
202 cublas
203 .gemm_f32(
204 GemmOp::NoTrans,
205 GemmOp::NoTrans,
206 n as i32,
207 m as i32,
208 k as i32,
209 1.0,
210 b.as_ptr(),
211 n as i32,
212 a.as_ptr(),
213 k as i32,
214 0.0,
215 c.as_ptr(),
216 n as i32,
217 )
218 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM forward failed: {e:?}")))
219}
220
221#[cfg(feature = "cuda")]
223pub(crate) fn cublas_gemm_backward_a(
224 cublas: &CublasHandle,
225 grad_output: &GpuBuffer<f32>,
226 b: &GpuBuffer<f32>,
227 grad_a: &mut GpuBuffer<f32>,
228 m: u32,
229 k: u32,
230 n: u32,
231) -> Result<()> {
232 cublas
233 .gemm_f32(
234 GemmOp::Trans,
235 GemmOp::NoTrans,
236 k as i32,
237 m as i32,
238 n as i32,
239 1.0,
240 b.as_ptr(),
241 n as i32,
242 grad_output.as_ptr(),
243 n as i32,
244 0.0,
245 grad_a.as_ptr(),
246 k as i32,
247 )
248 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM backward_a failed: {e:?}")))
249}
250
251#[cfg(feature = "cuda")]
257pub(crate) fn cublas_gemm_backward_a_accumulate(
258 cublas: &CublasHandle,
259 grad_output: &GpuBuffer<f32>,
260 b: &GpuBuffer<f32>,
261 grad_a: &mut GpuBuffer<f32>,
262 m: u32,
263 k: u32,
264 n: u32,
265) -> Result<()> {
266 cublas
267 .gemm_f32(
268 GemmOp::Trans,
269 GemmOp::NoTrans,
270 k as i32,
271 m as i32,
272 n as i32,
273 1.0,
274 b.as_ptr(),
275 n as i32,
276 grad_output.as_ptr(),
277 n as i32,
278 1.0, grad_a.as_ptr(),
280 k as i32,
281 )
282 .map_err(|e| {
283 CudaTensorError::KernelError(format!("cuBLAS GEMM backward_a accumulate failed: {e:?}"))
284 })
285}
286
287#[cfg(feature = "cuda")]
289pub(crate) fn cublas_gemm_backward_b(
290 cublas: &CublasHandle,
291 a: &GpuBuffer<f32>,
292 grad_output: &GpuBuffer<f32>,
293 grad_b: &mut GpuBuffer<f32>,
294 m: u32,
295 k: u32,
296 n: u32,
297) -> Result<()> {
298 cublas
299 .gemm_f32(
300 GemmOp::NoTrans,
301 GemmOp::Trans,
302 n as i32,
303 k as i32,
304 m as i32,
305 1.0,
306 grad_output.as_ptr(),
307 n as i32,
308 a.as_ptr(),
309 k as i32,
310 0.0,
311 grad_b.as_ptr(),
312 n as i32,
313 )
314 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM backward_b failed: {e:?}")))
315}
316
317#[cfg(feature = "cuda")]
329pub fn batched_4d_gemm_forward(
330 a: &GpuBuffer<f32>,
331 b: &GpuBuffer<f32>,
332 c: &mut GpuBuffer<f32>,
333 batch: u32,
334 heads: u32,
335 m: u32,
336 n: u32,
337 k: u32,
338 stream: &CudaStream,
339) -> Result<()> {
340 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
341 let mut cache = cache.lock().map_err(|_err| {
342 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
343 })?;
344
345 if let Some(cublas) = cache.cublas() {
347 let batch_count = (batch * heads) as i32;
348 let stride_a = i64::from(m) * i64::from(k);
349 let stride_b = i64::from(k) * i64::from(n);
350 let stride_c = i64::from(m) * i64::from(n);
351 return cublas
352 .gemm_f32_strided_batched_row_major(
353 m as i32,
354 n as i32,
355 k as i32,
356 1.0,
357 a.as_ptr(),
358 stride_a,
359 b.as_ptr(),
360 stride_b,
361 0.0,
362 c.as_ptr(),
363 stride_c,
364 batch_count,
365 )
366 .map_err(|e| {
367 CudaTensorError::KernelError(format!("cuBLAS batched 4D GEMM failed: {e:?}"))
368 });
369 }
370
371 let kernel = Batched4DGemmKernel::new(batch, heads, m, n, k);
372 let tile_size = kernel.config.tile_size;
373
374 let key = format!("batched_4d_gemm_{batch}_{heads}_{m}_{n}_{k}");
375 let module = match cache.get_cached(&key) {
376 Some(m) => m,
377 None => {
378 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
379 cache.get_or_compile(&key, &ptx)?
380 }
381 };
382
383 let config = LaunchConfig {
387 grid: (n.div_ceil(tile_size), m.div_ceil(tile_size), batch * heads),
388 block: (tile_size, tile_size, 1),
389 shared_mem: tile_size * tile_size * 4 * 2,
390 };
391
392 let a_ptr = a.as_ptr();
393 let b_ptr = b.as_ptr();
394 let c_ptr = c.as_ptr();
395
396 let mut args: [*mut std::ffi::c_void; 8] = [
398 &a_ptr as *const _ as *mut _,
399 &b_ptr as *const _ as *mut _,
400 &c_ptr as *const _ as *mut _,
401 &batch as *const _ as *mut _,
402 &heads as *const _ as *mut _,
403 &m as *const _ as *mut _,
404 &n as *const _ as *mut _,
405 &k as *const _ as *mut _,
406 ];
407
408 unsafe {
411 stream.launch_kernel(module, "batched_4d_gemm", &config, &mut args).map_err(|e| {
412 CudaTensorError::KernelError(format!("Batched 4D GEMM forward launch failed: {e:?}"))
413 })?;
414 }
415
416 Ok(())
417}
418
419#[cfg(feature = "cuda")]
433pub fn gemm_nf4_forward(
434 a: &GpuBuffer<f32>,
435 b_nf4: &GpuBuffer<u8>,
436 b_scales: &GpuBuffer<f32>,
437 c: &mut GpuBuffer<f32>,
438 m: u32,
439 k: u32,
440 n: u32,
441 stream: &CudaStream,
442) -> Result<()> {
443 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
444 let mut cache = cache.lock().map_err(|_err| {
445 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
446 })?;
447
448 let kernel = Nf4GemmKernel::new(m, n, k);
449 let tile_size = kernel.tile_size;
450
451 let key = format!("nf4_gemm_forward_{k}_{n}");
456 let module = match cache.get_cached(&key) {
457 Some(m) => m,
458 None => {
459 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
460 cache.get_or_compile(&key, &ptx)?
461 }
462 };
463
464 let config = LaunchConfig {
466 grid: (n.div_ceil(tile_size), m.div_ceil(tile_size), 1),
467 block: (tile_size * tile_size, 1, 1),
468 shared_mem: 16 * 4, };
470
471 let a_ptr = a.as_ptr();
472 let b_nf4_ptr = b_nf4.as_ptr();
473 let b_scales_ptr = b_scales.as_ptr();
474 let c_ptr = c.as_ptr();
475
476 let mut args: [*mut std::ffi::c_void; 7] = [
479 &a_ptr as *const _ as *mut _,
480 &b_nf4_ptr as *const _ as *mut _,
481 &b_scales_ptr as *const _ as *mut _,
482 &c_ptr as *const _ as *mut _,
483 &m as *const _ as *mut _,
484 &n as *const _ as *mut _,
485 &k as *const _ as *mut _,
486 ];
487
488 unsafe {
491 stream.launch_kernel(module, "nf4_gemm_fused", &config, &mut args).map_err(|e| {
492 CudaTensorError::KernelError(format!("NF4 GEMM forward launch failed: {e:?}"))
493 })?;
494 }
495
496 Ok(())
497}
498
499#[cfg(feature = "cuda")]
506pub fn gemm_nf4_tc_forward(
507 a: &GpuBuffer<f32>,
508 b_nf4: &GpuBuffer<u8>,
509 b_scales: &GpuBuffer<f32>,
510 c: &mut GpuBuffer<f32>,
511 m: u32,
512 k: u32,
513 n: u32,
514 stream: &CudaStream,
515) -> Result<()> {
516 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
517 let mut cache = cache.lock().map_err(|_err| {
518 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
519 })?;
520
521 let kernel = Nf4TensorCoreGemmKernel::new(m, n, k);
522
523 let key = format!("nf4_tc_gemm_forward_{k}_{n}");
524 let module = match cache.get_cached(&key) {
525 Some(m) => m,
526 None => {
527 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
528 cache.get_or_compile(&key, &ptx)?
529 }
530 };
531
532 let config = LaunchConfig {
534 grid: (n.div_ceil(16), m.div_ceil(16), 1),
535 block: (32, 1, 1),
536 shared_mem: 16 * 16 * 2 * 2, };
538
539 let a_ptr = a.as_ptr();
540 let b_nf4_ptr = b_nf4.as_ptr();
541 let b_scales_ptr = b_scales.as_ptr();
542 let c_ptr = c.as_ptr();
543
544 let mut args: [*mut std::ffi::c_void; 7] = [
546 &a_ptr as *const _ as *mut _,
547 &b_scales_ptr as *const _ as *mut _,
548 &b_nf4_ptr as *const _ as *mut _,
549 &c_ptr as *const _ as *mut _,
550 &m as *const _ as *mut _,
551 &n as *const _ as *mut _,
552 &k as *const _ as *mut _,
553 ];
554
555 unsafe {
556 stream.launch_kernel(module, "nf4_tensor_core_gemm", &config, &mut args).map_err(|e| {
557 CudaTensorError::KernelError(format!(
558 "NF4 tensor core GEMM forward launch failed: {e:?}"
559 ))
560 })?;
561 }
562
563 Ok(())
564}
565
566pub fn gemm_nf4_gate_up_forward(
574 a: &GpuBuffer<f32>,
575 wg_nf4: &GpuBuffer<u8>,
576 wg_scales: &GpuBuffer<f32>,
577 wu_nf4: &GpuBuffer<u8>,
578 wu_scales: &GpuBuffer<f32>,
579 gate: &mut GpuBuffer<f32>,
580 up: &mut GpuBuffer<f32>,
581 m: u32,
582 k: u32,
583 n: u32,
584 stream: &CudaStream,
585) -> Result<()> {
586 use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
587
588 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
589 let mut cache = cache.lock().map_err(|_err| {
590 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
591 })?;
592
593 let kernel = FusedNf4GateUpGemmKernel::new(m, n, k);
594 let tile = kernel.tile_size;
595 let key = format!("fused_nf4_gate_up_{k}_{n}");
596 let module = match cache.get_cached(&key) {
597 Some(m) => m,
598 None => {
599 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
600 cache.get_or_compile(&key, &ptx)?
601 }
602 };
603
604 let config = LaunchConfig {
605 grid: (n.div_ceil(tile), m.div_ceil(tile), 1),
606 block: (tile * tile, 1, 1),
607 shared_mem: 16 * 4,
608 };
609
610 let a_ptr = a.as_ptr();
611 let gate_ptr = gate.as_ptr();
612 let up_ptr = up.as_ptr();
613 let wg_nf4_ptr = wg_nf4.as_ptr();
614 let wg_scales_ptr = wg_scales.as_ptr();
615 let wu_nf4_ptr = wu_nf4.as_ptr();
616 let wu_scales_ptr = wu_scales.as_ptr();
617
618 let mut args: [*mut std::ffi::c_void; 10] = [
619 &gate_ptr as *const _ as *mut _,
620 &up_ptr as *const _ as *mut _,
621 &a_ptr as *const _ as *mut _,
622 &wg_scales_ptr as *const _ as *mut _,
623 &wg_nf4_ptr as *const _ as *mut _,
624 &wu_scales_ptr as *const _ as *mut _,
625 &wu_nf4_ptr as *const _ as *mut _,
626 &m as *const _ as *mut _,
627 &n as *const _ as *mut _,
628 &k as *const _ as *mut _,
629 ];
630
631 unsafe {
632 stream.launch_kernel(module, "fused_nf4_gate_up_gemm", &config, &mut args).map_err(
633 |e| CudaTensorError::KernelError(format!("Fused NF4 gate+up launch: {e:?}")),
634 )?;
635 }
636
637 Ok(())
638}
639
640#[cfg(feature = "cuda")]
658pub fn gemm_forward_bf16(
659 a: &GpuBuffer<f32>,
660 b: &GpuBuffer<f32>,
661 c: &mut GpuBuffer<f32>,
662 m: u32,
663 k: u32,
664 n: u32,
665 stream: &CudaStream,
666) -> Result<()> {
667 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
668 let mut cache = cache.lock().map_err(|_err| {
669 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
670 })?;
671
672 let key = format!("gemm_bf16_compute_{m}_{k}_{n}");
673 let module = match cache.get_cached(&key) {
674 Some(m) => m,
675 None => {
676 let ptx = build_gemm_bf16_compute_ptx(cache.sm_target());
677 cache.get_or_compile(&key, &ptx)?
678 }
679 };
680
681 let config = LaunchConfig {
682 grid: (n.div_ceil(16), m.div_ceil(16), 1),
683 block: (16, 16, 1),
684 shared_mem: 0,
685 };
686
687 let a_ptr = a.as_ptr();
688 let b_ptr = b.as_ptr();
689 let c_ptr = c.as_ptr();
690
691 let mut args: [*mut std::ffi::c_void; 6] = [
694 &a_ptr as *const _ as *mut _,
695 &b_ptr as *const _ as *mut _,
696 &c_ptr as *const _ as *mut _,
697 &m as *const _ as *mut _,
698 &n as *const _ as *mut _,
699 &k as *const _ as *mut _,
700 ];
701
702 unsafe {
705 stream.launch_kernel(module, "gemm_bf16_compute", &config, &mut args).map_err(|e| {
706 CudaTensorError::KernelError(format!("BF16 GEMM forward launch failed: {e:?}"))
707 })?;
708 }
709
710 Ok(())
711}
712
713#[cfg(feature = "cuda")]
720fn build_gemm_bf16_compute_ptx(sm_target: &str) -> String {
721 format!(
722 r".version 7.0
723.target {sm_target}
724.address_size 64
725
726.visible .entry gemm_bf16_compute(
727 .param .u64 a_ptr,
728 .param .u64 b_ptr,
729 .param .u64 c_ptr,
730 .param .u32 M,
731 .param .u32 N,
732 .param .u32 K
733) {{
734 .reg .u32 %r<20>;
735 .reg .u64 %rd<8>;
736 .reg .f32 %f<4>;
737 .reg .pred %p<4>;
738
739 // col = ctaid.x * 16 + tid.x
740 mov.u32 %r0, %ctaid.x;
741 mov.u32 %r1, %ntid.x;
742 mov.u32 %r2, %tid.x;
743 mad.lo.u32 %r3, %r0, %r1, %r2;
744
745 // row = ctaid.y * 16 + tid.y
746 mov.u32 %r4, %ctaid.y;
747 mov.u32 %r5, %ntid.y;
748 mov.u32 %r6, %tid.y;
749 mad.lo.u32 %r7, %r4, %r5, %r6;
750
751 // Load params
752 ld.param.u64 %rd0, [a_ptr];
753 ld.param.u64 %rd1, [b_ptr];
754 ld.param.u64 %rd2, [c_ptr];
755 ld.param.u32 %r8, [M];
756 ld.param.u32 %r9, [N];
757 ld.param.u32 %r10, [K];
758
759 // Bounds check: row < M && col < N
760 setp.ge.u32 %p0, %r7, %r8;
761 setp.ge.u32 %p1, %r3, %r9;
762 or.pred %p2, %p0, %p1;
763 @%p2 bra exit;
764
765 // acc = 0.0f
766 mov.f32 %f0, 0f00000000;
767
768 // Loop: for i = 0; i < K; i++
769 mov.u32 %r11, 0;
770loop_start:
771 setp.ge.u32 %p3, %r11, %r10;
772 @%p3 bra loop_end;
773
774 // Load A[row, i] as u32 bits, truncate to bf16 precision
775 mul.lo.u32 %r12, %r7, %r10;
776 add.u32 %r12, %r12, %r11;
777 mul.wide.u32 %rd3, %r12, 4;
778 add.u64 %rd3, %rd0, %rd3;
779 ld.global.u32 %r13, [%rd3];
780 and.b32 %r13, %r13, 0xFFFF0000;
781 mov.b32 %f1, %r13;
782
783 // Load B[i, col] as u32 bits, truncate to bf16 precision
784 mul.lo.u32 %r14, %r11, %r9;
785 add.u32 %r14, %r14, %r3;
786 mul.wide.u32 %rd4, %r14, 4;
787 add.u64 %rd4, %rd1, %rd4;
788 ld.global.u32 %r15, [%rd4];
789 and.b32 %r15, %r15, 0xFFFF0000;
790 mov.b32 %f2, %r15;
791
792 // acc += a_bf16 * b_bf16 (FMA in f32 accumulator)
793 fma.rn.f32 %f0, %f1, %f2, %f0;
794
795 add.u32 %r11, %r11, 1;
796 bra loop_start;
797
798loop_end:
799 // Store C[row, col]
800 mul.lo.u32 %r16, %r7, %r9;
801 add.u32 %r16, %r16, %r3;
802 mul.wide.u32 %rd5, %r16, 4;
803 add.u64 %rd5, %rd2, %rd5;
804 st.global.f32 [%rd5], %f0;
805
806exit:
807 ret;
808}}
809"
810 )
811}
812
813#[cfg(feature = "cuda")]
844pub fn gemm_nf4_dequant_cublas(
845 a: &GpuBuffer<f32>,
846 w: &GpuBuffer<f32>,
847 c: &mut GpuBuffer<f32>,
848 m: u32,
849 k: u32,
850 n: u32,
851 stream: &CudaStream,
852) -> Result<()> {
853 let _ = stream; let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
856 let cache = cache.lock().map_err(|_err| {
857 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
858 })?;
859
860 let cublas = cache.cublas().ok_or_else(|| {
861 CudaTensorError::KernelError("cuBLAS not available for NF4 dequant GEMM".to_string())
862 })?;
863
864 cublas
870 .gemm_f32(
871 GemmOp::Trans, GemmOp::NoTrans, n as i32, m as i32, k as i32, 1.0,
877 w.as_ptr(), k as i32, a.as_ptr(), k as i32, 0.0,
882 c.as_ptr(), n as i32, )
885 .map_err(|e| {
886 CudaTensorError::KernelError(format!("cuBLAS NF4 dequant forward failed: {e:?}"))
887 })
888}
889
890#[cfg(feature = "cuda")]
917pub fn gemm_nf4_backward_a_cublas(
918 grad_output: &GpuBuffer<f32>,
919 w: &GpuBuffer<f32>,
920 grad_input: &mut GpuBuffer<f32>,
921 m: u32,
922 k: u32,
923 n: u32,
924 stream: &CudaStream,
925) -> Result<()> {
926 let _ = stream; let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
929 let cache = cache.lock().map_err(|_err| {
930 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
931 })?;
932
933 let cublas = cache.cublas().ok_or_else(|| {
934 CudaTensorError::KernelError("cuBLAS not available for NF4 backward GEMM".to_string())
935 })?;
936
937 cublas
940 .gemm_f32(
941 GemmOp::NoTrans, GemmOp::NoTrans, k as i32, m as i32, n as i32, 1.0,
947 w.as_ptr(), k as i32, grad_output.as_ptr(), n as i32, 0.0,
952 grad_input.as_ptr(), k as i32, )
955 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS NF4 backward_a failed: {e:?}")))
956}
957
958#[cfg(feature = "cuda")]
979pub fn gemm_nf4_backward_a(
980 grad_output: &GpuBuffer<f32>,
981 w_nf4: &GpuBuffer<u8>,
982 w_scales: &GpuBuffer<f32>,
983 grad_input: &mut GpuBuffer<f32>,
984 m: u32,
985 n: u32,
986 k: u32,
987 stream: &CudaStream,
988) -> Result<()> {
989 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
990 let mut cache = cache.lock().map_err(|_err| {
991 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
992 })?;
993
994 let kernel = Nf4GemmTransposeKernel::new(m, n, k);
995 let tile_size = kernel.tile_size;
996
997 let key = format!("nf4_gemm_transpose_{n}_{k}");
999 let module = match cache.get_cached(&key) {
1000 Some(m) => m,
1001 None => {
1002 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
1003 cache.get_or_compile(&key, &ptx)?
1004 }
1005 };
1006
1007 let config = LaunchConfig {
1009 grid: (k.div_ceil(tile_size), m.div_ceil(tile_size), 1),
1010 block: (tile_size * tile_size, 1, 1),
1011 shared_mem: 16 * 4, };
1013
1014 let a_ptr = grad_output.as_ptr();
1015 let b_nf4_ptr = w_nf4.as_ptr();
1016 let b_scales_ptr = w_scales.as_ptr();
1017 let c_ptr = grad_input.as_ptr();
1018
1019 let mut args: [*mut std::ffi::c_void; 7] = [
1020 &a_ptr as *const _ as *mut _,
1021 &b_nf4_ptr as *const _ as *mut _,
1022 &b_scales_ptr as *const _ as *mut _,
1023 &c_ptr as *const _ as *mut _,
1024 &m as *const _ as *mut _,
1025 &n as *const _ as *mut _,
1026 &k as *const _ as *mut _,
1027 ];
1028
1029 unsafe {
1031 stream.launch_kernel(module, "nf4_gemm_transpose", &config, &mut args).map_err(|e| {
1032 CudaTensorError::KernelError(format!("NF4 GEMM transpose launch failed: {e:?}"))
1033 })?;
1034 }
1035
1036 Ok(())
1037}
1038
1039#[cfg(feature = "cuda")]
1048pub fn gemm_nf4_tc_backward_a(
1049 grad_output: &GpuBuffer<f32>,
1050 w_nf4: &GpuBuffer<u8>,
1051 w_scales: &GpuBuffer<f32>,
1052 grad_input: &mut GpuBuffer<f32>,
1053 m: u32,
1054 n: u32,
1055 k: u32,
1056 stream: &CudaStream,
1057) -> Result<()> {
1058 use trueno_gpu::kernels::backward::Nf4TensorCoreGemmBackwardAKernel;
1059
1060 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
1061 let mut cache = cache.lock().map_err(|_err| {
1062 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
1063 })?;
1064
1065 let kernel = Nf4TensorCoreGemmBackwardAKernel::new(m, n, k);
1066
1067 let key = format!("nf4_tc_gemm_backward_a_{n}_{k}");
1069 let module = match cache.get_cached(&key) {
1070 Some(m) => m,
1071 None => {
1072 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
1073 cache.get_or_compile(&key, &ptx)?
1074 }
1075 };
1076
1077 let config = LaunchConfig {
1079 grid: (k.div_ceil(16), m.div_ceil(16), 1),
1080 block: (32, 1, 1),
1081 shared_mem: 16 * 16 * 2 * 2, };
1083
1084 let grad_out_ptr = grad_output.as_ptr();
1085 let scales_ptr = w_scales.as_ptr();
1086 let data_ptr = w_nf4.as_ptr();
1087 let grad_a_ptr = grad_input.as_ptr();
1088
1089 let mut args: [*mut std::ffi::c_void; 7] = [
1091 &grad_out_ptr as *const _ as *mut _,
1092 &scales_ptr as *const _ as *mut _,
1093 &data_ptr as *const _ as *mut _,
1094 &grad_a_ptr as *const _ as *mut _,
1095 &m as *const _ as *mut _,
1096 &n as *const _ as *mut _,
1097 &k as *const _ as *mut _,
1098 ];
1099
1100 unsafe {
1101 stream
1102 .launch_kernel(module, "nf4_tensor_core_gemm_backward_a", &config, &mut args)
1103 .map_err(|e| {
1104 CudaTensorError::KernelError(format!(
1105 "NF4 tensor core GEMM backward_a launch failed: {e:?}"
1106 ))
1107 })?;
1108 }
1109
1110 Ok(())
1111}