Skip to main content

baracuda_cutensor/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuTENSOR (v2 API).
2//!
3//! cuTENSOR is NVIDIA's high-performance tensor-primitive library —
4//! einsum-style contractions, element-wise ops, reductions, and
5//! permutations. This crate wraps the full v2 host API surface.
6//!
7//! # Concepts
8//!
9//! - [`Handle`] — per-process library handle; owns the plan cache.
10//! - [`TensorDescriptor`] — shape + strides + dtype for one tensor.
11//! - [`OperationDescriptor`] — an *un*-compiled op (contraction,
12//!   reduction, elementwise binary/trinary, permutation). Created via
13//!   [`Contraction::new`], [`Reduction::new`], [`ElementwiseBinary::new`],
14//!   [`ElementwiseTrinary::new`], or [`Permutation::new`].
15//! - [`PlanPreference`] — algorithm selection + JIT mode.
16//! - [`Plan`] — compiled op, bound to a workspace size.
17//! - [`Plan::contract`] / [`Plan::reduce`] / etc. — execute the plan.
18//!
19//! # Example — `D = α · A ⊗ B + β · C` (matmul via contraction)
20//!
21//! Einstein notation: `D[m,n] = A[m,k] · B[k,n]`. Mode IDs identify the
22//! shared `k` index — pick any distinct integers per mode.
23//!
24//! ```no_run
25//! use baracuda_cutensor::*;
26//!
27//! # fn demo() -> Result<(), Error> {
28//! let handle = Handle::new()?;
29//! let m = 64i64; let n = 64i64; let k = 32i64;
30//! let a = TensorDescriptor::new(&handle, &[m, k], None, DataType::F32, 128)?;
31//! let b = TensorDescriptor::new(&handle, &[k, n], None, DataType::F32, 128)?;
32//! let c = TensorDescriptor::new(&handle, &[m, n], None, DataType::F32, 128)?;
33//! let modes_a = &[0i32, 2]; // [m, k]
34//! let modes_b = &[2, 1];     // [k, n]
35//! let modes_c = &[0, 1];     // [m, n]
36//! let op = unsafe {
37//!     Contraction::new(&handle, &a, modes_a, &b, modes_b, &c, modes_c, &c, modes_c,
38//!         core::ptr::null())
39//! }?;
40//! let pref = PlanPreference::default_for(&handle)?;
41//! let ws = op.estimate_workspace(&pref, WorkspaceKind::Default)?;
42//! let plan = Plan::new(&op, &pref, ws)?;
43//! # Ok(()) }
44//! ```
45//!
46//! # Example — reduce along an axis (sum over `k`)
47//!
48//! `D[m] = Σ_k A[m, k]`. Modes present in `A` but absent from `D` are
49//! reduced with the chosen [`BinaryOp`] (`Add` for sum).
50//!
51//! ```no_run
52//! use baracuda_cutensor::*;
53//!
54//! # fn demo() -> Result<(), Error> {
55//! let handle = Handle::new()?;
56//! let m = 128i64; let k = 64i64;
57//! let a = TensorDescriptor::new(&handle, &[m, k], None, DataType::F32, 128)?;
58//! let d = TensorDescriptor::new(&handle, &[m],    None, DataType::F32, 128)?;
59//!
60//! let modes_a = &[0i32, 1]; // [m, k]
61//! let modes_d = &[0i32];     // [m]
62//! let op = unsafe {
63//!     Reduction::new(&handle, &a, modes_a, &d, modes_d, &d, modes_d,
64//!         BinaryOp::Add, core::ptr::null())
65//! }?;
66//! let pref = PlanPreference::default_for(&handle)?;
67//! let ws = op.estimate_workspace(&pref, WorkspaceKind::Default)?;
68//! let _plan = Plan::new(&op, &pref, ws)?;
69//! # Ok(()) }
70//! ```
71//!
72//! # Example — element-wise `D = A + C` via [`ElementwiseBinary`]
73//!
74//! Same modes on every operand, no contraction or reduction — just a
75//! fused per-element op with optional unary pre-ops on each input.
76//!
77//! ```no_run
78//! use baracuda_cutensor::*;
79//!
80//! # fn demo() -> Result<(), Error> {
81//! let handle = Handle::new()?;
82//! let n = 1024i64;
83//! let a = TensorDescriptor::new(&handle, &[n], None, DataType::F32, 128)?;
84//! let c = TensorDescriptor::new(&handle, &[n], None, DataType::F32, 128)?;
85//! let d = TensorDescriptor::new(&handle, &[n], None, DataType::F32, 128)?;
86//!
87//! let modes = &[0i32];
88//! let op = unsafe {
89//!     ElementwiseBinary::new(
90//!         &handle,
91//!         &a, modes, UnaryOp::Identity,
92//!         &c, modes, UnaryOp::Identity,
93//!         &d, modes,
94//!         BinaryOp::Add,
95//!         core::ptr::null(),
96//!     )
97//! }?;
98//! let pref = PlanPreference::default_for(&handle)?;
99//! let _plan = Plan::new(&op, &pref, /* workspace */ 0)?;
100//! # Ok(()) }
101//! ```
102
103#![warn(missing_debug_implementations)]
104
105use core::ffi::c_void;
106use std::ffi::CString;
107
108use baracuda_cutensor_sys::{
109    cutensor, cutensorAlgo, cutensorDataType, cutensorHandle_t, cutensorJitMode,
110    cutensorOperationDescriptor_t, cutensorOperator, cutensorPlanPreference_t, cutensorPlan_t,
111    cutensorStatus_t, cutensorTensorDescriptor_t, cutensorWorksizePreference,
112};
113
114/// Error type for cuTENSOR operations.
115pub type Error = baracuda_core::Error<cutensorStatus_t>;
116/// Result alias.
117pub type Result<T, E = Error> = core::result::Result<T, E>;
118
119#[inline]
120fn check(status: cutensorStatus_t) -> Result<()> {
121    Error::check(status)
122}
123
124/// Verify cuTENSOR is loadable on this host.
125pub fn probe() -> Result<()> {
126    cutensor()?;
127    Ok(())
128}
129
130/// Encoded integer version from `cutensorGetVersion`. Decode as
131/// `major = v / 10000, minor = (v / 100) % 100, patch = v % 100`.
132pub fn version() -> Result<usize> {
133    let c = cutensor()?;
134    let cu = c.cutensor_get_version()?;
135    Ok(unsafe { cu() })
136}
137
138/// cuTENSOR's view of the CUDART version it was built against.
139pub fn cudart_version() -> Result<usize> {
140    let c = cutensor()?;
141    let cu = c.cutensor_get_cudart_version()?;
142    Ok(unsafe { cu() })
143}
144
145/// Set the cuTENSOR logger verbosity (0 = off, 1 = error, 2 = trace).
146pub fn set_log_level(level: i32) -> Result<()> {
147    let c = cutensor()?;
148    let cu = c.cutensor_logger_set_level()?;
149    check(unsafe { cu(level) })
150}
151
152/// Bitmask of log categories (API calls, hints, traces, …). Full value
153/// list in cuTENSOR headers.
154pub fn set_log_mask(mask: i32) -> Result<()> {
155    let c = cutensor()?;
156    let cu = c.cutensor_logger_set_mask()?;
157    check(unsafe { cu(mask) })
158}
159
160/// Open a log file path for cuTENSOR output.
161pub fn open_log_file(path: &str) -> Result<()> {
162    let cpath = std::ffi::CString::new(path).map_err(|_| Error::Status {
163        status: cutensorStatus_t::INVALID_VALUE,
164    })?;
165    let c = cutensor()?;
166    let cu = c.cutensor_logger_open_file()?;
167    check(unsafe { cu(cpath.as_ptr()) })
168}
169
170/// Force-disable all cuTENSOR logging (tightest possible quiet).
171pub fn force_disable_logging() -> Result<()> {
172    let c = cutensor()?;
173    let cu = c.cutensor_logger_force_disable()?;
174    check(unsafe { cu() })
175}
176
177/// Element dtype for tensor descriptors.
178#[derive(Copy, Clone, Debug, Eq, PartialEq)]
179pub enum DataType {
180    /// IEEE-754 half-precision (`f16`).
181    F16,
182    /// Brain float (`bf16`).
183    BF16,
184    /// Single-precision float (`f32`).
185    F32,
186    /// Double-precision float (`f64`).
187    F64,
188    /// Single-precision complex (real + imag `f32`).
189    ComplexF32,
190    /// Double-precision complex (real + imag `f64`).
191    ComplexF64,
192    /// Signed 8-bit integer.
193    I8,
194    /// Unsigned 8-bit integer.
195    U8,
196    /// Signed 32-bit integer.
197    I32,
198    /// Unsigned 32-bit integer.
199    U32,
200}
201
202impl DataType {
203    #[inline]
204    fn raw(self) -> i32 {
205        match self {
206            DataType::F16 => cutensorDataType::R_16F,
207            DataType::BF16 => cutensorDataType::R_16BF,
208            DataType::F32 => cutensorDataType::R_32F,
209            DataType::F64 => cutensorDataType::R_64F,
210            DataType::ComplexF32 => cutensorDataType::C_32F,
211            DataType::ComplexF64 => cutensorDataType::C_64F,
212            DataType::I8 => cutensorDataType::R_8I,
213            DataType::U8 => cutensorDataType::R_8U,
214            DataType::I32 => cutensorDataType::R_32I,
215            DataType::U32 => cutensorDataType::R_32U,
216        }
217    }
218}
219
220/// Per-operand unary operator (applied to A/B/C before the main op).
221#[derive(Copy, Clone, Debug, Eq, PartialEq)]
222pub enum UnaryOp {
223    /// No-op; pass the operand through unchanged.
224    Identity,
225    /// Square root.
226    Sqrt,
227    /// Rectified linear unit (`max(0, x)`).
228    Relu,
229    /// Complex conjugate (no-op for real types).
230    Conj,
231    /// Reciprocal (`1 / x`).
232    Rcp,
233    /// Logistic sigmoid (`1 / (1 + exp(-x))`).
234    Sigmoid,
235    /// Hyperbolic tangent.
236    Tanh,
237}
238
239impl UnaryOp {
240    #[inline]
241    fn raw(self) -> i32 {
242        match self {
243            UnaryOp::Identity => cutensorOperator::IDENTITY,
244            UnaryOp::Sqrt => cutensorOperator::SQRT,
245            UnaryOp::Relu => cutensorOperator::RELU,
246            UnaryOp::Conj => cutensorOperator::CONJ,
247            UnaryOp::Rcp => cutensorOperator::RCP,
248            UnaryOp::Sigmoid => cutensorOperator::SIGMOID,
249            UnaryOp::Tanh => cutensorOperator::TANH,
250        }
251    }
252}
253
254/// Binary combining operator (used between operands in elementwise /
255/// reduction ops).
256#[derive(Copy, Clone, Debug, Eq, PartialEq)]
257pub enum BinaryOp {
258    /// Sum (`a + b`).
259    Add,
260    /// Product (`a * b`).
261    Mul,
262    /// Element-wise maximum.
263    Max,
264    /// Element-wise minimum.
265    Min,
266}
267
268impl BinaryOp {
269    #[inline]
270    fn raw(self) -> i32 {
271        match self {
272            BinaryOp::Add => cutensorOperator::ADD,
273            BinaryOp::Mul => cutensorOperator::MUL,
274            BinaryOp::Max => cutensorOperator::MAX,
275            BinaryOp::Min => cutensorOperator::MIN,
276        }
277    }
278}
279
280/// cuTENSOR library handle.
281#[derive(Debug)]
282pub struct Handle {
283    handle: cutensorHandle_t,
284}
285
286unsafe impl Send for Handle {}
287
288impl Handle {
289    /// Create a new cuTENSOR handle (`cutensorCreate`).
290    pub fn new() -> Result<Self> {
291        let c = cutensor()?;
292        let cu = c.cutensor_create()?;
293        let mut h: cutensorHandle_t = core::ptr::null_mut();
294        check(unsafe { cu(&mut h) })?;
295        Ok(Self { handle: h })
296    }
297
298    /// Raw `cutensorHandle_t`. Use with care.
299    #[inline]
300    pub fn as_raw(&self) -> cutensorHandle_t {
301        self.handle
302    }
303
304    /// Resize the internal plan cache — larger = more cached plans,
305    /// faster re-invocations. Default is 64.
306    pub fn resize_plan_cache(&self, num_entries: u32) -> Result<()> {
307        let c = cutensor()?;
308        let cu = c.cutensor_handle_resize_plan_cache()?;
309        check(unsafe { cu(self.handle, num_entries) })
310    }
311
312    /// Persist the plan cache to disk.
313    pub fn write_plan_cache_to_file(&self, path: &str) -> Result<()> {
314        let cpath = CString::new(path).map_err(|_| Error::Status {
315            status: cutensorStatus_t::INVALID_VALUE,
316        })?;
317        let c = cutensor()?;
318        let cu = c.cutensor_handle_write_plan_cache_to_file()?;
319        check(unsafe { cu(self.handle, cpath.as_ptr()) })
320    }
321
322    /// Read a previously-written plan cache from disk.
323    pub fn read_plan_cache_from_file(&self, path: &str) -> Result<()> {
324        let cpath = CString::new(path).map_err(|_| Error::Status {
325            status: cutensorStatus_t::INVALID_VALUE,
326        })?;
327        let c = cutensor()?;
328        let cu = c.cutensor_handle_read_plan_cache_from_file()?;
329        check(unsafe { cu(self.handle, cpath.as_ptr()) })
330    }
331
332    /// Persist the **kernel cache** (compiled binary kernels) to disk.
333    /// Separate from plan cache — kernel cache survives across planner
334    /// changes.
335    pub fn write_kernel_cache_to_file(&self, path: &str) -> Result<()> {
336        let cpath = CString::new(path).map_err(|_| Error::Status {
337            status: cutensorStatus_t::INVALID_VALUE,
338        })?;
339        let c = cutensor()?;
340        let cu = c.cutensor_write_kernel_cache_to_file()?;
341        check(unsafe { cu(self.handle, cpath.as_ptr()) })
342    }
343
344    /// Read a previously-written kernel cache from disk.
345    pub fn read_kernel_cache_from_file(&self, path: &str) -> Result<()> {
346        let cpath = CString::new(path).map_err(|_| Error::Status {
347            status: cutensorStatus_t::INVALID_VALUE,
348        })?;
349        let c = cutensor()?;
350        let cu = c.cutensor_read_kernel_cache_from_file()?;
351        check(unsafe { cu(self.handle, cpath.as_ptr()) })
352    }
353
354    /// Fetch cuTENSOR's pre-defined `CUTENSOR_COMPUTE_DESC_32F` descriptor.
355    /// Pass this (or one of the sibling accessors) as `compute_desc` to
356    /// any op constructor.
357    pub fn compute_desc_32f(&self) -> Result<*const c_void> {
358        Ok(cutensor()?.compute_desc_32f()?)
359    }
360    /// Fetch `CUTENSOR_COMPUTE_DESC_64F` — double-precision accumulator.
361    pub fn compute_desc_64f(&self) -> Result<*const c_void> {
362        Ok(cutensor()?.compute_desc_64f()?)
363    }
364    /// Fetch `CUTENSOR_COMPUTE_DESC_16F` — half-precision accumulator.
365    pub fn compute_desc_16f(&self) -> Result<*const c_void> {
366        Ok(cutensor()?.compute_desc_16f()?)
367    }
368    /// Fetch `CUTENSOR_COMPUTE_DESC_16BF` — bf16 accumulator.
369    pub fn compute_desc_16bf(&self) -> Result<*const c_void> {
370        Ok(cutensor()?.compute_desc_16bf()?)
371    }
372    /// Fetch `CUTENSOR_COMPUTE_DESC_TF32` — TF32 tensor-core accumulator.
373    pub fn compute_desc_tf32(&self) -> Result<*const c_void> {
374        Ok(cutensor()?.compute_desc_tf32()?)
375    }
376    /// Fetch `CUTENSOR_COMPUTE_DESC_3XTF32` — 3xTF32 emulation for f32.
377    pub fn compute_desc_3xtf32(&self) -> Result<*const c_void> {
378        Ok(cutensor()?.compute_desc_3xtf32()?)
379    }
380    /// Fetch `CUTENSOR_COMPUTE_DESC_4X16F` — 4x f16 mixed-precision.
381    pub fn compute_desc_4x16f(&self) -> Result<*const c_void> {
382        Ok(cutensor()?.compute_desc_4x16f()?)
383    }
384    /// Fetch `CUTENSOR_COMPUTE_DESC_8XINT8` — packed int8 tensor cores.
385    pub fn compute_desc_8xint8(&self) -> Result<*const c_void> {
386        Ok(cutensor()?.compute_desc_8xint8()?)
387    }
388    /// Fetch `CUTENSOR_COMPUTE_DESC_9X16BF` — bf16 stochastic-rounding mode.
389    pub fn compute_desc_9x16bf(&self) -> Result<*const c_void> {
390        Ok(cutensor()?.compute_desc_9x16bf()?)
391    }
392}
393
394/// A custom [compute descriptor]. Prefer the pre-defined ones
395/// ([`Handle::compute_desc_32f`], …) unless you need attribute
396/// customization.
397#[derive(Debug)]
398pub struct ComputeDescriptor<'h> {
399    desc: baracuda_cutensor_sys::cutensorComputeDescriptor_t,
400    _handle: &'h Handle,
401}
402
403impl<'h> ComputeDescriptor<'h> {
404    /// Create a new compute descriptor
405    /// (`cutensorCreateComputeDescriptor`).
406    pub fn new(handle: &'h Handle) -> Result<Self> {
407        let c = cutensor()?;
408        let cu = c.cutensor_create_compute_descriptor()?;
409        let mut desc: baracuda_cutensor_sys::cutensorComputeDescriptor_t = core::ptr::null();
410        check(unsafe { cu(handle.as_raw(), &mut desc as *mut _ as *mut _) })?;
411        Ok(Self {
412            desc,
413            _handle: handle,
414        })
415    }
416
417    /// Raw `cutensorComputeDescriptor_t`. Use with care.
418    #[inline]
419    pub fn as_raw(&self) -> baracuda_cutensor_sys::cutensorComputeDescriptor_t {
420        self.desc
421    }
422
423    /// # Safety
424    ///
425    /// `value` points at a buffer of `size_bytes` matching `attr`.
426    pub unsafe fn set_attribute(
427        &self,
428        attr: i32,
429        value: *const c_void,
430        size_bytes: usize,
431    ) -> Result<()> { unsafe {
432        let c = cutensor()?;
433        let cu = c.cutensor_compute_descriptor_set_attribute()?;
434        check(cu(
435            self._handle.as_raw(),
436            self.desc,
437            attr,
438            value,
439            size_bytes,
440        ))
441    }}
442
443    /// # Safety
444    ///
445    /// `value` points at a writable buffer of `size_bytes`.
446    pub unsafe fn get_attribute(
447        &self,
448        attr: i32,
449        value: *mut c_void,
450        size_bytes: usize,
451    ) -> Result<()> { unsafe {
452        let c = cutensor()?;
453        let cu = c.cutensor_compute_descriptor_get_attribute()?;
454        check(cu(
455            self._handle.as_raw(),
456            self.desc,
457            attr,
458            value,
459            size_bytes,
460        ))
461    }}
462}
463
464impl Drop for ComputeDescriptor<'_> {
465    fn drop(&mut self) {
466        if let Ok(c) = cutensor() {
467            if let Ok(cu) = c.cutensor_destroy_compute_descriptor() {
468                let _ = unsafe { cu(self.desc) };
469            }
470        }
471    }
472}
473
474/// A block-sparse tensor descriptor (cuTENSOR 2.x). Used on the A
475/// operand of a [`BlockSparseContraction`].
476#[derive(Debug)]
477pub struct BlockSparseTensorDescriptor<'h> {
478    desc: baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t,
479    _handle: &'h Handle,
480}
481
482impl<'h> BlockSparseTensorDescriptor<'h> {
483    /// Build a block-sparse tensor:
484    ///
485    /// - `extents` — full dense shape
486    /// - `block_size` — size per dim of each non-zero block (same length as extents)
487    /// - `strides` — optional custom strides; `None` = packed
488    /// - `block_indices` — array of `num_modes × block_count` ints identifying
489    ///   the non-zero block locations (index per mode per block)
490    #[allow(clippy::too_many_arguments)]
491    pub fn new(
492        handle: &'h Handle,
493        extents: &[i64],
494        block_size: &[i64],
495        strides: Option<&[i64]>,
496        block_indices: &[i32],
497        dtype: DataType,
498        alignment_bytes: u32,
499    ) -> Result<Self> {
500        assert_eq!(block_size.len(), extents.len());
501        if let Some(s) = strides {
502            assert_eq!(s.len(), extents.len());
503        }
504        let num_modes = extents.len() as u32;
505        let block_count = (block_indices.len() / extents.len()) as i64;
506        let c = cutensor()?;
507        let cu = c.cutensor_create_block_sparse_tensor_descriptor()?;
508        let mut desc: baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t =
509            core::ptr::null_mut();
510        check(unsafe {
511            cu(
512                handle.as_raw(),
513                &mut desc,
514                num_modes,
515                extents.as_ptr(),
516                block_size.as_ptr(),
517                strides.map_or(core::ptr::null(), |s| s.as_ptr()),
518                block_count,
519                block_indices.as_ptr(),
520                dtype.raw(),
521                alignment_bytes,
522            )
523        })?;
524        Ok(Self {
525            desc,
526            _handle: handle,
527        })
528    }
529
530    /// Raw `cutensorBlockSparseTensorDescriptor_t`. Use with care.
531    #[inline]
532    pub fn as_raw(&self) -> baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t {
533        self.desc
534    }
535}
536
537impl Drop for BlockSparseTensorDescriptor<'_> {
538    fn drop(&mut self) {
539        if let Ok(c) = cutensor() {
540            if let Ok(cu) = c.cutensor_destroy_block_sparse_tensor_descriptor() {
541                let _ = unsafe { cu(self.desc) };
542            }
543        }
544    }
545}
546
547/// Block-sparse contraction: the A operand is block-sparse, B/C/D dense.
548#[derive(Debug)]
549pub struct BlockSparseContraction;
550
551impl BlockSparseContraction {
552    /// # Safety
553    ///
554    /// `compute_desc` must be null or a live `cutensorComputeDescriptor_t`.
555    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
556    pub unsafe fn new<'h>(
557        handle: &'h Handle,
558        a: &BlockSparseTensorDescriptor<'h>,
559        modes_a: &[i32],
560        b: &TensorDescriptor<'h>,
561        modes_b: &[i32],
562        c: &TensorDescriptor<'h>,
563        modes_c: &[i32],
564        d: &TensorDescriptor<'h>,
565        modes_d: &[i32],
566        compute_desc: *const c_void,
567    ) -> Result<OperationDescriptor<'h>> { unsafe {
568        let lib = cutensor()?;
569        let cu = lib.cutensor_create_block_sparse_contraction()?;
570        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
571        check(cu(
572            handle.as_raw(),
573            &mut desc,
574            a.as_raw(),
575            modes_a.as_ptr(),
576            cutensorOperator::IDENTITY,
577            b.as_raw(),
578            modes_b.as_ptr(),
579            cutensorOperator::IDENTITY,
580            c.as_raw(),
581            modes_c.as_ptr(),
582            cutensorOperator::IDENTITY,
583            d.as_raw(),
584            modes_d.as_ptr(),
585            compute_desc,
586        ))?;
587        Ok(OperationDescriptor {
588            desc,
589            handle,
590            kind: OpKind::BlockSparseContraction,
591        })
592    }}
593}
594
595/// A ternary contraction op: `E[mE] = α·op_a(A)·op_b(B)·op_c(C) + β·op_d(D)`.
596#[derive(Debug)]
597pub struct TrinaryContraction;
598
599impl TrinaryContraction {
600    /// # Safety
601    ///
602    /// `compute_desc` must be null or a live `cutensorComputeDescriptor_t`.
603    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
604    pub unsafe fn new<'h>(
605        handle: &'h Handle,
606        a: &TensorDescriptor<'h>,
607        modes_a: &[i32],
608        b: &TensorDescriptor<'h>,
609        modes_b: &[i32],
610        c: &TensorDescriptor<'h>,
611        modes_c: &[i32],
612        d: &TensorDescriptor<'h>,
613        modes_d: &[i32],
614        e: &TensorDescriptor<'h>,
615        modes_e: &[i32],
616        compute_desc: *const c_void,
617    ) -> Result<OperationDescriptor<'h>> { unsafe {
618        let lib = cutensor()?;
619        let cu = lib.cutensor_create_contraction_trinary()?;
620        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
621        check(cu(
622            handle.as_raw(),
623            &mut desc,
624            a.as_raw(),
625            modes_a.as_ptr(),
626            cutensorOperator::IDENTITY,
627            b.as_raw(),
628            modes_b.as_ptr(),
629            cutensorOperator::IDENTITY,
630            c.as_raw(),
631            modes_c.as_ptr(),
632            cutensorOperator::IDENTITY,
633            d.as_raw(),
634            modes_d.as_ptr(),
635            cutensorOperator::IDENTITY,
636            e.as_raw(),
637            modes_e.as_ptr(),
638            compute_desc,
639        ))?;
640        Ok(OperationDescriptor {
641            desc,
642            handle,
643            kind: OpKind::TrinaryContraction,
644        })
645    }}
646}
647
648impl Drop for Handle {
649    fn drop(&mut self) {
650        if let Ok(c) = cutensor() {
651            if let Ok(cu) = c.cutensor_destroy() {
652                let _ = unsafe { cu(self.handle) };
653            }
654        }
655    }
656}
657
658/// A tensor descriptor: modes + extents + dtype + stride layout.
659#[derive(Debug)]
660pub struct TensorDescriptor<'h> {
661    desc: cutensorTensorDescriptor_t,
662    _handle: &'h Handle,
663}
664
665impl<'h> TensorDescriptor<'h> {
666    /// `extents[i]` is the size along mode `i`. `strides = None` gets a
667    /// row-major packed layout.
668    pub fn new(
669        handle: &'h Handle,
670        extents: &[i64],
671        strides: Option<&[i64]>,
672        dtype: DataType,
673        alignment_bytes: u32,
674    ) -> Result<Self> {
675        let c = cutensor()?;
676        let cu = c.cutensor_create_tensor_descriptor()?;
677        let num_modes = extents.len() as u32;
678        if let Some(s) = strides {
679            assert_eq!(s.len(), extents.len(), "strides length mismatch");
680        }
681        let mut desc: cutensorTensorDescriptor_t = core::ptr::null_mut();
682        check(unsafe {
683            cu(
684                handle.as_raw(),
685                &mut desc,
686                num_modes,
687                extents.as_ptr(),
688                strides.map_or(core::ptr::null(), |s| s.as_ptr()),
689                dtype.raw(),
690                alignment_bytes,
691            )
692        })?;
693        Ok(Self {
694            desc,
695            _handle: handle,
696        })
697    }
698
699    /// Raw `cutensorTensorDescriptor_t`. Use with care.
700    #[inline]
701    pub fn as_raw(&self) -> cutensorTensorDescriptor_t {
702        self.desc
703    }
704
705    /// Low-level tensor-descriptor attribute setter.
706    ///
707    /// # Safety
708    ///
709    /// `buf` must point at `size_bytes` matching `attr`.
710    pub unsafe fn set_attribute(
711        &self,
712        attr: i32,
713        buf: *const c_void,
714        size_bytes: usize,
715    ) -> Result<()> { unsafe {
716        let c = cutensor()?;
717        let cu = c.cutensor_tensor_descriptor_set_attribute()?;
718        check(cu(self._handle.as_raw(), self.desc, attr, buf, size_bytes))
719    }}
720}
721
722impl Drop for TensorDescriptor<'_> {
723    fn drop(&mut self) {
724        if let Ok(c) = cutensor() {
725            if let Ok(cu) = c.cutensor_destroy_tensor_descriptor() {
726                let _ = unsafe { cu(self.desc) };
727            }
728        }
729    }
730}
731
732/// Internal: what kind of op a descriptor wraps — needed to dispatch
733/// the right `execute` path on the compiled [`Plan`].
734#[derive(Copy, Clone, Debug, Eq, PartialEq)]
735enum OpKind {
736    Contraction,
737    TrinaryContraction,
738    BlockSparseContraction,
739    Reduction,
740    ElementwiseBinary,
741    ElementwiseTrinary,
742    Permutation,
743}
744
745/// An un-compiled operation descriptor. Users typically create these
746/// through constructors on [`Contraction`], [`Reduction`],
747/// [`ElementwiseBinary`], [`ElementwiseTrinary`], or [`Permutation`].
748#[derive(Debug)]
749pub struct OperationDescriptor<'h> {
750    desc: cutensorOperationDescriptor_t,
751    handle: &'h Handle,
752    kind: OpKind,
753}
754
755impl<'h> OperationDescriptor<'h> {
756    /// Raw `cutensorOperationDescriptor_t`. Use with care.
757    #[inline]
758    pub fn as_raw(&self) -> cutensorOperationDescriptor_t {
759        self.desc
760    }
761
762    /// Estimate the scratch workspace required by a plan built from
763    /// this descriptor + `pref`.
764    pub fn estimate_workspace(
765        &self,
766        pref: &PlanPreference<'h>,
767        kind: WorkspaceKind,
768    ) -> Result<u64> {
769        let c = cutensor()?;
770        let cu = c.cutensor_estimate_workspace_size()?;
771        let mut size: u64 = 0;
772        check(unsafe {
773            cu(
774                self.handle.as_raw(),
775                self.desc,
776                pref.as_raw(),
777                kind.raw(),
778                &mut size,
779            )
780        })?;
781        Ok(size)
782    }
783
784    /// Estimated runtime in milliseconds for this op at the given
785    /// algorithm (`cutensorAlgo::DEFAULT` for auto).
786    pub fn estimate_runtime(&self, pref: &PlanPreference<'h>, algo: i32) -> Result<f32> {
787        let c = cutensor()?;
788        let cu = c.cutensor_operation_estimate_runtime()?;
789        let mut ms: f32 = 0.0;
790        check(unsafe {
791            cu(
792                self.handle.as_raw(),
793                self.desc,
794                pref.as_raw(),
795                algo,
796                &mut ms,
797            )
798        })?;
799        Ok(ms)
800    }
801
802    /// Number of algorithms cuTENSOR has for this op shape.
803    pub fn num_algos(&self) -> Result<i32> {
804        let c = cutensor()?;
805        let cu = c.cutensor_operation_num_algos()?;
806        let mut n: i32 = 0;
807        check(unsafe { cu(self.desc, &mut n) })?;
808        Ok(n)
809    }
810
811    /// Low-level attribute getter (for attributes not exposed as typed fns).
812    ///
813    /// # Safety
814    ///
815    /// `buf` must be writable for `size_bytes` matching `attr`.
816    pub unsafe fn get_attribute(
817        &self,
818        attr: i32,
819        buf: *mut c_void,
820        size_bytes: usize,
821    ) -> Result<()> { unsafe {
822        let c = cutensor()?;
823        let cu = c.cutensor_operation_descriptor_get_attribute()?;
824        check(cu(self.handle.as_raw(), self.desc, attr, buf, size_bytes))
825    }}
826
827    /// Low-level attribute setter.
828    ///
829    /// # Safety
830    ///
831    /// `buf` must point at a buffer of `size_bytes` matching `attr`.
832    pub unsafe fn set_attribute(
833        &self,
834        attr: i32,
835        buf: *const c_void,
836        size_bytes: usize,
837    ) -> Result<()> { unsafe {
838        let c = cutensor()?;
839        let cu = c.cutensor_operation_descriptor_set_attribute()?;
840        check(cu(self.handle.as_raw(), self.desc, attr, buf, size_bytes))
841    }}
842}
843
844impl Drop for OperationDescriptor<'_> {
845    fn drop(&mut self) {
846        if let Ok(c) = cutensor() {
847            if let Ok(cu) = c.cutensor_destroy_operation_descriptor() {
848                let _ = unsafe { cu(self.desc) };
849            }
850        }
851    }
852}
853
854/// A contraction op: `D[mD] = α * op_a(A[mA]) * op_b(B[mB]) + β * op_c(C[mC])`.
855#[derive(Debug)]
856pub struct Contraction;
857
858impl Contraction {
859    /// Build a contraction descriptor.
860    ///
861    /// `compute_desc` is an opaque pointer — pass `core::ptr::null()`
862    /// for the library default (compute-type matches C's dtype).
863    ///
864    /// # Safety
865    ///
866    /// `compute_desc` must be null or a valid `cutensorComputeDescriptor_t`.
867    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
868    pub unsafe fn new<'h>(
869        handle: &'h Handle,
870        a: &TensorDescriptor<'h>,
871        modes_a: &[i32],
872        b: &TensorDescriptor<'h>,
873        modes_b: &[i32],
874        c: &TensorDescriptor<'h>,
875        modes_c: &[i32],
876        d: &TensorDescriptor<'h>,
877        modes_d: &[i32],
878        compute_desc: *const c_void,
879    ) -> Result<OperationDescriptor<'h>> { unsafe {
880        let cu_lib = cutensor()?;
881        let cu = cu_lib.cutensor_create_contraction()?;
882        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
883        check(cu(
884            handle.as_raw(),
885            &mut desc,
886            a.as_raw(),
887            modes_a.as_ptr(),
888            cutensorOperator::IDENTITY,
889            b.as_raw(),
890            modes_b.as_ptr(),
891            cutensorOperator::IDENTITY,
892            c.as_raw(),
893            modes_c.as_ptr(),
894            cutensorOperator::IDENTITY,
895            d.as_raw(),
896            modes_d.as_ptr(),
897            compute_desc,
898        ))?;
899        Ok(OperationDescriptor {
900            desc,
901            handle,
902            kind: OpKind::Contraction,
903        })
904    }}
905}
906
907/// A reduction op: `D[mD] = reduce(A[mA])` with user-chosen reduce op.
908#[derive(Debug)]
909pub struct Reduction;
910
911impl Reduction {
912    /// Build a reduction. `modes_d` is a subset of `modes_a` — all
913    /// modes in `a` that do NOT appear in `d` are reduced. `op_reduce`
914    /// is ADD for sum, MUL for product, MAX/MIN for min-max.
915    ///
916    /// # Safety
917    ///
918    /// `compute_desc` must be null or valid.
919    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
920    pub unsafe fn new<'h>(
921        handle: &'h Handle,
922        a: &TensorDescriptor<'h>,
923        modes_a: &[i32],
924        c: &TensorDescriptor<'h>,
925        modes_c: &[i32],
926        d: &TensorDescriptor<'h>,
927        modes_d: &[i32],
928        op_reduce: BinaryOp,
929        compute_desc: *const c_void,
930    ) -> Result<OperationDescriptor<'h>> { unsafe {
931        let lib = cutensor()?;
932        let cu = lib.cutensor_create_reduction()?;
933        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
934        check(cu(
935            handle.as_raw(),
936            &mut desc,
937            a.as_raw(),
938            modes_a.as_ptr(),
939            cutensorOperator::IDENTITY,
940            c.as_raw(),
941            modes_c.as_ptr(),
942            cutensorOperator::IDENTITY,
943            d.as_raw(),
944            modes_d.as_ptr(),
945            op_reduce.raw(),
946            compute_desc,
947        ))?;
948        Ok(OperationDescriptor {
949            desc,
950            handle,
951            kind: OpKind::Reduction,
952        })
953    }}
954}
955
956/// Elementwise binary op: `D[mD] = (α * op_a(A[mA])) op_ac (γ * op_c(C[mC]))`.
957#[derive(Debug)]
958pub struct ElementwiseBinary;
959
960impl ElementwiseBinary {
961    /// # Safety
962    ///
963    /// `compute_desc` must be null or valid.
964    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
965    pub unsafe fn new<'h>(
966        handle: &'h Handle,
967        a: &TensorDescriptor<'h>,
968        modes_a: &[i32],
969        op_a: UnaryOp,
970        c: &TensorDescriptor<'h>,
971        modes_c: &[i32],
972        op_c: UnaryOp,
973        d: &TensorDescriptor<'h>,
974        modes_d: &[i32],
975        op_ac: BinaryOp,
976        compute_desc: *const c_void,
977    ) -> Result<OperationDescriptor<'h>> { unsafe {
978        let lib = cutensor()?;
979        let cu = lib.cutensor_create_elementwise_binary()?;
980        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
981        check(cu(
982            handle.as_raw(),
983            &mut desc,
984            a.as_raw(),
985            modes_a.as_ptr(),
986            op_a.raw(),
987            c.as_raw(),
988            modes_c.as_ptr(),
989            op_c.raw(),
990            d.as_raw(),
991            modes_d.as_ptr(),
992            op_ac.raw(),
993            compute_desc,
994        ))?;
995        Ok(OperationDescriptor {
996            desc,
997            handle,
998            kind: OpKind::ElementwiseBinary,
999        })
1000    }}
1001}
1002
1003/// Elementwise trinary op:
1004/// `D[mD] = ((α * op_a(A) op_ab β * op_b(B)) op_abc γ * op_c(C))`.
1005#[derive(Debug)]
1006pub struct ElementwiseTrinary;
1007
1008impl ElementwiseTrinary {
1009    /// # Safety
1010    ///
1011    /// `compute_desc` must be null or valid.
1012    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
1013    pub unsafe fn new<'h>(
1014        handle: &'h Handle,
1015        a: &TensorDescriptor<'h>,
1016        modes_a: &[i32],
1017        op_a: UnaryOp,
1018        b: &TensorDescriptor<'h>,
1019        modes_b: &[i32],
1020        op_b: UnaryOp,
1021        c: &TensorDescriptor<'h>,
1022        modes_c: &[i32],
1023        op_c: UnaryOp,
1024        d: &TensorDescriptor<'h>,
1025        modes_d: &[i32],
1026        op_ab: BinaryOp,
1027        op_abc: BinaryOp,
1028        compute_desc: *const c_void,
1029    ) -> Result<OperationDescriptor<'h>> { unsafe {
1030        let lib = cutensor()?;
1031        let cu = lib.cutensor_create_elementwise_trinary()?;
1032        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
1033        check(cu(
1034            handle.as_raw(),
1035            &mut desc,
1036            a.as_raw(),
1037            modes_a.as_ptr(),
1038            op_a.raw(),
1039            b.as_raw(),
1040            modes_b.as_ptr(),
1041            op_b.raw(),
1042            c.as_raw(),
1043            modes_c.as_ptr(),
1044            op_c.raw(),
1045            d.as_raw(),
1046            modes_d.as_ptr(),
1047            op_ab.raw(),
1048            op_abc.raw(),
1049            compute_desc,
1050        ))?;
1051        Ok(OperationDescriptor {
1052            desc,
1053            handle,
1054            kind: OpKind::ElementwiseTrinary,
1055        })
1056    }}
1057}
1058
1059/// Tensor permutation (axis shuffle + optional unary op):
1060/// `B[mB] = α * op_a(A[mA])`.
1061#[derive(Debug)]
1062pub struct Permutation;
1063
1064impl Permutation {
1065    /// # Safety
1066    ///
1067    /// `compute_desc` must be null or valid.
1068    #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
1069    pub unsafe fn new<'h>(
1070        handle: &'h Handle,
1071        a: &TensorDescriptor<'h>,
1072        modes_a: &[i32],
1073        op_a: UnaryOp,
1074        b: &TensorDescriptor<'h>,
1075        modes_b: &[i32],
1076        compute_desc: *const c_void,
1077    ) -> Result<OperationDescriptor<'h>> { unsafe {
1078        let lib = cutensor()?;
1079        let cu = lib.cutensor_create_permutation()?;
1080        let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
1081        check(cu(
1082            handle.as_raw(),
1083            &mut desc,
1084            a.as_raw(),
1085            modes_a.as_ptr(),
1086            op_a.raw(),
1087            b.as_raw(),
1088            modes_b.as_ptr(),
1089            compute_desc,
1090        ))?;
1091        Ok(OperationDescriptor {
1092            desc,
1093            handle,
1094            kind: OpKind::Permutation,
1095        })
1096    }}
1097}
1098
1099/// Plan preferences — algorithm selection + JIT mode.
1100#[derive(Debug)]
1101pub struct PlanPreference<'h> {
1102    pref: cutensorPlanPreference_t,
1103    _handle: &'h Handle,
1104}
1105
1106impl<'h> PlanPreference<'h> {
1107    /// Build a plan-preference (`cutensorCreatePlanPreference`)
1108    /// requesting `algo` (e.g. `cutensorAlgo::DEFAULT`) and `jit_mode`.
1109    pub fn new(handle: &'h Handle, algo: i32, jit_mode: i32) -> Result<Self> {
1110        let c = cutensor()?;
1111        let cu = c.cutensor_create_plan_preference()?;
1112        let mut p: cutensorPlanPreference_t = core::ptr::null_mut();
1113        check(unsafe { cu(handle.as_raw(), &mut p, algo, jit_mode) })?;
1114        Ok(Self {
1115            pref: p,
1116            _handle: handle,
1117        })
1118    }
1119
1120    /// Default preferences — library's best guess at algorithm, JIT off.
1121    pub fn default_for(handle: &'h Handle) -> Result<Self> {
1122        Self::new(handle, cutensorAlgo::DEFAULT, cutensorJitMode::NONE)
1123    }
1124
1125    /// Raw `cutensorPlanPreference_t`. Use with care.
1126    #[inline]
1127    pub fn as_raw(&self) -> cutensorPlanPreference_t {
1128        self.pref
1129    }
1130
1131    /// Set a plan-preference attribute (see cuTENSOR's
1132    /// `cutensorPlanPreferenceAttribute_t`).
1133    ///
1134    /// # Safety
1135    ///
1136    /// `value` must point at a buffer of at least `size_bytes` for the
1137    /// attribute kind being set.
1138    pub unsafe fn set_attribute(
1139        &self,
1140        attr: i32,
1141        value: *const c_void,
1142        size_bytes: usize,
1143    ) -> Result<()> { unsafe {
1144        let c = cutensor()?;
1145        let cu = c.cutensor_plan_preference_set_attribute()?;
1146        check(cu(
1147            self._handle.as_raw(),
1148            self.pref,
1149            attr,
1150            value,
1151            size_bytes,
1152        ))
1153    }}
1154
1155    /// Read a plan-preference attribute.
1156    ///
1157    /// # Safety
1158    ///
1159    /// `value` must be writable for `size_bytes` matching `attr`.
1160    pub unsafe fn get_attribute(
1161        &self,
1162        attr: i32,
1163        value: *mut c_void,
1164        size_bytes: usize,
1165    ) -> Result<()> { unsafe {
1166        let c = cutensor()?;
1167        let cu = c.cutensor_plan_preference_get_attribute()?;
1168        check(cu(
1169            self._handle.as_raw(),
1170            self.pref,
1171            attr,
1172            value,
1173            size_bytes,
1174        ))
1175    }}
1176}
1177
1178impl Drop for PlanPreference<'_> {
1179    fn drop(&mut self) {
1180        if let Ok(c) = cutensor() {
1181            if let Ok(cu) = c.cutensor_destroy_plan_preference() {
1182                let _ = unsafe { cu(self.pref) };
1183            }
1184        }
1185    }
1186}
1187
1188/// Workspace-size preference tier.
1189#[derive(Copy, Clone, Debug)]
1190pub enum WorkspaceKind {
1191    /// Smallest workspace the algorithm can run with.
1192    Min,
1193    /// Library default — balanced size vs. performance.
1194    Default,
1195    /// Largest workspace the algorithm will ever need.
1196    Max,
1197}
1198
1199impl WorkspaceKind {
1200    #[inline]
1201    fn raw(self) -> i32 {
1202        match self {
1203            WorkspaceKind::Min => cutensorWorksizePreference::MIN,
1204            WorkspaceKind::Default => cutensorWorksizePreference::DEFAULT,
1205            WorkspaceKind::Max => cutensorWorksizePreference::MAX,
1206        }
1207    }
1208}
1209
1210/// A compiled operation plan. Dispatch to the matching `execute` method
1211/// based on the op kind that built it.
1212#[derive(Debug)]
1213pub struct Plan<'h> {
1214    plan: cutensorPlan_t,
1215    handle: &'h Handle,
1216    kind: OpKind,
1217}
1218
1219impl<'h> Plan<'h> {
1220    /// Compile an operation descriptor into a plan.
1221    /// `workspace_size_limit` bytes — pass the estimate.
1222    pub fn new(
1223        op: &OperationDescriptor<'h>,
1224        pref: &PlanPreference<'h>,
1225        workspace_size_limit: u64,
1226    ) -> Result<Self> {
1227        let c = cutensor()?;
1228        let cu = c.cutensor_create_plan()?;
1229        let mut p: cutensorPlan_t = core::ptr::null_mut();
1230        check(unsafe {
1231            cu(
1232                op.handle.as_raw(),
1233                &mut p,
1234                op.as_raw(),
1235                pref.as_raw(),
1236                workspace_size_limit,
1237            )
1238        })?;
1239        Ok(Self {
1240            plan: p,
1241            handle: op.handle,
1242            kind: op.kind,
1243        })
1244    }
1245
1246    /// Raw `cutensorPlan_t`. Use with care.
1247    #[inline]
1248    pub fn as_raw(&self) -> cutensorPlan_t {
1249        self.plan
1250    }
1251
1252    /// Execute a contraction plan. Aborts if `self` wasn't built from a
1253    /// [`Contraction`] descriptor.
1254    ///
1255    /// # Safety
1256    ///
1257    /// All device pointers must be live, tensor-descriptor-conforming,
1258    /// and aligned. `workspace` must be at least the estimated size.
1259    #[allow(clippy::too_many_arguments)]
1260    pub unsafe fn contract(
1261        &self,
1262        alpha: *const c_void,
1263        a: *const c_void,
1264        b: *const c_void,
1265        beta: *const c_void,
1266        c: *const c_void,
1267        d: *mut c_void,
1268        workspace: *mut c_void,
1269        workspace_bytes: u64,
1270        stream: *mut c_void,
1271    ) -> Result<()> { unsafe {
1272        assert_eq!(self.kind, OpKind::Contraction, "plan is not a contraction");
1273        let lib = cutensor()?;
1274        let cu = lib.cutensor_contract()?;
1275        check(cu(
1276            self.handle.as_raw(),
1277            self.plan,
1278            alpha,
1279            a,
1280            b,
1281            beta,
1282            c,
1283            d,
1284            workspace,
1285            workspace_bytes,
1286            stream,
1287        ))
1288    }}
1289
1290    /// Execute a reduction plan.
1291    ///
1292    /// # Safety
1293    ///
1294    /// Same as [`Self::contract`].
1295    #[allow(clippy::too_many_arguments)]
1296    pub unsafe fn reduce(
1297        &self,
1298        alpha: *const c_void,
1299        a: *const c_void,
1300        beta: *const c_void,
1301        c: *const c_void,
1302        d: *mut c_void,
1303        workspace: *mut c_void,
1304        workspace_bytes: u64,
1305        stream: *mut c_void,
1306    ) -> Result<()> { unsafe {
1307        assert_eq!(self.kind, OpKind::Reduction, "plan is not a reduction");
1308        let lib = cutensor()?;
1309        let cu = lib.cutensor_reduce()?;
1310        check(cu(
1311            self.handle.as_raw(),
1312            self.plan,
1313            alpha,
1314            a,
1315            beta,
1316            c,
1317            d,
1318            workspace,
1319            workspace_bytes,
1320            stream,
1321        ))
1322    }}
1323
1324    /// Execute an elementwise-binary plan.
1325    ///
1326    /// # Safety
1327    ///
1328    /// Same as [`Self::contract`].
1329    #[allow(clippy::too_many_arguments)]
1330    pub unsafe fn elementwise_binary(
1331        &self,
1332        alpha: *const c_void,
1333        a: *const c_void,
1334        gamma: *const c_void,
1335        c: *const c_void,
1336        d: *mut c_void,
1337        stream: *mut c_void,
1338    ) -> Result<()> { unsafe {
1339        assert_eq!(
1340            self.kind,
1341            OpKind::ElementwiseBinary,
1342            "plan is not an elementwise-binary"
1343        );
1344        let lib = cutensor()?;
1345        let cu = lib.cutensor_elementwise_binary_execute()?;
1346        check(cu(
1347            self.handle.as_raw(),
1348            self.plan,
1349            alpha,
1350            a,
1351            gamma,
1352            c,
1353            d,
1354            stream,
1355        ))
1356    }}
1357
1358    /// Execute an elementwise-trinary plan.
1359    ///
1360    /// # Safety
1361    ///
1362    /// Same as [`Self::contract`].
1363    #[allow(clippy::too_many_arguments)]
1364    pub unsafe fn elementwise_trinary(
1365        &self,
1366        alpha: *const c_void,
1367        a: *const c_void,
1368        beta: *const c_void,
1369        b: *const c_void,
1370        gamma: *const c_void,
1371        c: *const c_void,
1372        d: *mut c_void,
1373        stream: *mut c_void,
1374    ) -> Result<()> { unsafe {
1375        assert_eq!(
1376            self.kind,
1377            OpKind::ElementwiseTrinary,
1378            "plan is not an elementwise-trinary"
1379        );
1380        let lib = cutensor()?;
1381        let cu = lib.cutensor_elementwise_trinary_execute()?;
1382        check(cu(
1383            self.handle.as_raw(),
1384            self.plan,
1385            alpha,
1386            a,
1387            beta,
1388            b,
1389            gamma,
1390            c,
1391            d,
1392            stream,
1393        ))
1394    }}
1395
1396    /// Execute a permutation plan.
1397    ///
1398    /// # Safety
1399    ///
1400    /// Same as [`Self::contract`].
1401    pub unsafe fn permute(
1402        &self,
1403        alpha: *const c_void,
1404        a: *const c_void,
1405        b: *mut c_void,
1406        stream: *mut c_void,
1407    ) -> Result<()> { unsafe {
1408        assert_eq!(self.kind, OpKind::Permutation, "plan is not a permutation");
1409        let lib = cutensor()?;
1410        let cu = lib.cutensor_permute()?;
1411        check(cu(self.handle.as_raw(), self.plan, alpha, a, b, stream))
1412    }}
1413
1414    /// Execute a block-sparse contraction plan.
1415    ///
1416    /// # Safety
1417    ///
1418    /// Same as [`Self::contract`]; `a` must be a block-sparse device
1419    /// buffer matching the `BlockSparseTensorDescriptor` passed to
1420    /// [`BlockSparseContraction::new`].
1421    #[allow(clippy::too_many_arguments)]
1422    pub unsafe fn contract_block_sparse(
1423        &self,
1424        alpha: *const c_void,
1425        a: *const c_void,
1426        b: *const c_void,
1427        beta: *const c_void,
1428        c: *const c_void,
1429        d: *mut c_void,
1430        workspace: *mut c_void,
1431        workspace_bytes: u64,
1432        stream: *mut c_void,
1433    ) -> Result<()> { unsafe {
1434        assert_eq!(
1435            self.kind,
1436            OpKind::BlockSparseContraction,
1437            "plan is not a block-sparse contraction"
1438        );
1439        let lib = cutensor()?;
1440        let cu = lib.cutensor_block_sparse_contract()?;
1441        check(cu(
1442            self.handle.as_raw(),
1443            self.plan,
1444            alpha,
1445            a,
1446            b,
1447            beta,
1448            c,
1449            d,
1450            workspace,
1451            workspace_bytes,
1452            stream,
1453        ))
1454    }}
1455
1456    /// Execute a trinary-contraction plan:
1457    /// `E = α·op_a(A)·op_b(B)·op_c(C) + β·op_d(D)`.
1458    ///
1459    /// # Safety
1460    ///
1461    /// Same as [`Self::contract`].
1462    #[allow(clippy::too_many_arguments)]
1463    pub unsafe fn contract_trinary(
1464        &self,
1465        alpha: *const c_void,
1466        a: *const c_void,
1467        b: *const c_void,
1468        c: *const c_void,
1469        beta: *const c_void,
1470        d: *const c_void,
1471        e: *mut c_void,
1472        workspace: *mut c_void,
1473        workspace_bytes: u64,
1474        stream: *mut c_void,
1475    ) -> Result<()> { unsafe {
1476        assert_eq!(
1477            self.kind,
1478            OpKind::TrinaryContraction,
1479            "plan is not a trinary-contraction"
1480        );
1481        let lib = cutensor()?;
1482        let cu = lib.cutensor_contract_trinary()?;
1483        check(cu(
1484            self.handle.as_raw(),
1485            self.plan,
1486            alpha,
1487            a,
1488            b,
1489            c,
1490            beta,
1491            d,
1492            e,
1493            workspace,
1494            workspace_bytes,
1495            stream,
1496        ))
1497    }}
1498}
1499
1500impl Drop for Plan<'_> {
1501    fn drop(&mut self) {
1502        if let Ok(c) = cutensor() {
1503            if let Ok(cu) = c.cutensor_destroy_plan() {
1504                let _ = unsafe { cu(self.plan) };
1505            }
1506        }
1507    }
1508}