#![warn(missing_debug_implementations)]
use core::ffi::c_void;
use std::ffi::CString;
use baracuda_cutensor_sys::{
cutensor, cutensorAlgo, cutensorDataType, cutensorHandle_t, cutensorJitMode,
cutensorOperationDescriptor_t, cutensorOperator, cutensorPlanPreference_t, cutensorPlan_t,
cutensorStatus_t, cutensorTensorDescriptor_t, cutensorWorksizePreference,
};
pub type Error = baracuda_core::Error<cutensorStatus_t>;
pub type Result<T, E = Error> = core::result::Result<T, E>;
#[inline]
fn check(status: cutensorStatus_t) -> Result<()> {
Error::check(status)
}
pub fn probe() -> Result<()> {
cutensor()?;
Ok(())
}
pub fn version() -> Result<usize> {
let c = cutensor()?;
let cu = c.cutensor_get_version()?;
Ok(unsafe { cu() })
}
pub fn cudart_version() -> Result<usize> {
let c = cutensor()?;
let cu = c.cutensor_get_cudart_version()?;
Ok(unsafe { cu() })
}
pub fn set_log_level(level: i32) -> Result<()> {
let c = cutensor()?;
let cu = c.cutensor_logger_set_level()?;
check(unsafe { cu(level) })
}
pub fn set_log_mask(mask: i32) -> Result<()> {
let c = cutensor()?;
let cu = c.cutensor_logger_set_mask()?;
check(unsafe { cu(mask) })
}
pub fn open_log_file(path: &str) -> Result<()> {
let cpath = std::ffi::CString::new(path).map_err(|_| Error::Status {
status: cutensorStatus_t::INVALID_VALUE,
})?;
let c = cutensor()?;
let cu = c.cutensor_logger_open_file()?;
check(unsafe { cu(cpath.as_ptr()) })
}
pub fn force_disable_logging() -> Result<()> {
let c = cutensor()?;
let cu = c.cutensor_logger_force_disable()?;
check(unsafe { cu() })
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DataType {
F16,
BF16,
F32,
F64,
ComplexF32,
ComplexF64,
I8,
U8,
I32,
U32,
}
impl DataType {
#[inline]
fn raw(self) -> i32 {
match self {
DataType::F16 => cutensorDataType::R_16F,
DataType::BF16 => cutensorDataType::R_16BF,
DataType::F32 => cutensorDataType::R_32F,
DataType::F64 => cutensorDataType::R_64F,
DataType::ComplexF32 => cutensorDataType::C_32F,
DataType::ComplexF64 => cutensorDataType::C_64F,
DataType::I8 => cutensorDataType::R_8I,
DataType::U8 => cutensorDataType::R_8U,
DataType::I32 => cutensorDataType::R_32I,
DataType::U32 => cutensorDataType::R_32U,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum UnaryOp {
Identity,
Sqrt,
Relu,
Conj,
Rcp,
Sigmoid,
Tanh,
}
impl UnaryOp {
#[inline]
fn raw(self) -> i32 {
match self {
UnaryOp::Identity => cutensorOperator::IDENTITY,
UnaryOp::Sqrt => cutensorOperator::SQRT,
UnaryOp::Relu => cutensorOperator::RELU,
UnaryOp::Conj => cutensorOperator::CONJ,
UnaryOp::Rcp => cutensorOperator::RCP,
UnaryOp::Sigmoid => cutensorOperator::SIGMOID,
UnaryOp::Tanh => cutensorOperator::TANH,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum BinaryOp {
Add,
Mul,
Max,
Min,
}
impl BinaryOp {
#[inline]
fn raw(self) -> i32 {
match self {
BinaryOp::Add => cutensorOperator::ADD,
BinaryOp::Mul => cutensorOperator::MUL,
BinaryOp::Max => cutensorOperator::MAX,
BinaryOp::Min => cutensorOperator::MIN,
}
}
}
#[derive(Debug)]
pub struct Handle {
handle: cutensorHandle_t,
}
unsafe impl Send for Handle {}
impl Handle {
pub fn new() -> Result<Self> {
let c = cutensor()?;
let cu = c.cutensor_create()?;
let mut h: cutensorHandle_t = core::ptr::null_mut();
check(unsafe { cu(&mut h) })?;
Ok(Self { handle: h })
}
#[inline]
pub fn as_raw(&self) -> cutensorHandle_t {
self.handle
}
pub fn resize_plan_cache(&self, num_entries: u32) -> Result<()> {
let c = cutensor()?;
let cu = c.cutensor_handle_resize_plan_cache()?;
check(unsafe { cu(self.handle, num_entries) })
}
pub fn write_plan_cache_to_file(&self, path: &str) -> Result<()> {
let cpath = CString::new(path).map_err(|_| Error::Status {
status: cutensorStatus_t::INVALID_VALUE,
})?;
let c = cutensor()?;
let cu = c.cutensor_handle_write_plan_cache_to_file()?;
check(unsafe { cu(self.handle, cpath.as_ptr()) })
}
pub fn read_plan_cache_from_file(&self, path: &str) -> Result<()> {
let cpath = CString::new(path).map_err(|_| Error::Status {
status: cutensorStatus_t::INVALID_VALUE,
})?;
let c = cutensor()?;
let cu = c.cutensor_handle_read_plan_cache_from_file()?;
check(unsafe { cu(self.handle, cpath.as_ptr()) })
}
pub fn write_kernel_cache_to_file(&self, path: &str) -> Result<()> {
let cpath = CString::new(path).map_err(|_| Error::Status {
status: cutensorStatus_t::INVALID_VALUE,
})?;
let c = cutensor()?;
let cu = c.cutensor_write_kernel_cache_to_file()?;
check(unsafe { cu(self.handle, cpath.as_ptr()) })
}
pub fn read_kernel_cache_from_file(&self, path: &str) -> Result<()> {
let cpath = CString::new(path).map_err(|_| Error::Status {
status: cutensorStatus_t::INVALID_VALUE,
})?;
let c = cutensor()?;
let cu = c.cutensor_read_kernel_cache_from_file()?;
check(unsafe { cu(self.handle, cpath.as_ptr()) })
}
pub fn compute_desc_32f(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_32f()?)
}
pub fn compute_desc_64f(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_64f()?)
}
pub fn compute_desc_16f(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_16f()?)
}
pub fn compute_desc_16bf(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_16bf()?)
}
pub fn compute_desc_tf32(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_tf32()?)
}
pub fn compute_desc_3xtf32(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_3xtf32()?)
}
pub fn compute_desc_4x16f(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_4x16f()?)
}
pub fn compute_desc_8xint8(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_8xint8()?)
}
pub fn compute_desc_9x16bf(&self) -> Result<*const c_void> {
Ok(cutensor()?.compute_desc_9x16bf()?)
}
}
#[derive(Debug)]
pub struct ComputeDescriptor<'h> {
desc: baracuda_cutensor_sys::cutensorComputeDescriptor_t,
_handle: &'h Handle,
}
impl<'h> ComputeDescriptor<'h> {
pub fn new(handle: &'h Handle) -> Result<Self> {
let c = cutensor()?;
let cu = c.cutensor_create_compute_descriptor()?;
let mut desc: baracuda_cutensor_sys::cutensorComputeDescriptor_t = core::ptr::null();
check(unsafe { cu(handle.as_raw(), &mut desc as *mut _ as *mut _) })?;
Ok(Self {
desc,
_handle: handle,
})
}
#[inline]
pub fn as_raw(&self) -> baracuda_cutensor_sys::cutensorComputeDescriptor_t {
self.desc
}
pub unsafe fn set_attribute(
&self,
attr: i32,
value: *const c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_compute_descriptor_set_attribute()?;
check(cu(
self._handle.as_raw(),
self.desc,
attr,
value,
size_bytes,
))
}}
pub unsafe fn get_attribute(
&self,
attr: i32,
value: *mut c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_compute_descriptor_get_attribute()?;
check(cu(
self._handle.as_raw(),
self.desc,
attr,
value,
size_bytes,
))
}}
}
impl Drop for ComputeDescriptor<'_> {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy_compute_descriptor() {
let _ = unsafe { cu(self.desc) };
}
}
}
}
#[derive(Debug)]
pub struct BlockSparseTensorDescriptor<'h> {
desc: baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t,
_handle: &'h Handle,
}
impl<'h> BlockSparseTensorDescriptor<'h> {
#[allow(clippy::too_many_arguments)]
pub fn new(
handle: &'h Handle,
extents: &[i64],
block_size: &[i64],
strides: Option<&[i64]>,
block_indices: &[i32],
dtype: DataType,
alignment_bytes: u32,
) -> Result<Self> {
assert_eq!(block_size.len(), extents.len());
if let Some(s) = strides {
assert_eq!(s.len(), extents.len());
}
let num_modes = extents.len() as u32;
let block_count = (block_indices.len() / extents.len()) as i64;
let c = cutensor()?;
let cu = c.cutensor_create_block_sparse_tensor_descriptor()?;
let mut desc: baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t =
core::ptr::null_mut();
check(unsafe {
cu(
handle.as_raw(),
&mut desc,
num_modes,
extents.as_ptr(),
block_size.as_ptr(),
strides.map_or(core::ptr::null(), |s| s.as_ptr()),
block_count,
block_indices.as_ptr(),
dtype.raw(),
alignment_bytes,
)
})?;
Ok(Self {
desc,
_handle: handle,
})
}
#[inline]
pub fn as_raw(&self) -> baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t {
self.desc
}
}
impl Drop for BlockSparseTensorDescriptor<'_> {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy_block_sparse_tensor_descriptor() {
let _ = unsafe { cu(self.desc) };
}
}
}
}
#[derive(Debug)]
pub struct BlockSparseContraction;
impl BlockSparseContraction {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &BlockSparseTensorDescriptor<'h>,
modes_a: &[i32],
b: &TensorDescriptor<'h>,
modes_b: &[i32],
c: &TensorDescriptor<'h>,
modes_c: &[i32],
d: &TensorDescriptor<'h>,
modes_d: &[i32],
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let lib = cutensor()?;
let cu = lib.cutensor_create_block_sparse_contraction()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
cutensorOperator::IDENTITY,
b.as_raw(),
modes_b.as_ptr(),
cutensorOperator::IDENTITY,
c.as_raw(),
modes_c.as_ptr(),
cutensorOperator::IDENTITY,
d.as_raw(),
modes_d.as_ptr(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::BlockSparseContraction,
})
}}
}
#[derive(Debug)]
pub struct TrinaryContraction;
impl TrinaryContraction {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &TensorDescriptor<'h>,
modes_a: &[i32],
b: &TensorDescriptor<'h>,
modes_b: &[i32],
c: &TensorDescriptor<'h>,
modes_c: &[i32],
d: &TensorDescriptor<'h>,
modes_d: &[i32],
e: &TensorDescriptor<'h>,
modes_e: &[i32],
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let lib = cutensor()?;
let cu = lib.cutensor_create_contraction_trinary()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
cutensorOperator::IDENTITY,
b.as_raw(),
modes_b.as_ptr(),
cutensorOperator::IDENTITY,
c.as_raw(),
modes_c.as_ptr(),
cutensorOperator::IDENTITY,
d.as_raw(),
modes_d.as_ptr(),
cutensorOperator::IDENTITY,
e.as_raw(),
modes_e.as_ptr(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::TrinaryContraction,
})
}}
}
impl Drop for Handle {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[derive(Debug)]
pub struct TensorDescriptor<'h> {
desc: cutensorTensorDescriptor_t,
_handle: &'h Handle,
}
impl<'h> TensorDescriptor<'h> {
pub fn new(
handle: &'h Handle,
extents: &[i64],
strides: Option<&[i64]>,
dtype: DataType,
alignment_bytes: u32,
) -> Result<Self> {
let c = cutensor()?;
let cu = c.cutensor_create_tensor_descriptor()?;
let num_modes = extents.len() as u32;
if let Some(s) = strides {
assert_eq!(s.len(), extents.len(), "strides length mismatch");
}
let mut desc: cutensorTensorDescriptor_t = core::ptr::null_mut();
check(unsafe {
cu(
handle.as_raw(),
&mut desc,
num_modes,
extents.as_ptr(),
strides.map_or(core::ptr::null(), |s| s.as_ptr()),
dtype.raw(),
alignment_bytes,
)
})?;
Ok(Self {
desc,
_handle: handle,
})
}
#[inline]
pub fn as_raw(&self) -> cutensorTensorDescriptor_t {
self.desc
}
pub unsafe fn set_attribute(
&self,
attr: i32,
buf: *const c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_tensor_descriptor_set_attribute()?;
check(cu(self._handle.as_raw(), self.desc, attr, buf, size_bytes))
}}
}
impl Drop for TensorDescriptor<'_> {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy_tensor_descriptor() {
let _ = unsafe { cu(self.desc) };
}
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum OpKind {
Contraction,
TrinaryContraction,
BlockSparseContraction,
Reduction,
ElementwiseBinary,
ElementwiseTrinary,
Permutation,
}
#[derive(Debug)]
pub struct OperationDescriptor<'h> {
desc: cutensorOperationDescriptor_t,
handle: &'h Handle,
kind: OpKind,
}
impl<'h> OperationDescriptor<'h> {
#[inline]
pub fn as_raw(&self) -> cutensorOperationDescriptor_t {
self.desc
}
pub fn estimate_workspace(
&self,
pref: &PlanPreference<'h>,
kind: WorkspaceKind,
) -> Result<u64> {
let c = cutensor()?;
let cu = c.cutensor_estimate_workspace_size()?;
let mut size: u64 = 0;
check(unsafe {
cu(
self.handle.as_raw(),
self.desc,
pref.as_raw(),
kind.raw(),
&mut size,
)
})?;
Ok(size)
}
pub fn estimate_runtime(&self, pref: &PlanPreference<'h>, algo: i32) -> Result<f32> {
let c = cutensor()?;
let cu = c.cutensor_operation_estimate_runtime()?;
let mut ms: f32 = 0.0;
check(unsafe {
cu(
self.handle.as_raw(),
self.desc,
pref.as_raw(),
algo,
&mut ms,
)
})?;
Ok(ms)
}
pub fn num_algos(&self) -> Result<i32> {
let c = cutensor()?;
let cu = c.cutensor_operation_num_algos()?;
let mut n: i32 = 0;
check(unsafe { cu(self.desc, &mut n) })?;
Ok(n)
}
pub unsafe fn get_attribute(
&self,
attr: i32,
buf: *mut c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_operation_descriptor_get_attribute()?;
check(cu(self.handle.as_raw(), self.desc, attr, buf, size_bytes))
}}
pub unsafe fn set_attribute(
&self,
attr: i32,
buf: *const c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_operation_descriptor_set_attribute()?;
check(cu(self.handle.as_raw(), self.desc, attr, buf, size_bytes))
}}
}
impl Drop for OperationDescriptor<'_> {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy_operation_descriptor() {
let _ = unsafe { cu(self.desc) };
}
}
}
}
#[derive(Debug)]
pub struct Contraction;
impl Contraction {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &TensorDescriptor<'h>,
modes_a: &[i32],
b: &TensorDescriptor<'h>,
modes_b: &[i32],
c: &TensorDescriptor<'h>,
modes_c: &[i32],
d: &TensorDescriptor<'h>,
modes_d: &[i32],
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let cu_lib = cutensor()?;
let cu = cu_lib.cutensor_create_contraction()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
cutensorOperator::IDENTITY,
b.as_raw(),
modes_b.as_ptr(),
cutensorOperator::IDENTITY,
c.as_raw(),
modes_c.as_ptr(),
cutensorOperator::IDENTITY,
d.as_raw(),
modes_d.as_ptr(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::Contraction,
})
}}
}
#[derive(Debug)]
pub struct Reduction;
impl Reduction {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &TensorDescriptor<'h>,
modes_a: &[i32],
c: &TensorDescriptor<'h>,
modes_c: &[i32],
d: &TensorDescriptor<'h>,
modes_d: &[i32],
op_reduce: BinaryOp,
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let lib = cutensor()?;
let cu = lib.cutensor_create_reduction()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
cutensorOperator::IDENTITY,
c.as_raw(),
modes_c.as_ptr(),
cutensorOperator::IDENTITY,
d.as_raw(),
modes_d.as_ptr(),
op_reduce.raw(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::Reduction,
})
}}
}
#[derive(Debug)]
pub struct ElementwiseBinary;
impl ElementwiseBinary {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &TensorDescriptor<'h>,
modes_a: &[i32],
op_a: UnaryOp,
c: &TensorDescriptor<'h>,
modes_c: &[i32],
op_c: UnaryOp,
d: &TensorDescriptor<'h>,
modes_d: &[i32],
op_ac: BinaryOp,
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let lib = cutensor()?;
let cu = lib.cutensor_create_elementwise_binary()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
op_a.raw(),
c.as_raw(),
modes_c.as_ptr(),
op_c.raw(),
d.as_raw(),
modes_d.as_ptr(),
op_ac.raw(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::ElementwiseBinary,
})
}}
}
#[derive(Debug)]
pub struct ElementwiseTrinary;
impl ElementwiseTrinary {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &TensorDescriptor<'h>,
modes_a: &[i32],
op_a: UnaryOp,
b: &TensorDescriptor<'h>,
modes_b: &[i32],
op_b: UnaryOp,
c: &TensorDescriptor<'h>,
modes_c: &[i32],
op_c: UnaryOp,
d: &TensorDescriptor<'h>,
modes_d: &[i32],
op_ab: BinaryOp,
op_abc: BinaryOp,
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let lib = cutensor()?;
let cu = lib.cutensor_create_elementwise_trinary()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
op_a.raw(),
b.as_raw(),
modes_b.as_ptr(),
op_b.raw(),
c.as_raw(),
modes_c.as_ptr(),
op_c.raw(),
d.as_raw(),
modes_d.as_ptr(),
op_ab.raw(),
op_abc.raw(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::ElementwiseTrinary,
})
}}
}
#[derive(Debug)]
pub struct Permutation;
impl Permutation {
#[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
pub unsafe fn new<'h>(
handle: &'h Handle,
a: &TensorDescriptor<'h>,
modes_a: &[i32],
op_a: UnaryOp,
b: &TensorDescriptor<'h>,
modes_b: &[i32],
compute_desc: *const c_void,
) -> Result<OperationDescriptor<'h>> { unsafe {
let lib = cutensor()?;
let cu = lib.cutensor_create_permutation()?;
let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
check(cu(
handle.as_raw(),
&mut desc,
a.as_raw(),
modes_a.as_ptr(),
op_a.raw(),
b.as_raw(),
modes_b.as_ptr(),
compute_desc,
))?;
Ok(OperationDescriptor {
desc,
handle,
kind: OpKind::Permutation,
})
}}
}
#[derive(Debug)]
pub struct PlanPreference<'h> {
pref: cutensorPlanPreference_t,
_handle: &'h Handle,
}
impl<'h> PlanPreference<'h> {
pub fn new(handle: &'h Handle, algo: i32, jit_mode: i32) -> Result<Self> {
let c = cutensor()?;
let cu = c.cutensor_create_plan_preference()?;
let mut p: cutensorPlanPreference_t = core::ptr::null_mut();
check(unsafe { cu(handle.as_raw(), &mut p, algo, jit_mode) })?;
Ok(Self {
pref: p,
_handle: handle,
})
}
pub fn default_for(handle: &'h Handle) -> Result<Self> {
Self::new(handle, cutensorAlgo::DEFAULT, cutensorJitMode::NONE)
}
#[inline]
pub fn as_raw(&self) -> cutensorPlanPreference_t {
self.pref
}
pub unsafe fn set_attribute(
&self,
attr: i32,
value: *const c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_plan_preference_set_attribute()?;
check(cu(
self._handle.as_raw(),
self.pref,
attr,
value,
size_bytes,
))
}}
pub unsafe fn get_attribute(
&self,
attr: i32,
value: *mut c_void,
size_bytes: usize,
) -> Result<()> { unsafe {
let c = cutensor()?;
let cu = c.cutensor_plan_preference_get_attribute()?;
check(cu(
self._handle.as_raw(),
self.pref,
attr,
value,
size_bytes,
))
}}
}
impl Drop for PlanPreference<'_> {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy_plan_preference() {
let _ = unsafe { cu(self.pref) };
}
}
}
}
#[derive(Copy, Clone, Debug)]
pub enum WorkspaceKind {
Min,
Default,
Max,
}
impl WorkspaceKind {
#[inline]
fn raw(self) -> i32 {
match self {
WorkspaceKind::Min => cutensorWorksizePreference::MIN,
WorkspaceKind::Default => cutensorWorksizePreference::DEFAULT,
WorkspaceKind::Max => cutensorWorksizePreference::MAX,
}
}
}
#[derive(Debug)]
pub struct Plan<'h> {
plan: cutensorPlan_t,
handle: &'h Handle,
kind: OpKind,
}
impl<'h> Plan<'h> {
pub fn new(
op: &OperationDescriptor<'h>,
pref: &PlanPreference<'h>,
workspace_size_limit: u64,
) -> Result<Self> {
let c = cutensor()?;
let cu = c.cutensor_create_plan()?;
let mut p: cutensorPlan_t = core::ptr::null_mut();
check(unsafe {
cu(
op.handle.as_raw(),
&mut p,
op.as_raw(),
pref.as_raw(),
workspace_size_limit,
)
})?;
Ok(Self {
plan: p,
handle: op.handle,
kind: op.kind,
})
}
#[inline]
pub fn as_raw(&self) -> cutensorPlan_t {
self.plan
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn contract(
&self,
alpha: *const c_void,
a: *const c_void,
b: *const c_void,
beta: *const c_void,
c: *const c_void,
d: *mut c_void,
workspace: *mut c_void,
workspace_bytes: u64,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(self.kind, OpKind::Contraction, "plan is not a contraction");
let lib = cutensor()?;
let cu = lib.cutensor_contract()?;
check(cu(
self.handle.as_raw(),
self.plan,
alpha,
a,
b,
beta,
c,
d,
workspace,
workspace_bytes,
stream,
))
}}
#[allow(clippy::too_many_arguments)]
pub unsafe fn reduce(
&self,
alpha: *const c_void,
a: *const c_void,
beta: *const c_void,
c: *const c_void,
d: *mut c_void,
workspace: *mut c_void,
workspace_bytes: u64,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(self.kind, OpKind::Reduction, "plan is not a reduction");
let lib = cutensor()?;
let cu = lib.cutensor_reduce()?;
check(cu(
self.handle.as_raw(),
self.plan,
alpha,
a,
beta,
c,
d,
workspace,
workspace_bytes,
stream,
))
}}
#[allow(clippy::too_many_arguments)]
pub unsafe fn elementwise_binary(
&self,
alpha: *const c_void,
a: *const c_void,
gamma: *const c_void,
c: *const c_void,
d: *mut c_void,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(
self.kind,
OpKind::ElementwiseBinary,
"plan is not an elementwise-binary"
);
let lib = cutensor()?;
let cu = lib.cutensor_elementwise_binary_execute()?;
check(cu(
self.handle.as_raw(),
self.plan,
alpha,
a,
gamma,
c,
d,
stream,
))
}}
#[allow(clippy::too_many_arguments)]
pub unsafe fn elementwise_trinary(
&self,
alpha: *const c_void,
a: *const c_void,
beta: *const c_void,
b: *const c_void,
gamma: *const c_void,
c: *const c_void,
d: *mut c_void,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(
self.kind,
OpKind::ElementwiseTrinary,
"plan is not an elementwise-trinary"
);
let lib = cutensor()?;
let cu = lib.cutensor_elementwise_trinary_execute()?;
check(cu(
self.handle.as_raw(),
self.plan,
alpha,
a,
beta,
b,
gamma,
c,
d,
stream,
))
}}
pub unsafe fn permute(
&self,
alpha: *const c_void,
a: *const c_void,
b: *mut c_void,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(self.kind, OpKind::Permutation, "plan is not a permutation");
let lib = cutensor()?;
let cu = lib.cutensor_permute()?;
check(cu(self.handle.as_raw(), self.plan, alpha, a, b, stream))
}}
#[allow(clippy::too_many_arguments)]
pub unsafe fn contract_block_sparse(
&self,
alpha: *const c_void,
a: *const c_void,
b: *const c_void,
beta: *const c_void,
c: *const c_void,
d: *mut c_void,
workspace: *mut c_void,
workspace_bytes: u64,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(
self.kind,
OpKind::BlockSparseContraction,
"plan is not a block-sparse contraction"
);
let lib = cutensor()?;
let cu = lib.cutensor_block_sparse_contract()?;
check(cu(
self.handle.as_raw(),
self.plan,
alpha,
a,
b,
beta,
c,
d,
workspace,
workspace_bytes,
stream,
))
}}
#[allow(clippy::too_many_arguments)]
pub unsafe fn contract_trinary(
&self,
alpha: *const c_void,
a: *const c_void,
b: *const c_void,
c: *const c_void,
beta: *const c_void,
d: *const c_void,
e: *mut c_void,
workspace: *mut c_void,
workspace_bytes: u64,
stream: *mut c_void,
) -> Result<()> { unsafe {
assert_eq!(
self.kind,
OpKind::TrinaryContraction,
"plan is not a trinary-contraction"
);
let lib = cutensor()?;
let cu = lib.cutensor_contract_trinary()?;
check(cu(
self.handle.as_raw(),
self.plan,
alpha,
a,
b,
c,
beta,
d,
e,
workspace,
workspace_bytes,
stream,
))
}}
}
impl Drop for Plan<'_> {
fn drop(&mut self) {
if let Ok(c) = cutensor() {
if let Ok(cu) = c.cutensor_destroy_plan() {
let _ = unsafe { cu(self.plan) };
}
}
}
}