use core::ffi::c_void;
use core::mem::MaybeUninit;
use cudarc::cutensor::sys as ct_sys;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct CutensorError(pub ct_sys::cutensorStatus_t);
impl std::fmt::Display for CutensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "cutensor status: {:?}", self.0)
}
}
impl std::error::Error for CutensorError {}
#[inline]
fn check(status: ct_sys::cutensorStatus_t) -> Result<(), CutensorError> {
match status {
ct_sys::cutensorStatus_t::CUTENSOR_STATUS_SUCCESS => Ok(()),
e => Err(CutensorError(e)),
}
}
use std::sync::OnceLock;
struct ComputeDescriptors {
r_min_32f: ct_sys::cutensorComputeDescriptor_t,
r_min_64f: ct_sys::cutensorComputeDescriptor_t,
r_min_16f: ct_sys::cutensorComputeDescriptor_t,
r_min_16bf: ct_sys::cutensorComputeDescriptor_t,
r_min_tf32: ct_sys::cutensorComputeDescriptor_t,
r_32f: ct_sys::cutensorComputeDescriptor_t,
r_64f: ct_sys::cutensorComputeDescriptor_t,
c_32f: ct_sys::cutensorComputeDescriptor_t,
}
unsafe impl Send for ComputeDescriptors {}
unsafe impl Sync for ComputeDescriptors {}
static DESCRIPTORS: OnceLock<ComputeDescriptors> = OnceLock::new();
fn load_descriptors() -> ComputeDescriptors {
let candidates = [
"libcutensor.so.2",
"libcutensor.so.1",
"libcutensor.so",
"cutensor.dll",
];
for cand in candidates.iter() {
let lib = unsafe { libloading::Library::new(*cand) };
let Ok(lib) = lib else { continue };
let read = |name: &[u8]| -> Option<ct_sys::cutensorComputeDescriptor_t> {
unsafe {
let s: libloading::Symbol<*const ct_sys::cutensorComputeDescriptor_t> =
lib.get(name).ok()?;
Some(**s)
}
};
let r_min_32f = read(b"CUTENSOR_R_MIN_32F\0");
let r_min_64f = read(b"CUTENSOR_R_MIN_64F\0");
let r_min_16f = read(b"CUTENSOR_R_MIN_16F\0");
let r_min_16bf = read(b"CUTENSOR_R_MIN_16BF\0");
let r_min_tf32 = read(b"CUTENSOR_R_MIN_TF32\0");
let r_32f = read(b"CUTENSOR_R_32F\0");
let r_64f = read(b"CUTENSOR_R_64F\0");
let c_32f = read(b"CUTENSOR_C_32F\0");
if let (Some(a), Some(b)) = (r_min_32f, r_min_64f) {
std::mem::forget(lib);
return ComputeDescriptors {
r_min_32f: a,
r_min_64f: b,
r_min_16f: r_min_16f.unwrap_or(a),
r_min_16bf: r_min_16bf.unwrap_or(a),
r_min_tf32: r_min_tf32.unwrap_or(a),
r_32f: r_32f.unwrap_or(a),
r_64f: r_64f.unwrap_or(b),
c_32f: c_32f.unwrap_or(a),
};
}
}
panic!(
"ContextPoisoned: failed to dlopen libcutensor.so / locate \
CUTENSOR_R_MIN_32F (compute descriptor symbol). cuTENSOR \
must be installed on the host for cutensor-feature builds."
);
}
#[inline]
fn descriptors() -> &'static ComputeDescriptors {
DESCRIPTORS.get_or_init(load_descriptors)
}
pub fn r_min_32f() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_min_32f
}
pub fn r_min_64f() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_min_64f
}
pub fn r_min_16f() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_min_16f
}
pub fn r_min_16bf() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_min_16bf
}
pub fn r_min_tf32() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_min_tf32
}
pub fn r_32f() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_32f
}
pub fn r_64f() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().r_64f
}
pub fn c_32f() -> ct_sys::cutensorComputeDescriptor_t {
descriptors().c_32f
}
pub unsafe fn reduce(
handle: ct_sys::cutensorHandle_t,
plan: ct_sys::cutensorPlan_t,
alpha: *const c_void,
a: *const c_void,
beta: *const c_void,
c: *const c_void,
d: *mut c_void,
workspace: *mut c_void,
workspace_size: u64,
stream: ct_sys::cudaStream_t,
) -> Result<(), CutensorError> {
check(ct_sys::cutensorReduce(
handle,
plan,
alpha,
a,
beta,
c,
d,
workspace,
workspace_size,
stream,
))
}
pub unsafe fn create_elementwise_binary(
handle: ct_sys::cutensorHandle_t,
desc_a: ct_sys::cutensorTensorDescriptor_t,
mode_a: *const i32,
op_a: ct_sys::cutensorOperator_t,
desc_c: ct_sys::cutensorTensorDescriptor_t,
mode_c: *const i32,
op_c: ct_sys::cutensorOperator_t,
desc_d: ct_sys::cutensorTensorDescriptor_t,
mode_d: *const i32,
op_ac: ct_sys::cutensorOperator_t,
desc_compute: ct_sys::cutensorComputeDescriptor_t,
) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
let mut desc = MaybeUninit::uninit();
check(ct_sys::cutensorCreateElementwiseBinary(
handle,
desc.as_mut_ptr(),
desc_a,
mode_a,
op_a,
desc_c,
mode_c,
op_c,
desc_d,
mode_d,
op_ac,
desc_compute,
))?;
Ok(desc.assume_init())
}
pub unsafe fn elementwise_binary_execute(
handle: ct_sys::cutensorHandle_t,
plan: ct_sys::cutensorPlan_t,
alpha: *const c_void,
a: *const c_void,
gamma: *const c_void,
c: *const c_void,
d: *mut c_void,
stream: ct_sys::cudaStream_t,
) -> Result<(), CutensorError> {
check(ct_sys::cutensorElementwiseBinaryExecute(
handle, plan, alpha, a, gamma, c, d, stream,
))
}
pub unsafe fn create_elementwise_trinary(
handle: ct_sys::cutensorHandle_t,
desc_a: ct_sys::cutensorTensorDescriptor_t,
mode_a: *const i32,
op_a: ct_sys::cutensorOperator_t,
desc_b: ct_sys::cutensorTensorDescriptor_t,
mode_b: *const i32,
op_b: ct_sys::cutensorOperator_t,
desc_c: ct_sys::cutensorTensorDescriptor_t,
mode_c: *const i32,
op_c: ct_sys::cutensorOperator_t,
desc_d: ct_sys::cutensorTensorDescriptor_t,
mode_d: *const i32,
op_ab: ct_sys::cutensorOperator_t,
op_abc: ct_sys::cutensorOperator_t,
desc_compute: ct_sys::cutensorComputeDescriptor_t,
) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
let mut desc = MaybeUninit::uninit();
check(ct_sys::cutensorCreateElementwiseTrinary(
handle,
desc.as_mut_ptr(),
desc_a,
mode_a,
op_a,
desc_b,
mode_b,
op_b,
desc_c,
mode_c,
op_c,
desc_d,
mode_d,
op_ab,
op_abc,
desc_compute,
))?;
Ok(desc.assume_init())
}
pub unsafe fn elementwise_trinary_execute(
handle: ct_sys::cutensorHandle_t,
plan: ct_sys::cutensorPlan_t,
alpha: *const c_void,
a: *const c_void,
beta: *const c_void,
b: *const c_void,
gamma: *const c_void,
c: *const c_void,
d: *mut c_void,
stream: ct_sys::cudaStream_t,
) -> Result<(), CutensorError> {
check(ct_sys::cutensorElementwiseTrinaryExecute(
handle, plan, alpha, a, beta, b, gamma, c, d, stream,
))
}
pub unsafe fn create_permutation(
handle: ct_sys::cutensorHandle_t,
desc_a: ct_sys::cutensorTensorDescriptor_t,
mode_a: *const i32,
op_a: ct_sys::cutensorOperator_t,
desc_b: ct_sys::cutensorTensorDescriptor_t,
mode_b: *const i32,
desc_compute: ct_sys::cutensorComputeDescriptor_t,
) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
let mut desc = MaybeUninit::uninit();
check(ct_sys::cutensorCreatePermutation(
handle,
desc.as_mut_ptr(),
desc_a,
mode_a,
op_a,
desc_b,
mode_b,
desc_compute,
))?;
Ok(desc.assume_init())
}
pub unsafe fn permute(
handle: ct_sys::cutensorHandle_t,
plan: ct_sys::cutensorPlan_t,
alpha: *const c_void,
a: *const c_void,
b: *mut c_void,
stream: ct_sys::cudaStream_t,
) -> Result<(), CutensorError> {
check(ct_sys::cutensorPermute(handle, plan, alpha, a, b, stream))
}
pub unsafe fn plan_preference_set_algo(
handle: ct_sys::cutensorHandle_t,
pref: ct_sys::cutensorPlanPreference_t,
algo: ct_sys::cutensorAlgo_t,
) -> Result<(), CutensorError> {
let value = algo as i32;
check(ct_sys::cutensorPlanPreferenceSetAttribute(
handle,
pref,
ct_sys::cutensorPlanPreferenceAttribute_t::CUTENSOR_PLAN_PREFERENCE_ALGO,
&value as *const i32 as *const c_void,
std::mem::size_of::<i32>(),
))
}