#![allow(unsafe_code)]
#![allow(trivial_casts)]
#![allow(clippy::borrow_as_ptr)]
#![allow(clippy::ref_as_ptr)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CublasHandle, CudaStream, GemmOp, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::{
Batched4DGemmKernel, FusedSwigluKernel, GemmKernel, Kernel, Nf4GemmKernel,
Nf4GemmTransposeKernel, Nf4TensorCoreGemmKernel,
};
use crate::autograd::cuda_tensor::{CudaTensorError, Result};
#[cfg(feature = "cuda")]
use super::cache::FORWARD_KERNEL_CACHE;
#[cfg(feature = "cuda")]
pub fn fused_swiglu_forward(
gate: &GpuBuffer<f32>,
up: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "fused_swiglu_forward".to_string(); let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = FusedSwigluKernel::new(n);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let gate_ptr = gate.as_ptr();
let up_ptr = up.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 4] = [
&gate_ptr as *const _ as *mut _,
&up_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&n as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "fused_swiglu", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Fused SwiGLU forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_forward(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
if let Some(cublas) = cache.cublas() {
return cublas_gemm_forward(cublas, a, b, c, m, k, n);
}
let key = format!("gemm_forward_{m}_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = GemmKernel::naive(m, n, k);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (n.div_ceil(16), m.div_ceil(16), 1),
block: (16, 16, 1),
shared_mem: 0,
};
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&a_ptr as *const _ as *mut _,
&b_ptr as *const _ as *mut _,
&c_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "gemm_naive", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("GEMM forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_forward_bt(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
_stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_| CudaTensorError::KernelError("cache lock".to_string()))?;
if let Some(cublas) = cache.cublas() {
return cublas_gemm_forward_bt(cublas, a, b, c, m, k, n);
}
Err(CudaTensorError::KernelError("gemm_forward_bt requires cuBLAS".to_string()))
}
#[cfg(feature = "cuda")]
fn cublas_gemm_forward_bt(
cublas: &CublasHandle,
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f32(
GemmOp::Trans, GemmOp::NoTrans, n as i32,
m as i32,
k as i32,
1.0,
b.as_ptr(),
k as i32, a.as_ptr(),
k as i32, 0.0,
c.as_ptr(),
n as i32, )
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM BT failed: {e:?}")))
}
#[cfg(feature = "cuda")]
fn cublas_gemm_forward(
cublas: &CublasHandle,
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f32(
GemmOp::NoTrans,
GemmOp::NoTrans,
n as i32,
m as i32,
k as i32,
1.0,
b.as_ptr(),
n as i32,
a.as_ptr(),
k as i32,
0.0,
c.as_ptr(),
n as i32,
)
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM forward failed: {e:?}")))
}
#[cfg(feature = "cuda")]
pub(crate) fn cublas_gemm_backward_a(
cublas: &CublasHandle,
grad_output: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f32(
GemmOp::Trans,
GemmOp::NoTrans,
k as i32,
m as i32,
n as i32,
1.0,
b.as_ptr(),
n as i32,
grad_output.as_ptr(),
n as i32,
0.0,
grad_a.as_ptr(),
k as i32,
)
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM backward_a failed: {e:?}")))
}
#[cfg(feature = "cuda")]
pub(crate) fn cublas_gemm_backward_a_accumulate(
cublas: &CublasHandle,
grad_output: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f32(
GemmOp::Trans,
GemmOp::NoTrans,
k as i32,
m as i32,
n as i32,
1.0,
b.as_ptr(),
n as i32,
grad_output.as_ptr(),
n as i32,
1.0, grad_a.as_ptr(),
k as i32,
)
.map_err(|e| {
CudaTensorError::KernelError(format!("cuBLAS GEMM backward_a accumulate failed: {e:?}"))
})
}
#[cfg(feature = "cuda")]
pub(crate) fn cublas_gemm_backward_b(
cublas: &CublasHandle,
a: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_b: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f32(
GemmOp::NoTrans,
GemmOp::Trans,
n as i32,
k as i32,
m as i32,
1.0,
grad_output.as_ptr(),
n as i32,
a.as_ptr(),
k as i32,
0.0,
grad_b.as_ptr(),
n as i32,
)
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS GEMM backward_b failed: {e:?}")))
}
#[cfg(feature = "cuda")]
pub fn batched_4d_gemm_forward(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
batch: u32,
heads: u32,
m: u32,
n: u32,
k: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
if let Some(cublas) = cache.cublas() {
let batch_count = (batch * heads) as i32;
let stride_a = i64::from(m) * i64::from(k);
let stride_b = i64::from(k) * i64::from(n);
let stride_c = i64::from(m) * i64::from(n);
return cublas
.gemm_f32_strided_batched_row_major(
m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
stride_a,
b.as_ptr(),
stride_b,
0.0,
c.as_ptr(),
stride_c,
batch_count,
)
.map_err(|e| {
CudaTensorError::KernelError(format!("cuBLAS batched 4D GEMM failed: {e:?}"))
});
}
let kernel = Batched4DGemmKernel::new(batch, heads, m, n, k);
let tile_size = kernel.config.tile_size;
let key = format!("batched_4d_gemm_{batch}_{heads}_{m}_{n}_{k}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (n.div_ceil(tile_size), m.div_ceil(tile_size), batch * heads),
block: (tile_size, tile_size, 1),
shared_mem: tile_size * tile_size * 4 * 2,
};
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_ptr();
let mut args: [*mut std::ffi::c_void; 8] = [
&a_ptr as *const _ as *mut _,
&b_ptr as *const _ as *mut _,
&c_ptr as *const _ as *mut _,
&batch as *const _ as *mut _,
&heads as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "batched_4d_gemm", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Batched 4D GEMM forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_nf4_forward(
a: &GpuBuffer<f32>,
b_nf4: &GpuBuffer<u8>,
b_scales: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let kernel = Nf4GemmKernel::new(m, n, k);
let tile_size = kernel.tile_size;
let key = format!("nf4_gemm_forward_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (n.div_ceil(tile_size), m.div_ceil(tile_size), 1),
block: (tile_size * tile_size, 1, 1),
shared_mem: 16 * 4, };
let a_ptr = a.as_ptr();
let b_nf4_ptr = b_nf4.as_ptr();
let b_scales_ptr = b_scales.as_ptr();
let c_ptr = c.as_ptr();
let mut args: [*mut std::ffi::c_void; 7] = [
&a_ptr as *const _ as *mut _,
&b_nf4_ptr as *const _ as *mut _,
&b_scales_ptr as *const _ as *mut _,
&c_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "nf4_gemm_fused", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("NF4 GEMM forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_nf4_tc_forward(
a: &GpuBuffer<f32>,
b_nf4: &GpuBuffer<u8>,
b_scales: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let kernel = Nf4TensorCoreGemmKernel::new(m, n, k);
let key = format!("nf4_tc_gemm_forward_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (n.div_ceil(16), m.div_ceil(16), 1),
block: (32, 1, 1),
shared_mem: 16 * 16 * 2 * 2, };
let a_ptr = a.as_ptr();
let b_nf4_ptr = b_nf4.as_ptr();
let b_scales_ptr = b_scales.as_ptr();
let c_ptr = c.as_ptr();
let mut args: [*mut std::ffi::c_void; 7] = [
&a_ptr as *const _ as *mut _,
&b_scales_ptr as *const _ as *mut _,
&b_nf4_ptr as *const _ as *mut _,
&c_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "nf4_tensor_core_gemm", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!(
"NF4 tensor core GEMM forward launch failed: {e:?}"
))
})?;
}
Ok(())
}
pub fn gemm_nf4_gate_up_forward(
a: &GpuBuffer<f32>,
wg_nf4: &GpuBuffer<u8>,
wg_scales: &GpuBuffer<f32>,
wu_nf4: &GpuBuffer<u8>,
wu_scales: &GpuBuffer<f32>,
gate: &mut GpuBuffer<f32>,
up: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
use trueno_gpu::kernels::FusedNf4GateUpGemmKernel;
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let kernel = FusedNf4GateUpGemmKernel::new(m, n, k);
let tile = kernel.tile_size;
let key = format!("fused_nf4_gate_up_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (n.div_ceil(tile), m.div_ceil(tile), 1),
block: (tile * tile, 1, 1),
shared_mem: 16 * 4,
};
let a_ptr = a.as_ptr();
let gate_ptr = gate.as_ptr();
let up_ptr = up.as_ptr();
let wg_nf4_ptr = wg_nf4.as_ptr();
let wg_scales_ptr = wg_scales.as_ptr();
let wu_nf4_ptr = wu_nf4.as_ptr();
let wu_scales_ptr = wu_scales.as_ptr();
let mut args: [*mut std::ffi::c_void; 10] = [
&gate_ptr as *const _ as *mut _,
&up_ptr as *const _ as *mut _,
&a_ptr as *const _ as *mut _,
&wg_scales_ptr as *const _ as *mut _,
&wg_nf4_ptr as *const _ as *mut _,
&wu_scales_ptr as *const _ as *mut _,
&wu_nf4_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "fused_nf4_gate_up_gemm", &config, &mut args).map_err(
|e| CudaTensorError::KernelError(format!("Fused NF4 gate+up launch: {e:?}")),
)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_forward_bf16(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = format!("gemm_bf16_compute_{m}_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = build_gemm_bf16_compute_ptx(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (n.div_ceil(16), m.div_ceil(16), 1),
block: (16, 16, 1),
shared_mem: 0,
};
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&a_ptr as *const _ as *mut _,
&b_ptr as *const _ as *mut _,
&c_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "gemm_bf16_compute", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("BF16 GEMM forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
fn build_gemm_bf16_compute_ptx(sm_target: &str) -> String {
format!(
r".version 7.0
.target {sm_target}
.address_size 64
.visible .entry gemm_bf16_compute(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 c_ptr,
.param .u32 M,
.param .u32 N,
.param .u32 K
) {{
.reg .u32 %r<20>;
.reg .u64 %rd<8>;
.reg .f32 %f<4>;
.reg .pred %p<4>;
// col = ctaid.x * 16 + tid.x
mov.u32 %r0, %ctaid.x;
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %tid.x;
mad.lo.u32 %r3, %r0, %r1, %r2;
// row = ctaid.y * 16 + tid.y
mov.u32 %r4, %ctaid.y;
mov.u32 %r5, %ntid.y;
mov.u32 %r6, %tid.y;
mad.lo.u32 %r7, %r4, %r5, %r6;
// Load params
ld.param.u64 %rd0, [a_ptr];
ld.param.u64 %rd1, [b_ptr];
ld.param.u64 %rd2, [c_ptr];
ld.param.u32 %r8, [M];
ld.param.u32 %r9, [N];
ld.param.u32 %r10, [K];
// Bounds check: row < M && col < N
setp.ge.u32 %p0, %r7, %r8;
setp.ge.u32 %p1, %r3, %r9;
or.pred %p2, %p0, %p1;
@%p2 bra exit;
// acc = 0.0f
mov.f32 %f0, 0f00000000;
// Loop: for i = 0; i < K; i++
mov.u32 %r11, 0;
loop_start:
setp.ge.u32 %p3, %r11, %r10;
@%p3 bra loop_end;
// Load A[row, i] as u32 bits, truncate to bf16 precision
mul.lo.u32 %r12, %r7, %r10;
add.u32 %r12, %r12, %r11;
mul.wide.u32 %rd3, %r12, 4;
add.u64 %rd3, %rd0, %rd3;
ld.global.u32 %r13, [%rd3];
and.b32 %r13, %r13, 0xFFFF0000;
mov.b32 %f1, %r13;
// Load B[i, col] as u32 bits, truncate to bf16 precision
mul.lo.u32 %r14, %r11, %r9;
add.u32 %r14, %r14, %r3;
mul.wide.u32 %rd4, %r14, 4;
add.u64 %rd4, %rd1, %rd4;
ld.global.u32 %r15, [%rd4];
and.b32 %r15, %r15, 0xFFFF0000;
mov.b32 %f2, %r15;
// acc += a_bf16 * b_bf16 (FMA in f32 accumulator)
fma.rn.f32 %f0, %f1, %f2, %f0;
add.u32 %r11, %r11, 1;
bra loop_start;
loop_end:
// Store C[row, col]
mul.lo.u32 %r16, %r7, %r9;
add.u32 %r16, %r16, %r3;
mul.wide.u32 %rd5, %r16, 4;
add.u64 %rd5, %rd2, %rd5;
st.global.f32 [%rd5], %f0;
exit:
ret;
}}
"
)
}
#[cfg(feature = "cuda")]
pub fn gemm_nf4_dequant_cublas(
a: &GpuBuffer<f32>,
w: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let _ = stream;
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let cublas = cache.cublas().ok_or_else(|| {
CudaTensorError::KernelError("cuBLAS not available for NF4 dequant GEMM".to_string())
})?;
cublas
.gemm_f32(
GemmOp::Trans, GemmOp::NoTrans, n as i32, m as i32, k as i32, 1.0,
w.as_ptr(), k as i32, a.as_ptr(), k as i32, 0.0,
c.as_ptr(), n as i32, )
.map_err(|e| {
CudaTensorError::KernelError(format!("cuBLAS NF4 dequant forward failed: {e:?}"))
})
}
#[cfg(feature = "cuda")]
pub fn gemm_nf4_backward_a_cublas(
grad_output: &GpuBuffer<f32>,
w: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let _ = stream;
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let cublas = cache.cublas().ok_or_else(|| {
CudaTensorError::KernelError("cuBLAS not available for NF4 backward GEMM".to_string())
})?;
cublas
.gemm_f32(
GemmOp::NoTrans, GemmOp::NoTrans, k as i32, m as i32, n as i32, 1.0,
w.as_ptr(), k as i32, grad_output.as_ptr(), n as i32, 0.0,
grad_input.as_ptr(), k as i32, )
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS NF4 backward_a failed: {e:?}")))
}
#[cfg(feature = "cuda")]
pub fn gemm_nf4_backward_a(
grad_output: &GpuBuffer<f32>,
w_nf4: &GpuBuffer<u8>,
w_scales: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
m: u32,
n: u32,
k: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let kernel = Nf4GemmTransposeKernel::new(m, n, k);
let tile_size = kernel.tile_size;
let key = format!("nf4_gemm_transpose_{n}_{k}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (k.div_ceil(tile_size), m.div_ceil(tile_size), 1),
block: (tile_size * tile_size, 1, 1),
shared_mem: 16 * 4, };
let a_ptr = grad_output.as_ptr();
let b_nf4_ptr = w_nf4.as_ptr();
let b_scales_ptr = w_scales.as_ptr();
let c_ptr = grad_input.as_ptr();
let mut args: [*mut std::ffi::c_void; 7] = [
&a_ptr as *const _ as *mut _,
&b_nf4_ptr as *const _ as *mut _,
&b_scales_ptr as *const _ as *mut _,
&c_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "nf4_gemm_transpose", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("NF4 GEMM transpose launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_nf4_tc_backward_a(
grad_output: &GpuBuffer<f32>,
w_nf4: &GpuBuffer<u8>,
w_scales: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
m: u32,
n: u32,
k: u32,
stream: &CudaStream,
) -> Result<()> {
use trueno_gpu::kernels::backward::Nf4TensorCoreGemmBackwardAKernel;
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let kernel = Nf4TensorCoreGemmBackwardAKernel::new(m, n, k);
let key = format!("nf4_tc_gemm_backward_a_{n}_{k}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (k.div_ceil(16), m.div_ceil(16), 1),
block: (32, 1, 1),
shared_mem: 16 * 16 * 2 * 2, };
let grad_out_ptr = grad_output.as_ptr();
let scales_ptr = w_scales.as_ptr();
let data_ptr = w_nf4.as_ptr();
let grad_a_ptr = grad_input.as_ptr();
let mut args: [*mut std::ffi::c_void; 7] = [
&grad_out_ptr as *const _ as *mut _,
&scales_ptr as *const _ as *mut _,
&data_ptr as *const _ as *mut _,
&grad_a_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream
.launch_kernel(module, "nf4_tensor_core_gemm_backward_a", &config, &mut args)
.map_err(|e| {
CudaTensorError::KernelError(format!(
"NF4 tensor core GEMM backward_a launch failed: {e:?}"
))
})?;
}
Ok(())
}