Skip to main content

entrenar/autograd/cuda_forward/
matmul.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CublasHandle, CudaStream, GemmOp, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10    Batched4DGemmKernel, FusedSwigluKernel, GemmKernel, Kernel, Nf4GemmKernel,
11    Nf4GemmTransposeKernel, Nf4TensorCoreGemmKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19/// Fused SwiGLU forward pass on GPU (ENT-150)
20///
21/// Computes: output = SiLU(gate) * up
22/// Fuses two operations into one kernel for better memory bandwidth.
23#[cfg(feature = "cuda")]
24pub fn fused_swiglu_forward(
25    gate: &GpuBuffer<f32>,
26    up: &GpuBuffer<f32>,
27    output: &mut GpuBuffer<f32>,
28    n: u32,
29    stream: &CudaStream,
30) -> Result<()> {
31    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
32    let mut cache = cache.lock().map_err(|_err| {
33        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
34    })?;
35
36    let key = "fused_swiglu_forward".to_string(); // PTX is n-independent (trueno#184)
37    let module = match cache.get_cached(&key) {
38        Some(m) => m,
39        None => {
40            let kernel = FusedSwigluKernel::new(n);
41            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
42            cache.get_or_compile(&key, &ptx)?
43        }
44    };
45
46    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
47
48    let gate_ptr = gate.as_ptr();
49    let up_ptr = up.as_ptr();
50    let output_ptr = output.as_ptr();
51
52    let mut args: [*mut std::ffi::c_void; 4] = [
53        &gate_ptr as *const _ as *mut _,
54        &up_ptr as *const _ as *mut _,
55        &output_ptr as *const _ as *mut _,
56        &n as *const _ as *mut _,
57    ];
58
59    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
60    // matching sizes, and the kernel parameters match the expected PTX signature.
61    unsafe {
62        stream.launch_kernel(module, "fused_swiglu", &config, &mut args).map_err(|e| {
63            CudaTensorError::KernelError(format!("Fused SwiGLU forward launch failed: {e:?}"))
64        })?;
65    }
66
67    Ok(())
68}
69
70/// GEMM forward pass on GPU
71///
72/// Computes: C = A @ B where A is MxK, B is KxN, C is MxN
73///
74/// Dispatches to cuBLAS tensor core TF32 when available (ALB-075), falling back to PTX
75/// naive GEMM. Backward GEMMs use CUBLAS_DEFAULT_MATH (SIMD) per ALB-076/trueno#170.
76#[cfg(feature = "cuda")]
77pub fn gemm_forward(
78    a: &GpuBuffer<f32>,
79    b: &GpuBuffer<f32>,
80    c: &mut GpuBuffer<f32>,
81    m: u32,
82    k: u32,
83    n: u32,
84    stream: &CudaStream,
85) -> Result<()> {
86    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
87    let mut cache = cache.lock().map_err(|_err| {
88        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
89    })?;
90    if let Some(cublas) = cache.cublas() {
91        return cublas_gemm_forward(cublas, a, b, c, m, k, n);
92    }
93
94    // PTX fallback
95    let key = format!("gemm_forward_{m}_{k}_{n}");
96    let module = match cache.get_cached(&key) {
97        Some(m) => m,
98        None => {
99            let kernel = GemmKernel::naive(m, n, k);
100            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
101            cache.get_or_compile(&key, &ptx)?
102        }
103    };
104
105    // Use 16x16 thread blocks for GEMM
106    // Kernel: col = ctaid.x * 16 + tid.x, row = ctaid.y * 16 + tid.y
107    // So grid.x = ceil(N/16) for columns, grid.y = ceil(M/16) for rows
108    let config = LaunchConfig {
109        grid: (n.div_ceil(16), m.div_ceil(16), 1),
110        block: (16, 16, 1),
111        shared_mem: 0,
112    };
113
114    let a_ptr = a.as_ptr();
115    let b_ptr = b.as_ptr();
116    let c_ptr = c.as_ptr();
117
118    // PTX kernel signature: (a_ptr, b_ptr, c_ptr, m, n, k)
119    // CRITICAL: must match param declaration order in GemmKernel::build_naive()
120    let mut args: [*mut std::ffi::c_void; 6] = [
121        &a_ptr as *const _ as *mut _,
122        &b_ptr as *const _ as *mut _,
123        &c_ptr as *const _ as *mut _,
124        &m as *const _ as *mut _,
125        &n as *const _ as *mut _,
126        &k as *const _ as *mut _,
127    ];
128
129    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
130    // matching sizes, and the kernel parameters match the expected PTX signature.
131    unsafe {
132        stream.launch_kernel(module, "gemm_naive", &config, &mut args).map_err(|e| {
133            CudaTensorError::KernelError(format!("GEMM forward launch failed: {e:?}"))
134        })?;
135    }
136
137    Ok(())
138}
139
140/// GEMM with transposed B: C[M,N] = A[M,K] @ B[N,K]^T
141/// B is stored row-major [N,K]. entrenar#318: GPU lm_head with embed_original.
142#[cfg(feature = "cuda")]
143pub fn gemm_forward_bt(
144    a: &GpuBuffer<f32>,
145    b: &GpuBuffer<f32>,
146    c: &mut GpuBuffer<f32>,
147    m: u32,
148    k: u32,
149    n: u32,
150    _stream: &CudaStream,
151) -> Result<()> {
152    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
153    let cache = cache.lock().map_err(|_| CudaTensorError::KernelError("cache lock".to_string()))?;
154    if let Some(cublas) = cache.cublas() {
155        return cublas_gemm_forward_bt(cublas, a, b, c, m, k, n);
156    }
157    Err(CudaTensorError::KernelError("gemm_forward_bt requires cuBLAS".to_string()))
158}
159
160#[cfg(feature = "cuda")]
161fn cublas_gemm_forward_bt(
162    cublas: &CublasHandle,
163    a: &GpuBuffer<f32>,
164    b: &GpuBuffer<f32>,
165    c: &mut GpuBuffer<f32>,
166    m: u32,
167    k: u32,
168    n: u32,
169) -> Result<()> {
170    // Row-major C[M,N] = A[M,K] @ B[N,K]^T
171    // Column-major: C^T[N,M] = Trans(B_col[K,N])[N,K] @ A_col[K,M]
172    cublas
173        .gemm_f32(
174            GemmOp::Trans,   // B transposed
175            GemmOp::NoTrans, // A not transposed
176            n as i32,
177            m as i32,
178            k as i32,
179            1.0,
180            b.as_ptr(),
181            k as i32, // ldb = K (B is [K,N] in col-major, transposed to [N,K])
182            a.as_ptr(),
183            k as i32, // lda = K (A is [K,M] in col-major)
184            0.0,
185            c.as_ptr(),
186            n as i32, // ldc = N
187        )
188        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM BT failed: {e:?}")))
189}
190
191/// cuBLAS GEMM forward: C[M,N] = A[M,K] @ B[K,N] (row-major via B^T@A^T identity)
192#[cfg(feature = "cuda")]
193fn cublas_gemm_forward(
194    cublas: &CublasHandle,
195    a: &GpuBuffer<f32>,
196    b: &GpuBuffer<f32>,
197    c: &mut GpuBuffer<f32>,
198    m: u32,
199    k: u32,
200    n: u32,
201) -> Result<()> {
202    cublas
203        .gemm_f32(
204            GemmOp::NoTrans,
205            GemmOp::NoTrans,
206            n as i32,
207            m as i32,
208            k as i32,
209            1.0,
210            b.as_ptr(),
211            n as i32,
212            a.as_ptr(),
213            k as i32,
214            0.0,
215            c.as_ptr(),
216            n as i32,
217        )
218        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM forward failed: {e:?}")))
219}
220
221/// cuBLAS backward A: grad_A[M,K] = grad_C[M,N] @ B[K,N]^T
222#[cfg(feature = "cuda")]
223pub(crate) fn cublas_gemm_backward_a(
224    cublas: &CublasHandle,
225    grad_output: &GpuBuffer<f32>,
226    b: &GpuBuffer<f32>,
227    grad_a: &mut GpuBuffer<f32>,
228    m: u32,
229    k: u32,
230    n: u32,
231) -> Result<()> {
232    cublas
233        .gemm_f32(
234            GemmOp::Trans,
235            GemmOp::NoTrans,
236            k as i32,
237            m as i32,
238            n as i32,
239            1.0,
240            b.as_ptr(),
241            n as i32,
242            grad_output.as_ptr(),
243            n as i32,
244            0.0,
245            grad_a.as_ptr(),
246            k as i32,
247        )
248        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM backward_a failed: {e:?}")))
249}
250
251/// cuBLAS backward A with accumulation: grad_A += grad_C @ B^T (PMAT-484)
252///
253/// Same as cublas_gemm_backward_a but uses beta=1.0 to ACCUMULATE into grad_a
254/// instead of overwriting. Enables fused Gate+Up backward without a separate
255/// cuda_add_inplace call.
256#[cfg(feature = "cuda")]
257pub(crate) fn cublas_gemm_backward_a_accumulate(
258    cublas: &CublasHandle,
259    grad_output: &GpuBuffer<f32>,
260    b: &GpuBuffer<f32>,
261    grad_a: &mut GpuBuffer<f32>,
262    m: u32,
263    k: u32,
264    n: u32,
265) -> Result<()> {
266    cublas
267        .gemm_f32(
268            GemmOp::Trans,
269            GemmOp::NoTrans,
270            k as i32,
271            m as i32,
272            n as i32,
273            1.0,
274            b.as_ptr(),
275            n as i32,
276            grad_output.as_ptr(),
277            n as i32,
278            1.0, // ACCUMULATE: C = 1.0 * A @ B + 1.0 * C
279            grad_a.as_ptr(),
280            k as i32,
281        )
282        .map_err(|e| {
283            CudaTensorError::KernelError(format!("cuBLAS GEMM backward_a accumulate failed: {e:?}"))
284        })
285}
286
287/// cuBLAS backward B: grad_B[K,N] = A[M,K]^T @ grad_C[M,N]
288#[cfg(feature = "cuda")]
289pub(crate) fn cublas_gemm_backward_b(
290    cublas: &CublasHandle,
291    a: &GpuBuffer<f32>,
292    grad_output: &GpuBuffer<f32>,
293    grad_b: &mut GpuBuffer<f32>,
294    m: u32,
295    k: u32,
296    n: u32,
297) -> Result<()> {
298    cublas
299        .gemm_f32(
300            GemmOp::NoTrans,
301            GemmOp::Trans,
302            n as i32,
303            k as i32,
304            m as i32,
305            1.0,
306            grad_output.as_ptr(),
307            n as i32,
308            a.as_ptr(),
309            k as i32,
310            0.0,
311            grad_b.as_ptr(),
312            n as i32,
313        )
314        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM backward_b failed: {e:?}")))
315}
316
317/// Batched 4D GEMM forward pass on GPU for multi-head attention
318///
319/// Computes: C[b,h] = A[b,h] @ B[b,h] for each batch b and head h
320/// Pattern: [batch, heads, m, k] @ [batch, heads, k, n] -> [batch, heads, m, n]
321///
322/// # Contract (C-B4DGEMM-001)
323///
324/// - **Precondition**: a.len() >= batch * heads * m * k, b.len() >= batch * heads * k * n,
325///   c.len() >= batch * heads * m * n
326/// - **Postcondition**: C[b,h] = A[b,h] @ B[b,h] for all (b,h) in [0,batch)×[0,heads)
327/// - **Invariant**: Zero CPU-side data transfers
328#[cfg(feature = "cuda")]
329pub fn batched_4d_gemm_forward(
330    a: &GpuBuffer<f32>,
331    b: &GpuBuffer<f32>,
332    c: &mut GpuBuffer<f32>,
333    batch: u32,
334    heads: u32,
335    m: u32,
336    n: u32,
337    k: u32,
338    stream: &CudaStream,
339) -> Result<()> {
340    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
341    let mut cache = cache.lock().map_err(|_err| {
342        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
343    })?;
344
345    // ALB-075 Phase 4: cuBLAS strided batched GEMM for attention (16x faster than PTX)
346    if let Some(cublas) = cache.cublas() {
347        let batch_count = (batch * heads) as i32;
348        let stride_a = i64::from(m) * i64::from(k);
349        let stride_b = i64::from(k) * i64::from(n);
350        let stride_c = i64::from(m) * i64::from(n);
351        return cublas
352            .gemm_f32_strided_batched_row_major(
353                m as i32,
354                n as i32,
355                k as i32,
356                1.0,
357                a.as_ptr(),
358                stride_a,
359                b.as_ptr(),
360                stride_b,
361                0.0,
362                c.as_ptr(),
363                stride_c,
364                batch_count,
365            )
366            .map_err(|e| {
367                CudaTensorError::KernelError(format!("cuBLAS batched 4D GEMM failed: {e:?}"))
368            });
369    }
370
371    let kernel = Batched4DGemmKernel::new(batch, heads, m, n, k);
372    let tile_size = kernel.config.tile_size;
373
374    let key = format!("batched_4d_gemm_{batch}_{heads}_{m}_{n}_{k}");
375    let module = match cache.get_cached(&key) {
376        Some(m) => m,
377        None => {
378            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
379            cache.get_or_compile(&key, &ptx)?
380        }
381    };
382
383    // Grid: ((m+tile-1)/tile, (n+tile-1)/tile, batch * heads)
384    // Block: (tile_size, tile_size, 1)
385    // Shared memory: tile_size * tile_size * 4 * 2 bytes (tiles for A and B)
386    let config = LaunchConfig {
387        grid: (n.div_ceil(tile_size), m.div_ceil(tile_size), batch * heads),
388        block: (tile_size, tile_size, 1),
389        shared_mem: tile_size * tile_size * 4 * 2,
390    };
391
392    let a_ptr = a.as_ptr();
393    let b_ptr = b.as_ptr();
394    let c_ptr = c.as_ptr();
395
396    // PTX kernel signature: (a_ptr, b_ptr, c_ptr, batch, heads, m, n, k)
397    let mut args: [*mut std::ffi::c_void; 8] = [
398        &a_ptr as *const _ as *mut _,
399        &b_ptr as *const _ as *mut _,
400        &c_ptr as *const _ as *mut _,
401        &batch as *const _ as *mut _,
402        &heads as *const _ as *mut _,
403        &m as *const _ as *mut _,
404        &n as *const _ as *mut _,
405        &k as *const _ as *mut _,
406    ];
407
408    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
409    // matching sizes, and the kernel parameters match the expected PTX signature.
410    unsafe {
411        stream.launch_kernel(module, "batched_4d_gemm", &config, &mut args).map_err(|e| {
412            CudaTensorError::KernelError(format!("Batched 4D GEMM forward launch failed: {e:?}"))
413        })?;
414    }
415
416    Ok(())
417}
418
419/// NF4 quantized GEMM forward pass on GPU (trueno#108).
420///
421/// Computes: C = A @ dequant(B_nf4) where:
422/// - A is MxK (f32 activations)
423/// - B_nf4 is packed 4-bit NF4 weights (u8)
424/// - B_scales is per-block f32 scale factors
425/// - C is MxN (f32 output)
426///
427/// The kernel fuses dequantization with matmul: no intermediate fp32 weight buffer needed.
428///
429/// # Contract: C-NF4-003 (GEMM Numerical Parity)
430///
431/// `nf4_gemm(A, Q) ≈ naive_gemm(A, dequantize(Q))` within 1e-3 per-element.
432#[cfg(feature = "cuda")]
433pub fn gemm_nf4_forward(
434    a: &GpuBuffer<f32>,
435    b_nf4: &GpuBuffer<u8>,
436    b_scales: &GpuBuffer<f32>,
437    c: &mut GpuBuffer<f32>,
438    m: u32,
439    k: u32,
440    n: u32,
441    stream: &CudaStream,
442) -> Result<()> {
443    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
444    let mut cache = cache.lock().map_err(|_err| {
445        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
446    })?;
447
448    let kernel = Nf4GemmKernel::new(m, n, k);
449    let tile_size = kernel.tile_size;
450
451    // Cache key excludes M (seq_len) — PTX is shape-independent (m/n/k are
452    // runtime params, only tile_size is baked in). Including M causes cache misses
453    // when actual seq_len differs from max_seq_len used during pre-warming,
454    // triggering on-demand JIT that fails on Blackwell (trueno#184).
455    let key = format!("nf4_gemm_forward_{k}_{n}");
456    let module = match cache.get_cached(&key) {
457        Some(m) => m,
458        None => {
459            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
460            cache.get_or_compile(&key, &ptx)?
461        }
462    };
463
464    // Use tile_size × tile_size thread blocks (same as Q4K GEMM)
465    let config = LaunchConfig {
466        grid: (n.div_ceil(tile_size), m.div_ceil(tile_size), 1),
467        block: (tile_size * tile_size, 1, 1),
468        shared_mem: 16 * 4, // NF4 codebook LUT (16 × f32)
469    };
470
471    let a_ptr = a.as_ptr();
472    let b_nf4_ptr = b_nf4.as_ptr();
473    let b_scales_ptr = b_scales.as_ptr();
474    let c_ptr = c.as_ptr();
475
476    // PTX kernel signature: (a_ptr, b_nf4_ptr, b_scales_ptr, c_ptr, m, n, k)
477    // CRITICAL: must match param declaration order in Nf4GemmKernel::build_ptx()
478    let mut args: [*mut std::ffi::c_void; 7] = [
479        &a_ptr as *const _ as *mut _,
480        &b_nf4_ptr as *const _ as *mut _,
481        &b_scales_ptr as *const _ as *mut _,
482        &c_ptr as *const _ as *mut _,
483        &m as *const _ as *mut _,
484        &n as *const _ as *mut _,
485        &k as *const _ as *mut _,
486    ];
487
488    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
489    // matching sizes, and the kernel parameters match the expected PTX signature.
490    unsafe {
491        stream.launch_kernel(module, "nf4_gemm_fused", &config, &mut args).map_err(|e| {
492            CudaTensorError::KernelError(format!("NF4 GEMM forward launch failed: {e:?}"))
493        })?;
494    }
495
496    Ok(())
497}
498
499/// PMAT-481: NF4 tensor core GEMM — WMMA 16×16×16 with inline NF4 dequant in SHMEM.
500///
501/// Dequantizes NF4 blocks to FP16 in shared memory, uses tensor cores for matmul.
502/// Expected 5-40x compute improvement over naive tiled NF4 GEMM.
503///
504/// Contract: nf4-tensor-core-gemm-v1.yaml (F-NF4-TC-001, F-NF4-TC-002)
505#[cfg(feature = "cuda")]
506pub fn gemm_nf4_tc_forward(
507    a: &GpuBuffer<f32>,
508    b_nf4: &GpuBuffer<u8>,
509    b_scales: &GpuBuffer<f32>,
510    c: &mut GpuBuffer<f32>,
511    m: u32,
512    k: u32,
513    n: u32,
514    stream: &CudaStream,
515) -> Result<()> {
516    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
517    let mut cache = cache.lock().map_err(|_err| {
518        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
519    })?;
520
521    let kernel = Nf4TensorCoreGemmKernel::new(m, n, k);
522
523    let key = format!("nf4_tc_gemm_forward_{k}_{n}");
524    let module = match cache.get_cached(&key) {
525        Some(m) => m,
526        None => {
527            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
528            cache.get_or_compile(&key, &ptx)?
529        }
530    };
531
532    // WMMA: 1 warp (32 threads) per 16×16 tile
533    let config = LaunchConfig {
534        grid: (n.div_ceil(16), m.div_ceil(16), 1),
535        block: (32, 1, 1),
536        shared_mem: 16 * 16 * 2 * 2, // A[16×16] + B[16×16] in FP16
537    };
538
539    let a_ptr = a.as_ptr();
540    let b_nf4_ptr = b_nf4.as_ptr();
541    let b_scales_ptr = b_scales.as_ptr();
542    let c_ptr = c.as_ptr();
543
544    // Kernel signature: (a_ptr, scales_ptr, data_ptr, c_ptr, m, n, k)
545    let mut args: [*mut std::ffi::c_void; 7] = [
546        &a_ptr as *const _ as *mut _,
547        &b_scales_ptr as *const _ as *mut _,
548        &b_nf4_ptr as *const _ as *mut _,
549        &c_ptr as *const _ as *mut _,
550        &m as *const _ as *mut _,
551        &n as *const _ as *mut _,
552        &k as *const _ as *mut _,
553    ];
554
555    unsafe {
556        stream.launch_kernel(module, "nf4_tensor_core_gemm", &config, &mut args).map_err(|e| {
557            CudaTensorError::KernelError(format!(
558                "NF4 tensor core GEMM forward launch failed: {e:?}"
559            ))
560        })?;
561    }
562
563    Ok(())
564}
565
566/// PMAT-475: Fused NF4 Gate+Up GEMM — computes both projections with shared input load.
567///
568/// Eliminates one full input activation read from DRAM per call.
569/// Savings: M × K × 4 bytes/call (12 MB/layer for Qwen 1.5B batch=4 seq=512).
570///
571/// `gate[M×N] = A[M×K] @ dequant(W_gate_nf4)` and
572/// `up[M×N]   = A[M×K] @ dequant(W_up_nf4)` in one kernel launch.
573pub fn gemm_nf4_gate_up_forward(
574    a: &GpuBuffer<f32>,
575    wg_nf4: &GpuBuffer<u8>,
576    wg_scales: &GpuBuffer<f32>,
577    wu_nf4: &GpuBuffer<u8>,
578    wu_scales: &GpuBuffer<f32>,
579    gate: &mut GpuBuffer<f32>,
580    up: &mut GpuBuffer<f32>,
581    m: u32,
582    k: u32,
583    n: u32,
584    stream: &CudaStream,
585) -> Result<()> {
586    use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
587
588    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
589    let mut cache = cache.lock().map_err(|_err| {
590        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
591    })?;
592
593    let kernel = FusedNf4GateUpGemmKernel::new(m, n, k);
594    let tile = kernel.tile_size;
595    let key = format!("fused_nf4_gate_up_{k}_{n}");
596    let module = match cache.get_cached(&key) {
597        Some(m) => m,
598        None => {
599            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
600            cache.get_or_compile(&key, &ptx)?
601        }
602    };
603
604    let config = LaunchConfig {
605        grid: (n.div_ceil(tile), m.div_ceil(tile), 1),
606        block: (tile * tile, 1, 1),
607        shared_mem: 16 * 4,
608    };
609
610    let a_ptr = a.as_ptr();
611    let gate_ptr = gate.as_ptr();
612    let up_ptr = up.as_ptr();
613    let wg_nf4_ptr = wg_nf4.as_ptr();
614    let wg_scales_ptr = wg_scales.as_ptr();
615    let wu_nf4_ptr = wu_nf4.as_ptr();
616    let wu_scales_ptr = wu_scales.as_ptr();
617
618    let mut args: [*mut std::ffi::c_void; 10] = [
619        &gate_ptr as *const _ as *mut _,
620        &up_ptr as *const _ as *mut _,
621        &a_ptr as *const _ as *mut _,
622        &wg_scales_ptr as *const _ as *mut _,
623        &wg_nf4_ptr as *const _ as *mut _,
624        &wu_scales_ptr as *const _ as *mut _,
625        &wu_nf4_ptr as *const _ as *mut _,
626        &m as *const _ as *mut _,
627        &n as *const _ as *mut _,
628        &k as *const _ as *mut _,
629    ];
630
631    unsafe {
632        stream.launch_kernel(module, "fused_nf4_gate_up_gemm", &config, &mut args).map_err(
633            |e| CudaTensorError::KernelError(format!("Fused NF4 gate+up launch: {e:?}")),
634        )?;
635    }
636
637    Ok(())
638}
639
640/// BF16-precision GEMM forward pass on GPU (R-002: BF16 mixed precision).
641///
642/// Computes: C = A @ B where A is MxK, B is KxN, C is MxN
643/// Both inputs are f32 (FP32 master weights), but compute is done at BF16
644/// precision: each operand is truncated to BF16 (7-bit mantissa) before
645/// multiply, with FP32 accumulation. Output is FP32.
646///
647/// This implements the standard mixed-precision pattern:
648/// - FP32 storage (master weights stay in full precision)
649/// - BF16 compute (reduced precision multiply for bandwidth savings)
650/// - FP32 accumulation (no loss in reduction precision)
651///
652/// # Contract (C-BF16GEMM-001)
653///
654/// - `C[i,j] = Σ_k trunc_bf16(A[i,k]) * trunc_bf16(B[k,j])` accumulated in f32
655/// - `trunc_bf16(x)` = f32::from_bits(x.to_bits() & 0xFFFF0000)
656/// - Output matches CPU BF16 reference within f32 accumulation tolerance
657#[cfg(feature = "cuda")]
658pub fn gemm_forward_bf16(
659    a: &GpuBuffer<f32>,
660    b: &GpuBuffer<f32>,
661    c: &mut GpuBuffer<f32>,
662    m: u32,
663    k: u32,
664    n: u32,
665    stream: &CudaStream,
666) -> Result<()> {
667    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
668    let mut cache = cache.lock().map_err(|_err| {
669        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
670    })?;
671
672    let key = format!("gemm_bf16_compute_{m}_{k}_{n}");
673    let module = match cache.get_cached(&key) {
674        Some(m) => m,
675        None => {
676            let ptx = build_gemm_bf16_compute_ptx(cache.sm_target());
677            cache.get_or_compile(&key, &ptx)?
678        }
679    };
680
681    let config = LaunchConfig {
682        grid: (n.div_ceil(16), m.div_ceil(16), 1),
683        block: (16, 16, 1),
684        shared_mem: 0,
685    };
686
687    let a_ptr = a.as_ptr();
688    let b_ptr = b.as_ptr();
689    let c_ptr = c.as_ptr();
690
691    // PTX kernel signature: (a_ptr, b_ptr, c_ptr, m, n, k)
692    // CRITICAL: must match param declaration order in build_gemm_bf16_compute_ptx()
693    let mut args: [*mut std::ffi::c_void; 6] = [
694        &a_ptr as *const _ as *mut _,
695        &b_ptr as *const _ as *mut _,
696        &c_ptr as *const _ as *mut _,
697        &m as *const _ as *mut _,
698        &n as *const _ as *mut _,
699        &k as *const _ as *mut _,
700    ];
701
702    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
703    // matching sizes, and the kernel parameters match the expected PTX signature.
704    unsafe {
705        stream.launch_kernel(module, "gemm_bf16_compute", &config, &mut args).map_err(|e| {
706            CudaTensorError::KernelError(format!("BF16 GEMM forward launch failed: {e:?}"))
707        })?;
708    }
709
710    Ok(())
711}
712
713/// Build PTX for BF16-precision GEMM kernel.
714///
715/// Naive GEMM with inline BF16 truncation: loads f32, truncates to bf16 precision
716/// (AND 0xFFFF0000), multiplies as f32, accumulates in f32. This matches the
717/// precision characteristics of hardware BF16 tensor cores (BF16 multiply, f32 accum,
718/// safe because forward GEMMs are NoTrans/NoTrans — unaffected by ALB-076/trueno#170).
719#[cfg(feature = "cuda")]
720fn build_gemm_bf16_compute_ptx(sm_target: &str) -> String {
721    format!(
722        r".version 7.0
723.target {sm_target}
724.address_size 64
725
726.visible .entry gemm_bf16_compute(
727    .param .u64 a_ptr,
728    .param .u64 b_ptr,
729    .param .u64 c_ptr,
730    .param .u32 M,
731    .param .u32 N,
732    .param .u32 K
733) {{
734    .reg .u32 %r<20>;
735    .reg .u64 %rd<8>;
736    .reg .f32 %f<4>;
737    .reg .pred %p<4>;
738
739    // col = ctaid.x * 16 + tid.x
740    mov.u32 %r0, %ctaid.x;
741    mov.u32 %r1, %ntid.x;
742    mov.u32 %r2, %tid.x;
743    mad.lo.u32 %r3, %r0, %r1, %r2;
744
745    // row = ctaid.y * 16 + tid.y
746    mov.u32 %r4, %ctaid.y;
747    mov.u32 %r5, %ntid.y;
748    mov.u32 %r6, %tid.y;
749    mad.lo.u32 %r7, %r4, %r5, %r6;
750
751    // Load params
752    ld.param.u64 %rd0, [a_ptr];
753    ld.param.u64 %rd1, [b_ptr];
754    ld.param.u64 %rd2, [c_ptr];
755    ld.param.u32 %r8, [M];
756    ld.param.u32 %r9, [N];
757    ld.param.u32 %r10, [K];
758
759    // Bounds check: row < M && col < N
760    setp.ge.u32 %p0, %r7, %r8;
761    setp.ge.u32 %p1, %r3, %r9;
762    or.pred %p2, %p0, %p1;
763    @%p2 bra exit;
764
765    // acc = 0.0f
766    mov.f32 %f0, 0f00000000;
767
768    // Loop: for i = 0; i < K; i++
769    mov.u32 %r11, 0;
770loop_start:
771    setp.ge.u32 %p3, %r11, %r10;
772    @%p3 bra loop_end;
773
774    // Load A[row, i] as u32 bits, truncate to bf16 precision
775    mul.lo.u32 %r12, %r7, %r10;
776    add.u32 %r12, %r12, %r11;
777    mul.wide.u32 %rd3, %r12, 4;
778    add.u64 %rd3, %rd0, %rd3;
779    ld.global.u32 %r13, [%rd3];
780    and.b32 %r13, %r13, 0xFFFF0000;
781    mov.b32 %f1, %r13;
782
783    // Load B[i, col] as u32 bits, truncate to bf16 precision
784    mul.lo.u32 %r14, %r11, %r9;
785    add.u32 %r14, %r14, %r3;
786    mul.wide.u32 %rd4, %r14, 4;
787    add.u64 %rd4, %rd1, %rd4;
788    ld.global.u32 %r15, [%rd4];
789    and.b32 %r15, %r15, 0xFFFF0000;
790    mov.b32 %f2, %r15;
791
792    // acc += a_bf16 * b_bf16 (FMA in f32 accumulator)
793    fma.rn.f32 %f0, %f1, %f2, %f0;
794
795    add.u32 %r11, %r11, 1;
796    bra loop_start;
797
798loop_end:
799    // Store C[row, col]
800    mul.lo.u32 %r16, %r7, %r9;
801    add.u32 %r16, %r16, %r3;
802    mul.wide.u32 %rd5, %r16, 4;
803    add.u64 %rd5, %rd2, %rd5;
804    st.global.f32 [%rd5], %f0;
805
806exit:
807    ret;
808}}
809"
810    )
811}
812
813/// cuBLAS GEMM for NF4 QLoRA forward with pre-dequantized fp32 weights (ENT-287).
814///
815/// Computes: `C[M,N] = A[M,K] @ W[N,K]^T` where W is stored row-major `[N,K]`
816/// (HuggingFace convention: `[out_features, in_features]`).
817///
818/// # Weight Layout Derivation
819///
820/// W is row-major `[N,K]`: element `(i,j)` at offset `i*K + j`.
821/// In column-major this is `[K,N]` with leading dimension K.
822///
823/// We want `C = A @ W^T`. Expanding in row-major: `C[M,N] = A[M,K] @ W^T[K,N]`.
824///
825/// Column-major equivalent: `C_cm[N,M] = (W^T)_cm[N,K] @ A_cm[K,M]`.
826/// Since W_cm is `[K,N]`, applying cuBLAS Trans gives `[N,K]` with `lda = K`.
827/// A_cm is `[K,M]` with `ldb = K`. C_cm is `[N,M]` with `ldc = N`.
828///
829/// cuBLAS call: `(Trans, NoTrans, N, M, K, W_ptr, K, A_ptr, K, C_ptr, N)`.
830///
831/// # Arguments
832///
833/// * `a` - Input activations `[M, K]` row-major (f32)
834/// * `w` - Weight matrix `[N, K]` row-major = `[K, N]` col-major (f32, original fp32 weights)
835/// * `c` - Output `[M, N]` row-major (f32)
836/// * `m` - Rows of A (seq_len)
837/// * `k` - Columns of A / columns of W (input dimension)
838/// * `n` - Rows of W (output dimension)
839///
840/// # Contract (C-NF4CUBLAS-001)
841///
842/// `gemm_nf4_dequant_cublas(A, W) = A @ W^T` within f32 precision.
843#[cfg(feature = "cuda")]
844pub fn gemm_nf4_dequant_cublas(
845    a: &GpuBuffer<f32>,
846    w: &GpuBuffer<f32>,
847    c: &mut GpuBuffer<f32>,
848    m: u32,
849    k: u32,
850    n: u32,
851    stream: &CudaStream,
852) -> Result<()> {
853    let _ = stream; // cuBLAS uses its own stream (set via set_forward_cublas_stream)
854
855    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
856    let cache = cache.lock().map_err(|_err| {
857        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
858    })?;
859
860    let cublas = cache.cublas().ok_or_else(|| {
861        CudaTensorError::KernelError("cuBLAS not available for NF4 dequant GEMM".to_string())
862    })?;
863
864    // C[M,N] = A[M,K] @ W[N,K]^T
865    // col-major: C_cm[N,M] = W_cm_transposed[N,K] @ A_cm[K,M]
866    // W_cm is [K,N] with lda=K. Trans on it gives [N,K].
867    // A_cm is [K,M] with lda=K.
868    // C_cm is [N,M] with ldc=N.
869    cublas
870        .gemm_f32(
871            GemmOp::Trans,   // W_cm[K,N] transposed → [N,K]
872            GemmOp::NoTrans, // A_cm[K,M]
873            n as i32,        // rows of op(W) = N
874            m as i32,        // cols of op(A) = M
875            k as i32,        // shared dim = K
876            1.0,
877            w.as_ptr(), // W: row-major [N,K] = col-major [K,N], lda=K
878            k as i32,   // lda = K (leading dim of W_cm[K,N])
879            a.as_ptr(), // A: row-major [M,K] = col-major [K,M], lda=K
880            k as i32,   // ldb = K (leading dim of A_cm[K,M])
881            0.0,
882            c.as_ptr(), // C: row-major [M,N] = col-major [N,M], ldc=N
883            n as i32,   // ldc = N
884        )
885        .map_err(|e| {
886            CudaTensorError::KernelError(format!("cuBLAS NF4 dequant forward failed: {e:?}"))
887        })
888}
889
890/// cuBLAS GEMM for NF4 QLoRA backward: grad_input (ENT-287).
891///
892/// Computes: `grad_input[M,K] = grad_output[M,N] @ W[N,K]` where W is row-major `[N,K]`.
893///
894/// This is standard GEMM `C = A @ B` where `B = W[N,K]`.
895///
896/// Derivation:
897/// Row-major: `C[M,K] = A[M,N] @ B[N,K]`
898/// col-major: `C_cm[K,M] = B_cm[K,N] @ A_cm[N,M]`
899/// - B = W row-major `[N,K]` = col-major `[K,N]` with `lda = K`
900/// - A = grad_out row-major `[M,N]` = col-major `[N,M]` with `ldb = N`
901/// - C = grad_in row-major `[M,K]` = col-major `[K,M]` with `ldc = K`
902/// So: `cublas(NoTrans, NoTrans, K, M, N, W_ptr, K, grad_out_ptr, N, grad_in_ptr, K)`
903///
904/// # Arguments
905///
906/// * `grad_output` - Upstream gradient `[M, N]` (f32)
907/// * `w` - Weight matrix `[N, K]` row-major (f32, pre-dequantized)
908/// * `grad_input` - Output gradient `[M, K]` (f32)
909/// * `m` - Rows (seq_len)
910/// * `k` - Output columns (input dimension)
911/// * `n` - Shared dimension (output dimension)
912///
913/// # Contract (C-NF4CUBLAS-002)
914///
915/// `gemm_nf4_backward_a_cublas(grad, W) = grad @ W` within f32 precision.
916#[cfg(feature = "cuda")]
917pub fn gemm_nf4_backward_a_cublas(
918    grad_output: &GpuBuffer<f32>,
919    w: &GpuBuffer<f32>,
920    grad_input: &mut GpuBuffer<f32>,
921    m: u32,
922    k: u32,
923    n: u32,
924    stream: &CudaStream,
925) -> Result<()> {
926    let _ = stream; // cuBLAS uses its own stream (set via set_forward_cublas_stream)
927
928    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
929    let cache = cache.lock().map_err(|_err| {
930        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
931    })?;
932
933    let cublas = cache.cublas().ok_or_else(|| {
934        CudaTensorError::KernelError("cuBLAS not available for NF4 backward GEMM".to_string())
935    })?;
936
937    // grad_in[M,K] = grad_out[M,N] @ W[N,K]
938    // col-major: C_cm[K,M] = W_cm[K,N] @ A_cm[N,M]
939    cublas
940        .gemm_f32(
941            GemmOp::NoTrans, // W_cm[K,N] as-is
942            GemmOp::NoTrans, // grad_out_cm[N,M] as-is
943            k as i32,        // rows of W_cm = K
944            m as i32,        // cols of grad_out_cm = M
945            n as i32,        // shared dim = N
946            1.0,
947            w.as_ptr(),           // W: row-major [N,K] = col-major [K,N], lda=K
948            k as i32,             // lda = K
949            grad_output.as_ptr(), // grad_out: row-major [M,N] = col-major [N,M], ldb=N
950            n as i32,             // ldb = N
951            0.0,
952            grad_input.as_ptr(), // grad_in: row-major [M,K] = col-major [K,M], ldc=K
953            k as i32,            // ldc = K
954        )
955        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS NF4 backward_a failed: {e:?}")))
956}
957
958/// NF4 transposed GEMM for backward pass (ENT-153: QLoRA backward).
959///
960/// Computes: `grad_input[M×K] = grad_output[M×N] @ dequant(W_nf4[K×N])^T`
961///
962/// This is the gradient-flow kernel: given upstream gradient and frozen NF4 weights,
963/// computes the input gradient without materializing fp32 weights.
964///
965/// # Arguments
966///
967/// * `grad_output` - Upstream gradient `[M × N]` (f32)
968/// * `w_nf4` - Frozen NF4-packed weights for `W[K × N]` (u8)
969/// * `w_scales` - Per-block scales for `W[K × N]` (f32)
970/// * `grad_input` - Output gradient `[M × K]` (f32)
971/// * `m` - Rows of grad_output (seq_len)
972/// * `n` - Columns of W (reduction dimension)
973/// * `k` - Rows of W (output columns = input dimension)
974///
975/// # Contract: C-NF4T-001 (Transposed GEMM Parity)
976///
977/// `gemm_nf4_backward_a(grad, W_nf4) ≈ gemm(grad, dequant(W)^T)` within 1e-3.
978#[cfg(feature = "cuda")]
979pub fn gemm_nf4_backward_a(
980    grad_output: &GpuBuffer<f32>,
981    w_nf4: &GpuBuffer<u8>,
982    w_scales: &GpuBuffer<f32>,
983    grad_input: &mut GpuBuffer<f32>,
984    m: u32,
985    n: u32,
986    k: u32,
987    stream: &CudaStream,
988) -> Result<()> {
989    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
990    let mut cache = cache.lock().map_err(|_err| {
991        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
992    })?;
993
994    let kernel = Nf4GemmTransposeKernel::new(m, n, k);
995    let tile_size = kernel.tile_size;
996
997    // Cache key excludes M (seq_len) — PTX is shape-independent (trueno#184).
998    let key = format!("nf4_gemm_transpose_{n}_{k}");
999    let module = match cache.get_cached(&key) {
1000        Some(m) => m,
1001        None => {
1002            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
1003            cache.get_or_compile(&key, &ptx)?
1004        }
1005    };
1006
1007    // Output is [M × K], tiled with tile_size
1008    let config = LaunchConfig {
1009        grid: (k.div_ceil(tile_size), m.div_ceil(tile_size), 1),
1010        block: (tile_size * tile_size, 1, 1),
1011        shared_mem: 16 * 4, // NF4 codebook LUT
1012    };
1013
1014    let a_ptr = grad_output.as_ptr();
1015    let b_nf4_ptr = w_nf4.as_ptr();
1016    let b_scales_ptr = w_scales.as_ptr();
1017    let c_ptr = grad_input.as_ptr();
1018
1019    let mut args: [*mut std::ffi::c_void; 7] = [
1020        &a_ptr as *const _ as *mut _,
1021        &b_nf4_ptr as *const _ as *mut _,
1022        &b_scales_ptr as *const _ as *mut _,
1023        &c_ptr as *const _ as *mut _,
1024        &m as *const _ as *mut _,
1025        &n as *const _ as *mut _,
1026        &k as *const _ as *mut _,
1027    ];
1028
1029    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations.
1030    unsafe {
1031        stream.launch_kernel(module, "nf4_gemm_transpose", &config, &mut args).map_err(|e| {
1032            CudaTensorError::KernelError(format!("NF4 GEMM transpose launch failed: {e:?}"))
1033        })?;
1034    }
1035
1036    Ok(())
1037}
1038
1039/// PMAT-481: NF4 tensor core backward GEMM — WMMA 16×16×16 with inline NF4 dequant.
1040///
1041/// Computes `grad_input[M×K] = grad_output[M×N] @ dequant(B_nf4[K×N])^T`
1042///
1043/// Eliminates separate dequant kernel + generic cuBLAS GEMM per backward projection.
1044/// Uses trueno `Nf4TensorCoreGemmBackwardAKernel` (WMMA, shared memory NF4 dequant).
1045///
1046/// Contract: nf4-backward-tensor-core-gemm-v1.yaml
1047#[cfg(feature = "cuda")]
1048pub fn gemm_nf4_tc_backward_a(
1049    grad_output: &GpuBuffer<f32>,
1050    w_nf4: &GpuBuffer<u8>,
1051    w_scales: &GpuBuffer<f32>,
1052    grad_input: &mut GpuBuffer<f32>,
1053    m: u32,
1054    n: u32,
1055    k: u32,
1056    stream: &CudaStream,
1057) -> Result<()> {
1058    use trueno_gpu::kernels::backward::Nf4TensorCoreGemmBackwardAKernel;
1059
1060    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
1061    let mut cache = cache.lock().map_err(|_err| {
1062        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
1063    })?;
1064
1065    let kernel = Nf4TensorCoreGemmBackwardAKernel::new(m, n, k);
1066
1067    // Cache key: backward TC is shape-independent for (n, k) pair
1068    let key = format!("nf4_tc_gemm_backward_a_{n}_{k}");
1069    let module = match cache.get_cached(&key) {
1070        Some(m) => m,
1071        None => {
1072            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
1073            cache.get_or_compile(&key, &ptx)?
1074        }
1075    };
1076
1077    // WMMA backward: Grid = (ceil(K/16), ceil(M/16)), Block = 32 threads (1 warp)
1078    let config = LaunchConfig {
1079        grid: (k.div_ceil(16), m.div_ceil(16), 1),
1080        block: (32, 1, 1),
1081        shared_mem: 16 * 16 * 2 * 2, // grad_out[16×16] + B^T[16×16] in FP16
1082    };
1083
1084    let grad_out_ptr = grad_output.as_ptr();
1085    let scales_ptr = w_scales.as_ptr();
1086    let data_ptr = w_nf4.as_ptr();
1087    let grad_a_ptr = grad_input.as_ptr();
1088
1089    // Kernel signature: (grad_out_ptr, scales_ptr, data_ptr, grad_a_ptr, m, n, k)
1090    let mut args: [*mut std::ffi::c_void; 7] = [
1091        &grad_out_ptr as *const _ as *mut _,
1092        &scales_ptr as *const _ as *mut _,
1093        &data_ptr as *const _ as *mut _,
1094        &grad_a_ptr as *const _ as *mut _,
1095        &m as *const _ as *mut _,
1096        &n as *const _ as *mut _,
1097        &k as *const _ as *mut _,
1098    ];
1099
1100    unsafe {
1101        stream
1102            .launch_kernel(module, "nf4_tensor_core_gemm_backward_a", &config, &mut args)
1103            .map_err(|e| {
1104                CudaTensorError::KernelError(format!(
1105                    "NF4 tensor core GEMM backward_a launch failed: {e:?}"
1106                ))
1107            })?;
1108    }
1109
1110    Ok(())
1111}