use std::ffi::c_void;
use std::ptr;
#[cfg(test)]
use cudarc::cublaslt::sys::cublasLtMatmulDescOpaque_t;
use cudarc::cublaslt::sys::{
cublasLtMatmulDescAttributes_t, cublasLtMatmulDesc_t, cublasLtMatmulPreferenceAttributes_t,
cublasLtMatmulPreferenceOpaque_t, cublasLtMatmulPreference_t, cublasStatus_t,
};
pub fn check(status: cublasStatus_t, op: &str) -> Result<(), String> {
match status {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
other => Err(format!("{op}: {other:?}")),
}
}
pub struct Preference {
pub raw: cublasLtMatmulPreference_t,
}
unsafe impl Send for Preference {}
unsafe impl Sync for Preference {}
impl Preference {
pub fn new() -> Result<Self, String> {
let mut raw: cublasLtMatmulPreference_t =
ptr::null_mut::<cublasLtMatmulPreferenceOpaque_t>();
let status = unsafe { cudarc::cublaslt::sys::cublasLtMatmulPreferenceCreate(&mut raw) };
check(status, "cublasLtMatmulPreferenceCreate")?;
Ok(Self { raw })
}
pub fn set_u64(
&self,
attr: cublasLtMatmulPreferenceAttributes_t,
value: u64,
) -> Result<(), String> {
let status = unsafe {
cudarc::cublaslt::sys::cublasLtMatmulPreferenceSetAttribute(
self.raw,
attr,
&value as *const u64 as *const c_void,
std::mem::size_of::<u64>(),
)
};
check(status, "cublasLtMatmulPreferenceSetAttribute")
}
}
impl Drop for Preference {
fn drop(&mut self) {
if !self.raw.is_null() {
unsafe {
let _ = cudarc::cublaslt::sys::cublasLtMatmulPreferenceDestroy(self.raw);
}
self.raw = ptr::null_mut::<cublasLtMatmulPreferenceOpaque_t>();
}
}
}
pub unsafe fn set_desc_pointer_attr(
desc: cublasLtMatmulDesc_t,
attr: cublasLtMatmulDescAttributes_t,
ptr: *const c_void,
) -> Result<(), String> {
let status = unsafe {
cudarc::cublaslt::sys::cublasLtMatmulDescSetAttribute(
desc,
attr,
&ptr as *const *const c_void as *const c_void,
std::mem::size_of::<*const c_void>(),
)
};
check(status, "cublasLtMatmulDescSetAttribute(pointer)")
}
pub unsafe fn set_desc_i32_attr(
desc: cublasLtMatmulDesc_t,
attr: cublasLtMatmulDescAttributes_t,
value: i32,
) -> Result<(), String> {
let status = unsafe {
cudarc::cublaslt::sys::cublasLtMatmulDescSetAttribute(
desc,
attr,
&value as *const i32 as *const c_void,
std::mem::size_of::<i32>(),
)
};
check(status, "cublasLtMatmulDescSetAttribute(i32)")
}
#[cfg(test)]
pub fn mock_desc_handle() -> cublasLtMatmulDesc_t {
let leaked: Box<cublasLtMatmulDescOpaque_t> = Box::new(unsafe { std::mem::zeroed() });
Box::into_raw(leaked)
}
#[cfg(test)]
pub unsafe fn drop_mock_desc(desc: cublasLtMatmulDesc_t) {
if !desc.is_null() {
let _ = unsafe { Box::from_raw(desc) };
}
}