use super::sys;
use core::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CutensorError(pub sys::cutensorStatus_t);
impl sys::cutensorStatus_t {
fn result(self) -> Result<(), CutensorError> {
match self {
sys::cutensorStatus_t::CUTENSOR_STATUS_SUCCESS => Ok(()),
_ => Err(CutensorError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CutensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CutensorError {}
pub fn create_handle() -> Result<sys::cutensorHandle_t, CutensorError> {
let mut handle = MaybeUninit::uninit();
unsafe {
sys::cutensorCreate(handle.as_mut_ptr()).result()?;
Ok(handle.assume_init())
}
}
pub unsafe fn destroy_handle(handle: sys::cutensorHandle_t) -> Result<(), CutensorError> {
sys::cutensorDestroy(handle).result()
}
pub fn get_version() -> (usize, usize, usize) {
let version = unsafe { sys::cutensorGetVersion() };
let major = version / 10000;
let minor = (version % 10000) / 100;
let patch = version % 100;
(major, minor, patch)
}
pub unsafe fn create_tensor_descriptor(
handle: sys::cutensorHandle_t,
num_modes: u32,
extent: *const i64,
stride: *const i64,
data_type: sys::cudaDataType_t,
alignment_requirement: u32,
) -> Result<sys::cutensorTensorDescriptor_t, CutensorError> {
let mut desc = MaybeUninit::uninit();
sys::cutensorCreateTensorDescriptor(
handle,
desc.as_mut_ptr(),
num_modes,
extent,
stride,
data_type,
alignment_requirement,
)
.result()?;
Ok(desc.assume_init())
}
pub unsafe fn destroy_tensor_descriptor(
desc: sys::cutensorTensorDescriptor_t,
) -> Result<(), CutensorError> {
sys::cutensorDestroyTensorDescriptor(desc).result()
}
pub unsafe fn create_contraction(
handle: sys::cutensorHandle_t,
desc_a: sys::cutensorTensorDescriptor_t,
mode_a: *const i32,
op_a: sys::cutensorOperator_t,
desc_b: sys::cutensorTensorDescriptor_t,
mode_b: *const i32,
op_b: sys::cutensorOperator_t,
desc_c: sys::cutensorTensorDescriptor_t,
mode_c: *const i32,
op_c: sys::cutensorOperator_t,
desc_d: sys::cutensorTensorDescriptor_t,
mode_d: *const i32,
compute_desc: sys::cutensorComputeDescriptor_t,
) -> Result<sys::cutensorOperationDescriptor_t, CutensorError> {
let mut desc = MaybeUninit::uninit();
sys::cutensorCreateContraction(
handle,
desc.as_mut_ptr(),
desc_a,
mode_a,
op_a,
desc_b,
mode_b,
op_b,
desc_c,
mode_c,
op_c,
desc_d,
mode_d,
compute_desc,
)
.result()?;
Ok(desc.assume_init())
}
pub unsafe fn create_reduction(
handle: sys::cutensorHandle_t,
desc_a: sys::cutensorTensorDescriptor_t,
mode_a: *const i32,
op_a: sys::cutensorOperator_t,
desc_c: sys::cutensorTensorDescriptor_t,
mode_c: *const i32,
op_c: sys::cutensorOperator_t,
desc_d: sys::cutensorTensorDescriptor_t,
mode_d: *const i32,
op_reduce: sys::cutensorOperator_t,
compute_desc: sys::cutensorComputeDescriptor_t,
) -> Result<sys::cutensorOperationDescriptor_t, CutensorError> {
let mut desc = MaybeUninit::uninit();
sys::cutensorCreateReduction(
handle,
desc.as_mut_ptr(),
desc_a,
mode_a,
op_a,
desc_c,
mode_c,
op_c,
desc_d,
mode_d,
op_reduce,
compute_desc,
)
.result()?;
Ok(desc.assume_init())
}
pub unsafe fn destroy_operation_descriptor(
desc: sys::cutensorOperationDescriptor_t,
) -> Result<(), CutensorError> {
sys::cutensorDestroyOperationDescriptor(desc).result()
}
pub unsafe fn create_plan_preference(
handle: sys::cutensorHandle_t,
algo: sys::cutensorAlgo_t,
jit_mode: sys::cutensorJitMode_t,
) -> Result<sys::cutensorPlanPreference_t, CutensorError> {
let mut pref = MaybeUninit::uninit();
sys::cutensorCreatePlanPreference(handle, pref.as_mut_ptr(), algo, jit_mode).result()?;
Ok(pref.assume_init())
}
pub unsafe fn destroy_plan_preference(
pref: sys::cutensorPlanPreference_t,
) -> Result<(), CutensorError> {
sys::cutensorDestroyPlanPreference(pref).result()
}
pub unsafe fn estimate_workspace_size(
handle: sys::cutensorHandle_t,
desc: sys::cutensorOperationDescriptor_t,
pref: sys::cutensorPlanPreference_t,
workspace_pref: sys::cutensorWorksizePreference_t,
) -> Result<u64, CutensorError> {
let mut workspace_size = MaybeUninit::uninit();
sys::cutensorEstimateWorkspaceSize(
handle,
desc,
pref,
workspace_pref,
workspace_size.as_mut_ptr(),
)
.result()?;
Ok(workspace_size.assume_init())
}
pub unsafe fn create_plan(
handle: sys::cutensorHandle_t,
desc: sys::cutensorOperationDescriptor_t,
pref: sys::cutensorPlanPreference_t,
workspace_size: u64,
) -> Result<sys::cutensorPlan_t, CutensorError> {
let mut plan = MaybeUninit::uninit();
sys::cutensorCreatePlan(handle, plan.as_mut_ptr(), desc, pref, workspace_size).result()?;
Ok(plan.assume_init())
}
pub unsafe fn destroy_plan(plan: sys::cutensorPlan_t) -> Result<(), CutensorError> {
sys::cutensorDestroyPlan(plan).result()
}
pub unsafe fn contract(
handle: sys::cutensorHandle_t,
plan: sys::cutensorPlan_t,
alpha: *const core::ffi::c_void,
a: *const core::ffi::c_void,
b: *const core::ffi::c_void,
beta: *const core::ffi::c_void,
c: *const core::ffi::c_void,
d: *mut core::ffi::c_void,
workspace: *mut core::ffi::c_void,
workspace_size: u64,
stream: sys::cudaStream_t,
) -> Result<(), CutensorError> {
sys::cutensorContract(
handle,
plan,
alpha,
a,
b,
beta,
c,
d,
workspace,
workspace_size,
stream,
)
.result()
}
pub unsafe fn reduce(
handle: sys::cutensorHandle_t,
plan: sys::cutensorPlan_t,
alpha: *const core::ffi::c_void,
a: *const core::ffi::c_void,
beta: *const core::ffi::c_void,
c: *const core::ffi::c_void,
d: *mut core::ffi::c_void,
workspace: *mut core::ffi::c_void,
workspace_size: u64,
stream: sys::cudaStream_t,
) -> Result<(), CutensorError> {
sys::cutensorReduce(
handle,
plan,
alpha,
a,
beta,
c,
d,
workspace,
workspace_size,
stream,
)
.result()
}