use super::sys::{self};
use crate::cublaslt::sys::cublasLtMatmulAlgo_t;
use core::ffi::c_void;
use core::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CublasError(pub sys::cublasStatus_t);
impl sys::cublasStatus_t {
pub fn result(self) -> Result<(), CublasError> {
match self {
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
_ => Err(CublasError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CublasError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CublasError {}
pub fn create_handle() -> Result<sys::cublasLtHandle_t, CublasError> {
let mut handle = MaybeUninit::uninit();
unsafe {
sys::cublasLtCreate(handle.as_mut_ptr()).result()?;
Ok(handle.assume_init())
}
}
pub unsafe fn destroy_handle(handle: sys::cublasLtHandle_t) -> Result<(), CublasError> {
sys::cublasLtDestroy(handle).result()
}
pub fn create_matrix_layout(
matrix_type: sys::cudaDataType,
rows: u64,
cols: u64,
ld: i64,
) -> Result<sys::cublasLtMatrixLayout_t, CublasError> {
let mut matrix_layout = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatrixLayoutCreate(matrix_layout.as_mut_ptr(), matrix_type, rows, cols, ld)
.result()?;
Ok(matrix_layout.assume_init())
}
}
pub unsafe fn set_matrix_layout_attribute(
matrix_layout: sys::cublasLtMatrixLayout_t,
attr: sys::cublasLtMatrixLayoutAttribute_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatrixLayoutSetAttribute(matrix_layout, attr, buf, buf_size).result()
}
pub unsafe fn destroy_matrix_layout(
matrix_layout: sys::cublasLtMatrixLayout_t,
) -> Result<(), CublasError> {
sys::cublasLtMatrixLayoutDestroy(matrix_layout).result()
}
pub fn create_matmul_desc(
compute_type: sys::cublasComputeType_t,
scale_type: sys::cudaDataType,
) -> Result<sys::cublasLtMatmulDesc_t, CublasError> {
let mut matmul_desc = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatmulDescCreate(matmul_desc.as_mut_ptr(), compute_type, scale_type)
.result()?;
Ok(matmul_desc.assume_init())
}
}
pub unsafe fn set_matmul_desc_attribute(
matmul_desc: sys::cublasLtMatmulDesc_t,
attr: sys::cublasLtMatmulDescAttributes_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatmulDescSetAttribute(matmul_desc, attr, buf, buf_size).result()
}
pub unsafe fn destroy_matmul_desc(
matmul_desc: sys::cublasLtMatmulDesc_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmulDescDestroy(matmul_desc).result()
}
pub fn create_matmul_pref() -> Result<sys::cublasLtMatmulPreference_t, CublasError> {
let mut matmul_pref = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatmulPreferenceCreate(matmul_pref.as_mut_ptr()).result()?;
Ok(matmul_pref.assume_init())
}
}
pub unsafe fn set_matmul_pref_attribute(
matmul_pref: sys::cublasLtMatmulPreference_t,
attr: sys::cublasLtMatmulPreferenceAttributes_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatmulPreferenceSetAttribute(matmul_pref, attr, buf, buf_size).result()
}
pub unsafe fn destroy_matmul_pref(
matmul_pref: sys::cublasLtMatmulPreference_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmulPreferenceDestroy(matmul_pref).result()
}
pub unsafe fn get_matmul_algo_heuristic(
handle: sys::cublasLtHandle_t,
matmul_desc: sys::cublasLtMatmulDesc_t,
a_layout: sys::cublasLtMatrixLayout_t,
b_layout: sys::cublasLtMatrixLayout_t,
c_layout: sys::cublasLtMatrixLayout_t,
d_layout: sys::cublasLtMatrixLayout_t,
matmul_pref: sys::cublasLtMatmulPreference_t,
) -> Result<sys::cublasLtMatmulHeuristicResult_t, CublasError> {
let mut matmul_heuristic = MaybeUninit::uninit();
let mut algo_count = 0;
sys::cublasLtMatmulAlgoGetHeuristic(
handle,
matmul_desc,
a_layout,
b_layout,
c_layout,
d_layout,
matmul_pref,
1, matmul_heuristic.as_mut_ptr(),
&mut algo_count,
)
.result()?;
if algo_count == 0 {
return Err(CublasError(
sys::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED,
));
}
let matmul_heuristic = matmul_heuristic.assume_init();
matmul_heuristic.state.result()?;
Ok(matmul_heuristic)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul(
handle: sys::cublasLtHandle_t,
matmul_desc: sys::cublasLtMatmulDesc_t,
alpha: *const c_void,
beta: *const c_void,
a: *const c_void,
a_layout: sys::cublasLtMatrixLayout_t,
b: *const c_void,
b_layout: sys::cublasLtMatrixLayout_t,
c: *const c_void,
c_layout: sys::cublasLtMatrixLayout_t,
d: *mut c_void,
d_layout: sys::cublasLtMatrixLayout_t,
algo: *const cublasLtMatmulAlgo_t,
workspace: *mut c_void,
workspace_size: usize,
stream: sys::cudaStream_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmul(
handle,
matmul_desc,
alpha,
a,
a_layout,
b,
b_layout,
beta,
c,
c_layout,
d,
d_layout,
algo,
workspace,
workspace_size,
stream,
)
.result()
}