use std::ptr;
use super::cublas_sys::*;
use super::stream::CudaStream;
use crate::driver::context::CudaContext;
use crate::driver::sys::CUdeviceptr;
use crate::GpuError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmOp {
NoTrans,
Trans,
}
impl GemmOp {
fn to_cublas(self) -> CublasOperation {
match self {
GemmOp::NoTrans => CUBLAS_OP_N,
GemmOp::Trans => CUBLAS_OP_T,
}
}
}
pub struct CublasHandle {
handle: super::cublas_sys::CublasHandle,
}
unsafe impl Send for CublasHandle {}
unsafe impl Sync for CublasHandle {}
impl CublasHandle {
pub fn new(_ctx: &CudaContext) -> Result<Self, GpuError> {
let driver = get_cublas_driver()?;
let mut handle: super::cublas_sys::CublasHandle = ptr::null_mut();
let result = unsafe { (driver.cublasCreate_v2)(&mut handle) };
CublasDriver::check(result)
.map_err(|e| GpuError::CudaDriver(format!("cublasCreate_v2: {e}"), 0))?;
let result = unsafe { (driver.cublasSetMathMode)(handle, CUBLAS_DEFAULT_MATH) };
if result != CUBLAS_STATUS_SUCCESS {
unsafe { (driver.cublasDestroy_v2)(handle) };
return Err(GpuError::CudaDriver(
format!("cublasSetMathMode: {}", cublas_status_string(result)),
result,
));
}
Ok(Self { handle })
}
pub fn set_stream(&self, stream: &CudaStream) -> Result<(), GpuError> {
let driver = get_cublas_driver()?;
let result = unsafe { (driver.cublasSetStream_v2)(self.handle, stream.raw()) };
CublasDriver::check(result)
.map_err(|e| GpuError::CudaDriver(format!("cublasSetStream_v2: {e}"), 0))
}
pub fn gemm_f16(
&self,
transa: GemmOp,
transb: GemmOp,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
lda: i32,
b_ptr: CUdeviceptr,
ldb: i32,
beta: f32,
c_ptr: CUdeviceptr,
ldc: i32,
) -> Result<(), GpuError> {
let driver = get_cublas_driver()?;
let compute_type = CUBLAS_COMPUTE_32F;
let result = unsafe {
(driver.cublasGemmEx)(
self.handle,
transa.to_cublas(),
transb.to_cublas(),
m,
n,
k,
&alpha as *const f32 as *const std::ffi::c_void,
a_ptr as *const std::ffi::c_void,
CUDA_R_16F,
lda,
b_ptr as *const std::ffi::c_void,
CUDA_R_16F,
ldb,
&beta as *const f32 as *const std::ffi::c_void,
c_ptr as *mut std::ffi::c_void,
CUDA_R_16F,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
};
CublasDriver::check(result)
.map_err(|e| GpuError::CudaDriver(format!("cublasGemmEx(m={m}, n={n}, k={k}): {e}"), 0))
}
pub fn gemm_f16_to_f32(
&self,
transa: GemmOp,
transb: GemmOp,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
lda: i32,
b_ptr: CUdeviceptr,
ldb: i32,
beta: f32,
c_ptr: CUdeviceptr,
ldc: i32,
) -> Result<(), GpuError> {
let driver = get_cublas_driver()?;
let result = unsafe {
(driver.cublasGemmEx)(
self.handle,
transa.to_cublas(),
transb.to_cublas(),
m,
n,
k,
&alpha as *const f32 as *const std::ffi::c_void,
a_ptr as *const std::ffi::c_void,
CUDA_R_16F,
lda,
b_ptr as *const std::ffi::c_void,
CUDA_R_16F,
ldb,
&beta as *const f32 as *const std::ffi::c_void,
c_ptr as *mut std::ffi::c_void,
CUDA_R_32F,
ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
};
CublasDriver::check(result).map_err(|e| {
GpuError::CudaDriver(format!("cublasGemmEx_f16_f32(m={m}, n={n}, k={k}): {e}"), 0)
})
}
pub fn gemm_f32(
&self,
transa: GemmOp,
transb: GemmOp,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
lda: i32,
b_ptr: CUdeviceptr,
ldb: i32,
beta: f32,
c_ptr: CUdeviceptr,
ldc: i32,
) -> Result<(), GpuError> {
let driver = get_cublas_driver()?;
let result = unsafe {
(driver.cublasGemmEx)(
self.handle,
transa.to_cublas(),
transb.to_cublas(),
m,
n,
k,
&alpha as *const f32 as *const std::ffi::c_void,
a_ptr as *const std::ffi::c_void,
CUDA_R_32F,
lda,
b_ptr as *const std::ffi::c_void,
CUDA_R_32F,
ldb,
&beta as *const f32 as *const std::ffi::c_void,
c_ptr as *mut std::ffi::c_void,
CUDA_R_32F,
ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT,
)
};
CublasDriver::check(result).map_err(|e| {
GpuError::CudaDriver(format!("cublasGemmEx_f32(m={m}, n={n}, k={k}): {e}"), 0)
})
}
pub fn gemm_f32_strided_batched(
&self,
transa: GemmOp,
transb: GemmOp,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
lda: i32,
stride_a: i64,
b_ptr: CUdeviceptr,
ldb: i32,
stride_b: i64,
beta: f32,
c_ptr: CUdeviceptr,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> Result<(), GpuError> {
let driver = get_cublas_driver()?;
let result = unsafe {
(driver.cublasSgemmStridedBatched)(
self.handle,
transa.to_cublas(),
transb.to_cublas(),
m,
n,
k,
&alpha,
a_ptr as *const std::ffi::c_void,
lda,
stride_a,
b_ptr as *const std::ffi::c_void,
ldb,
stride_b,
&beta,
c_ptr as *mut std::ffi::c_void,
ldc,
stride_c,
batch_count,
)
};
CublasDriver::check(result).map_err(|e| {
GpuError::CudaDriver(
format!("cublasSgemmStridedBatched(m={m}, n={n}, k={k}, batch={batch_count}): {e}"),
0,
)
})
}
#[must_use]
pub fn raw(&self) -> super::cublas_sys::CublasHandle {
self.handle
}
}
impl Drop for CublasHandle {
fn drop(&mut self) {
if let Some(driver) = CublasDriver::load() {
unsafe {
let _ = (driver.cublasDestroy_v2)(self.handle);
}
}
}
}
fn get_cublas_driver() -> Result<&'static CublasDriver, GpuError> {
CublasDriver::load()
.ok_or_else(|| GpuError::CudaNotAvailable("cuBLAS library not found".to_string()))
}
impl CublasHandle {
pub fn gemm_f16_row_major(
&self,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
b_ptr: CUdeviceptr,
beta: f32,
c_ptr: CUdeviceptr,
) -> Result<(), GpuError> {
self.gemm_f16(
GemmOp::NoTrans, GemmOp::NoTrans, n, m, k, alpha,
b_ptr,
n, a_ptr,
k, beta,
c_ptr,
n, )
}
pub fn gemm_f32_strided_batched_row_major(
&self,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
stride_a: i64,
b_ptr: CUdeviceptr,
stride_b: i64,
beta: f32,
c_ptr: CUdeviceptr,
stride_c: i64,
batch_count: i32,
) -> Result<(), GpuError> {
self.gemm_f32_strided_batched(
GemmOp::NoTrans,
GemmOp::NoTrans,
n,
m,
k,
alpha,
b_ptr,
n,
stride_b,
a_ptr,
k,
stride_a,
beta,
c_ptr,
n,
stride_c,
batch_count,
)
}
pub fn gemm_f32_row_major(
&self,
m: i32,
n: i32,
k: i32,
alpha: f32,
a_ptr: CUdeviceptr,
b_ptr: CUdeviceptr,
beta: f32,
c_ptr: CUdeviceptr,
) -> Result<(), GpuError> {
self.gemm_f32(
GemmOp::NoTrans,
GemmOp::NoTrans,
n,
m,
k,
alpha,
b_ptr,
n,
a_ptr,
k,
beta,
c_ptr,
n,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemm_op_to_cublas() {
assert_eq!(GemmOp::NoTrans.to_cublas(), CUBLAS_OP_N);
assert_eq!(GemmOp::Trans.to_cublas(), CUBLAS_OP_T);
}
#[cfg(not(feature = "cuda"))]
#[test]
fn test_cublas_handle_requires_cuda() {
let result = get_cublas_driver();
assert!(result.is_err());
}
}