#![allow(unsafe_code)]
#![allow(trivial_casts)]
#![allow(clippy::borrow_as_ptr)]
#![allow(clippy::ref_as_ptr)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CublasHandle, CudaStream, GemmOp, GpuBuffer};
use crate::autograd::cuda_tensor::{CudaTensorError, Result};
#[cfg(feature = "cuda")]
use super::cache::FORWARD_KERNEL_CACHE;
#[cfg(feature = "cuda")]
pub fn gemm_forward_f16(
a: &GpuBuffer<u16>,
b: &GpuBuffer<u16>,
c: &mut GpuBuffer<u16>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let cublas = cache.cublas().ok_or_else(|| {
CudaTensorError::KernelError("cuBLAS handle required for fp16 GEMM".to_string())
})?;
let _ = stream; cublas
.gemm_f16(
GemmOp::NoTrans,
GemmOp::NoTrans,
n as i32,
m as i32,
k as i32,
1.0,
b.as_ptr(),
n as i32,
a.as_ptr(),
k as i32,
0.0,
c.as_ptr(),
n as i32,
)
.map_err(|e| {
CudaTensorError::KernelError(format!("cuBLAS fp16 GEMM forward failed: {e:?}"))
})
}
#[cfg(feature = "cuda")]
pub(crate) fn cublas_gemm_backward_a_f16(
cublas: &CublasHandle,
grad_output: &GpuBuffer<u16>,
b: &GpuBuffer<u16>,
grad_a: &mut GpuBuffer<u16>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f16(
GemmOp::Trans,
GemmOp::NoTrans,
k as i32,
m as i32,
n as i32,
1.0,
b.as_ptr(),
n as i32,
grad_output.as_ptr(),
n as i32,
0.0,
grad_a.as_ptr(),
k as i32,
)
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS fp16 backward_a failed: {e:?}")))
}
#[cfg(feature = "cuda")]
pub(crate) fn cublas_gemm_backward_b_f16(
cublas: &CublasHandle,
a: &GpuBuffer<u16>,
grad_output: &GpuBuffer<u16>,
grad_b: &mut GpuBuffer<u16>,
m: u32,
k: u32,
n: u32,
) -> Result<()> {
cublas
.gemm_f16(
GemmOp::NoTrans,
GemmOp::Trans,
n as i32,
k as i32,
m as i32,
1.0,
grad_output.as_ptr(),
n as i32,
a.as_ptr(),
k as i32,
0.0,
grad_b.as_ptr(),
n as i32,
)
.map_err(|e| CudaTensorError::KernelError(format!("cuBLAS fp16 backward_b failed: {e:?}")))
}
#[cfg(feature = "cuda")]
pub fn gemm_f16_to_f32_backward_a(
grad_output: &GpuBuffer<u16>,
b: &GpuBuffer<u16>,
grad_a: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let cublas = cache.cublas().ok_or_else(|| {
CudaTensorError::KernelError("cuBLAS handle required for fp16→fp32 backward".to_string())
})?;
let _ = stream;
cublas
.gemm_f16_to_f32(
GemmOp::Trans,
GemmOp::NoTrans,
k as i32,
m as i32,
n as i32,
1.0,
b.as_ptr(),
n as i32,
grad_output.as_ptr(),
n as i32,
0.0,
grad_a.as_ptr(),
k as i32,
)
.map_err(|e| {
CudaTensorError::KernelError(format!("cuBLAS fp16→fp32 backward_a failed: {e:?}"))
})
}
#[cfg(feature = "cuda")]
pub fn gemm_f16_to_f32_forward(
a: &GpuBuffer<u16>,
b: &GpuBuffer<u16>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let cublas = cache.cublas().ok_or_else(|| {
CudaTensorError::KernelError("cuBLAS handle required for fp16→fp32 GEMM".to_string())
})?;
let _ = stream;
cublas
.gemm_f16_to_f32(
GemmOp::NoTrans,
GemmOp::NoTrans,
n as i32,
m as i32,
k as i32,
1.0,
b.as_ptr(),
n as i32,
a.as_ptr(),
k as i32,
0.0,
c.as_ptr(),
n as i32,
)
.map_err(|e| {
CudaTensorError::KernelError(format!("cuBLAS fp16→fp32 GEMM forward failed: {e:?}"))
})
}