use std::ffi::c_void;
use std::os::raw::c_int;
use crate::GpuError;
pub type CublasHandle = *mut c_void;
pub type CublasStatus = c_int;
pub const CUBLAS_STATUS_SUCCESS: CublasStatus = 0;
pub const CUBLAS_STATUS_NOT_INITIALIZED: CublasStatus = 1;
pub const CUBLAS_STATUS_ALLOC_FAILED: CublasStatus = 3;
pub const CUBLAS_STATUS_INVALID_VALUE: CublasStatus = 7;
pub const CUBLAS_STATUS_ARCH_MISMATCH: CublasStatus = 8;
pub const CUBLAS_STATUS_EXECUTION_FAILED: CublasStatus = 13;
pub const CUBLAS_STATUS_INTERNAL_ERROR: CublasStatus = 14;
pub const CUBLAS_STATUS_NOT_SUPPORTED: CublasStatus = 15;
pub type CublasOperation = c_int;
pub const CUBLAS_OP_N: CublasOperation = 0;
pub const CUBLAS_OP_T: CublasOperation = 1;
pub const CUBLAS_OP_C: CublasOperation = 2;
pub type CudaDataType = c_int;
pub const CUDA_R_16F: CudaDataType = 2;
pub const CUDA_R_32F: CudaDataType = 0;
pub const CUDA_R_16BF: CudaDataType = 14;
pub type CublasComputeType = c_int;
pub const CUBLAS_COMPUTE_32F: CublasComputeType = 68;
pub const CUBLAS_COMPUTE_32F_FAST_TF32: CublasComputeType = 74;
pub const CUBLAS_COMPUTE_16F: CublasComputeType = 64;
pub type CublasMathMode = c_int;
pub const CUBLAS_DEFAULT_MATH: CublasMathMode = 0;
pub const CUBLAS_TENSOR_OP_MATH: CublasMathMode = 1;
pub const CUBLAS_PEDANTIC_MATH: CublasMathMode = 2;
pub const CUBLAS_TF32_TENSOR_OP_MATH: CublasMathMode = 3;
#[allow(non_snake_case)]
pub struct CublasDriver {
pub cublasCreate_v2: unsafe extern "C" fn(handle: *mut CublasHandle) -> CublasStatus,
pub cublasDestroy_v2: unsafe extern "C" fn(handle: CublasHandle) -> CublasStatus,
pub cublasSetStream_v2:
unsafe extern "C" fn(handle: CublasHandle, stream: *mut c_void) -> CublasStatus,
pub cublasSetMathMode:
unsafe extern "C" fn(handle: CublasHandle, mode: CublasMathMode) -> CublasStatus,
#[allow(clippy::type_complexity)]
pub cublasGemmEx: unsafe extern "C" fn(
handle: CublasHandle,
transa: CublasOperation,
transb: CublasOperation,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
a_type: CudaDataType,
lda: c_int,
b: *const c_void,
b_type: CudaDataType,
ldb: c_int,
beta: *const c_void,
c: *mut c_void,
c_type: CudaDataType,
ldc: c_int,
compute_type: CublasComputeType,
algo: c_int,
) -> CublasStatus,
#[allow(clippy::type_complexity)]
pub cublasSgemmStridedBatched: unsafe extern "C" fn(
handle: CublasHandle,
transa: CublasOperation,
transb: CublasOperation,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const c_void,
lda: c_int,
stride_a: i64,
b: *const c_void,
ldb: c_int,
stride_b: i64,
beta: *const f32,
c: *mut c_void,
ldc: c_int,
stride_c: i64,
batch_count: c_int,
) -> CublasStatus,
}
#[cfg(feature = "cuda")]
mod loading {
use super::*;
use libloading::{Library, Symbol};
use std::sync::OnceLock;
static CUBLAS_DRIVER: OnceLock<Option<CublasDriver>> = OnceLock::new();
static CUBLAS_LIBRARY: OnceLock<Option<Library>> = OnceLock::new();
impl CublasDriver {
#[must_use]
pub fn load() -> Option<&'static Self> {
let _ = CUBLAS_LIBRARY.get_or_init(|| {
let lib_names = ["libcublas.so.12", "libcublas.so"];
for name in lib_names {
if let Ok(lib) = unsafe { Library::new(name) } {
return Some(lib);
}
}
None
});
CUBLAS_DRIVER
.get_or_init(|| {
let lib = CUBLAS_LIBRARY.get()?.as_ref()?;
Self::load_from_library(lib)
})
.as_ref()
}
fn load_from_library(lib: &Library) -> Option<Self> {
unsafe {
macro_rules! load_sym {
($name:ident, $ty:ty) => {{
let sym: Symbol<'_, $ty> = lib.get(stringify!($name).as_bytes()).ok()?;
*sym
}};
}
type FnCreate = unsafe extern "C" fn(*mut CublasHandle) -> CublasStatus;
type FnDestroy = unsafe extern "C" fn(CublasHandle) -> CublasStatus;
type FnSetStream = unsafe extern "C" fn(CublasHandle, *mut c_void) -> CublasStatus;
type FnSetMathMode =
unsafe extern "C" fn(CublasHandle, CublasMathMode) -> CublasStatus;
type FnGemmEx = unsafe extern "C" fn(
CublasHandle,
CublasOperation,
CublasOperation,
c_int,
c_int,
c_int,
*const c_void,
*const c_void,
CudaDataType,
c_int,
*const c_void,
CudaDataType,
c_int,
*const c_void,
*mut c_void,
CudaDataType,
c_int,
CublasComputeType,
c_int,
) -> CublasStatus;
type FnSgemmStridedBatched = unsafe extern "C" fn(
CublasHandle,
CublasOperation,
CublasOperation,
c_int,
c_int,
c_int,
*const f32,
*const c_void,
c_int,
i64,
*const c_void,
c_int,
i64,
*const f32,
*mut c_void,
c_int,
i64,
c_int,
) -> CublasStatus;
Some(CublasDriver {
cublasCreate_v2: load_sym!(cublasCreate_v2, FnCreate),
cublasDestroy_v2: load_sym!(cublasDestroy_v2, FnDestroy),
cublasSetStream_v2: load_sym!(cublasSetStream_v2, FnSetStream),
cublasSetMathMode: load_sym!(cublasSetMathMode, FnSetMathMode),
cublasGemmEx: load_sym!(cublasGemmEx, FnGemmEx),
cublasSgemmStridedBatched: load_sym!(
cublasSgemmStridedBatched,
FnSgemmStridedBatched
),
})
}
}
pub fn check(result: CublasStatus) -> Result<(), GpuError> {
if result == CUBLAS_STATUS_SUCCESS {
Ok(())
} else {
Err(GpuError::CudaDriver(cublas_status_string(result).to_string(), result))
}
}
}
}
#[cfg(not(feature = "cuda"))]
mod loading {
use super::*;
impl CublasDriver {
#[must_use]
pub fn load() -> Option<&'static Self> {
None
}
pub fn check(_result: CublasStatus) -> Result<(), GpuError> {
Err(GpuError::CudaNotAvailable("cuda feature not enabled".to_string()))
}
}
}
pub const CUBLAS_GEMM_DEFAULT: c_int = -1;
pub const CUBLAS_GEMM_DEFAULT_TENSOR_OP: c_int = 99;
#[must_use]
pub fn cublas_status_string(status: CublasStatus) -> &'static str {
match status {
CUBLAS_STATUS_SUCCESS => "CUBLAS_STATUS_SUCCESS",
CUBLAS_STATUS_NOT_INITIALIZED => "CUBLAS_STATUS_NOT_INITIALIZED",
CUBLAS_STATUS_ALLOC_FAILED => "CUBLAS_STATUS_ALLOC_FAILED",
CUBLAS_STATUS_INVALID_VALUE => "CUBLAS_STATUS_INVALID_VALUE",
CUBLAS_STATUS_ARCH_MISMATCH => "CUBLAS_STATUS_ARCH_MISMATCH",
CUBLAS_STATUS_EXECUTION_FAILED => "CUBLAS_STATUS_EXECUTION_FAILED",
CUBLAS_STATUS_INTERNAL_ERROR => "CUBLAS_STATUS_INTERNAL_ERROR",
CUBLAS_STATUS_NOT_SUPPORTED => "CUBLAS_STATUS_NOT_SUPPORTED",
_ => "CUBLAS_STATUS_UNKNOWN",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_status_strings() {
assert_eq!(cublas_status_string(CUBLAS_STATUS_SUCCESS), "CUBLAS_STATUS_SUCCESS");
assert_eq!(
cublas_status_string(CUBLAS_STATUS_INVALID_VALUE),
"CUBLAS_STATUS_INVALID_VALUE"
);
assert_eq!(cublas_status_string(999), "CUBLAS_STATUS_UNKNOWN");
}
#[test]
fn test_operation_constants() {
assert_eq!(CUBLAS_OP_N, 0);
assert_eq!(CUBLAS_OP_T, 1);
}
#[test]
fn test_data_type_constants() {
assert_eq!(CUDA_R_16F, 2);
assert_eq!(CUDA_R_32F, 0);
}
#[test]
fn test_compute_type_constants() {
assert_eq!(CUBLAS_COMPUTE_32F, 68);
assert_eq!(CUBLAS_COMPUTE_32F_FAST_TF32, 74);
assert_eq!(CUBLAS_COMPUTE_16F, 64);
}
#[test]
fn test_math_mode_constants() {
assert_eq!(CUBLAS_DEFAULT_MATH, 0);
assert_eq!(CUBLAS_TENSOR_OP_MATH, 1);
assert_eq!(CUBLAS_PEDANTIC_MATH, 2);
assert_eq!(CUBLAS_TF32_TENSOR_OP_MATH, 3);
}
#[cfg(not(feature = "cuda"))]
#[test]
fn test_cublas_not_available_without_feature() {
assert!(CublasDriver::load().is_none());
}
}