#![allow(unsafe_code)]
#![allow(trivial_casts)]
#![allow(clippy::borrow_as_ptr)]
#![allow(clippy::ref_as_ptr)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::backward::{GemmBackwardAKernel, GemmBackwardBKernel};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::Kernel;
use super::super::cuda_tensor::{CudaTensorError, Result};
#[cfg(feature = "cuda")]
use super::cache::KERNEL_CACHE;
#[cfg(feature = "cuda")]
use crate::autograd::cuda_forward::{cublas_gemm_backward_a, cublas_gemm_backward_b};
const BACKWARD_TILE_SIZE: u32 = 16;
#[cfg(feature = "cuda")]
pub fn gemm_backward_a(
grad_output: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
if let Some(cublas) = cache.cublas() {
return cublas_gemm_backward_a(cublas, grad_output, b, grad_a, m, k, n);
}
let tile = BACKWARD_TILE_SIZE;
let kernel = GemmBackwardAKernel::tiled_unrolled(m, n, k, tile);
let kernel_name = kernel.name();
let key = format!("gemm_backward_a_{m}_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let smem = 2 * tile * tile * 4; let config = LaunchConfig {
grid: (k.div_ceil(tile), m.div_ceil(tile), 1),
block: (tile, tile, 1),
shared_mem: smem,
};
let grad_out_ptr = grad_output.as_ptr();
let b_ptr = b.as_ptr();
let grad_a_ptr = grad_a.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&grad_out_ptr as *const _ as *mut _,
&b_ptr as *const _ as *mut _,
&grad_a_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("GEMM backward A launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_backward_b(
a: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_b: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
if let Some(cublas) = cache.cublas() {
return cublas_gemm_backward_b(cublas, a, grad_output, grad_b, m, k, n);
}
let tile = BACKWARD_TILE_SIZE;
let kernel = GemmBackwardBKernel::tiled_unrolled(m, n, k, tile);
let kernel_name = kernel.name();
let key = format!("gemm_backward_b_{m}_{k}_{n}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let smem = 2 * tile * tile * 4;
let config = LaunchConfig {
grid: (n.div_ceil(tile), k.div_ceil(tile), 1),
block: (tile, tile, 1),
shared_mem: smem,
};
let a_ptr = a.as_ptr();
let grad_out_ptr = grad_output.as_ptr();
let grad_b_ptr = grad_b.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&a_ptr as *const _ as *mut _,
&grad_out_ptr as *const _ as *mut _,
&grad_b_ptr as *const _ as *mut _,
&m as *const _ as *mut _,
&n as *const _ as *mut _,
&k as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("GEMM backward B launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gemm_backward_a_accumulate(
grad_output: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
_stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
if let Some(cublas) = cache.cublas() {
return crate::autograd::cuda_forward::cublas_gemm_backward_a_accumulate(
cublas,
grad_output,
b,
grad_a,
m,
k,
n,
);
}
Err(CudaTensorError::KernelError(
"gemm_backward_a_accumulate requires cuBLAS (NF4 training always has it)".to_string(),
))
}
#[cfg(feature = "cuda")]
pub fn gemm_backward_a_fp16_dispatch_accumulate(
grad_output: &GpuBuffer<f32>,
w_fp16: Option<&GpuBuffer<u16>>,
w_fp32: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
_ctx: &trueno_gpu::driver::CudaContext,
) -> Result<()> {
if w_fp16.is_some() {
let mut temp = GpuBuffer::<f32>::new(_ctx, (m * k) as usize)
.map_err(|e| CudaTensorError::AllocationFailed(format!("fp16 accum temp: {e:?}")))?;
gemm_backward_a_fp16_dispatch(
grad_output,
w_fp16,
w_fp32,
&mut temp,
m,
k,
n,
stream,
_ctx,
)?;
crate::transformer::cuda_block::cuda_add_inplace(grad_a, &temp, (m * k) as usize, stream)?;
Ok(())
} else {
gemm_backward_a_accumulate(grad_output, w_fp32, grad_a, m, k, n, stream)
}
}
#[cfg(feature = "cuda")]
pub fn gemm_backward_a_fp16_dispatch(
grad_output: &GpuBuffer<f32>,
w_fp16: Option<&GpuBuffer<u16>>,
w_fp32: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
ctx: &trueno_gpu::driver::CudaContext,
) -> Result<()> {
if let Some(w16) = w_fp16 {
let elems = (m * n) as usize;
let mut grad_f16 = GpuBuffer::<u16>::new(ctx, elems)
.map_err(|e| CudaTensorError::AllocationFailed(format!("grad f16 cast: {e:?}")))?;
crate::autograd::cuda_forward::cast_f32_to_f16_gpu(
grad_output,
&mut grad_f16,
m * n,
stream,
)?;
crate::autograd::cuda_forward::gemm_f16_to_f32_backward_a(
&grad_f16, w16, grad_a, m, k, n, stream,
)
} else {
gemm_backward_a(grad_output, w_fp32, grad_a, m, k, n, stream)
}
}