Skip to main content

baracuda_cutensor_sys/
lib.rs

1//! Raw FFI + dynamic loader skeleton for NVIDIA cuTENSOR.
2//!
3//! `baracuda-cutensor` wraps this with a safe, typed API. Use this
4//! crate directly only if you need a function that the safe layer
5//! hasn't wrapped yet (in which case please file a bug).
6//!
7//! cuTENSOR is a separately-installed NVIDIA library for high-performance
8//! tensor contraction, reduction, and element-wise ops. v0.1 ships the
9//! loader + status enum; concrete contraction/permutation/reduction
10//! wrappers follow once CI has a cuTENSOR install.
11
12#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
13#![warn(missing_debug_implementations)]
14
15use std::sync::OnceLock;
16
17use baracuda_core::{Library, LoaderError};
18use baracuda_types::CudaStatus;
19
20/// cuTENSOR status code.
21#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
22#[repr(transparent)]
23pub struct cutensorStatus_t(pub i32);
24
25impl cutensorStatus_t {
26    /// `CUTENSOR_STATUS_SUCCESS` — operation succeeded.
27    pub const SUCCESS: Self = Self(0);
28    /// `CUTENSOR_STATUS_NOT_INITIALIZED` — the library is not initialized.
29    pub const NOT_INITIALIZED: Self = Self(1);
30    /// `CUTENSOR_STATUS_ALLOC_FAILED` — an allocation failed.
31    pub const ALLOC_FAILED: Self = Self(3);
32    /// `CUTENSOR_STATUS_INVALID_VALUE` — an argument has an invalid value.
33    pub const INVALID_VALUE: Self = Self(7);
34    /// `CUTENSOR_STATUS_ARCH_MISMATCH` — the device architecture is unsupported.
35    pub const ARCH_MISMATCH: Self = Self(8);
36    /// `CUTENSOR_STATUS_MAPPING_ERROR` — a host/device memory mapping error occurred.
37    pub const MAPPING_ERROR: Self = Self(11);
38    /// `CUTENSOR_STATUS_EXECUTION_FAILED` — kernel execution failed.
39    pub const EXECUTION_FAILED: Self = Self(13);
40    /// `CUTENSOR_STATUS_INTERNAL_ERROR` — an internal cuTENSOR error occurred.
41    pub const INTERNAL_ERROR: Self = Self(14);
42    /// `CUTENSOR_STATUS_NOT_SUPPORTED` — the requested feature is not supported.
43    pub const NOT_SUPPORTED: Self = Self(15);
44    /// `CUTENSOR_STATUS_LICENSE_ERROR` — license check failed.
45    pub const LICENSE_ERROR: Self = Self(16);
46    /// `CUTENSOR_STATUS_CUBLAS_ERROR` — an internal cuBLAS call failed.
47    pub const CUBLAS_ERROR: Self = Self(17);
48    /// `CUTENSOR_STATUS_CUDA_ERROR` — an internal CUDA call failed.
49    pub const CUDA_ERROR: Self = Self(18);
50    /// `CUTENSOR_STATUS_INSUFFICIENT_WORKSPACE` — supplied workspace buffer is too small.
51    pub const INSUFFICIENT_WORKSPACE: Self = Self(19);
52    /// `CUTENSOR_STATUS_INSUFFICIENT_DRIVER` — installed CUDA driver is too old.
53    pub const INSUFFICIENT_DRIVER: Self = Self(20);
54    /// `CUTENSOR_STATUS_IO_ERROR` — an I/O error occurred (cache read/write, logger, ...).
55    pub const IO_ERROR: Self = Self(21);
56
57    /// Return `true` if the status code denotes success.
58    pub const fn is_success(self) -> bool {
59        self.0 == 0
60    }
61}
62
63impl CudaStatus for cutensorStatus_t {
64    fn code(self) -> i32 {
65        self.0
66    }
67    fn name(self) -> &'static str {
68        match self.0 {
69            0 => "CUTENSOR_STATUS_SUCCESS",
70            1 => "CUTENSOR_STATUS_NOT_INITIALIZED",
71            3 => "CUTENSOR_STATUS_ALLOC_FAILED",
72            7 => "CUTENSOR_STATUS_INVALID_VALUE",
73            13 => "CUTENSOR_STATUS_EXECUTION_FAILED",
74            15 => "CUTENSOR_STATUS_NOT_SUPPORTED",
75            19 => "CUTENSOR_STATUS_INSUFFICIENT_WORKSPACE",
76            _ => "CUTENSOR_STATUS_UNRECOGNIZED",
77        }
78    }
79    fn description(self) -> &'static str {
80        match self.0 {
81            0 => "success",
82            15 => "operation not supported",
83            19 => "workspace buffer too small",
84            _ => "unrecognized cuTENSOR status code",
85        }
86    }
87    fn is_success(self) -> bool {
88        cutensorStatus_t::is_success(self)
89    }
90    fn library(self) -> &'static str {
91        "cutensor"
92    }
93}
94
95// ---- Handle + descriptor types ----
96
97/// Opaque cuTENSOR handle.
98pub type cutensorHandle_t = *mut core::ffi::c_void;
99
100/// Opaque tensor descriptor.
101pub type cutensorTensorDescriptor_t = *mut core::ffi::c_void;
102
103/// Opaque contraction-plan descriptor.
104pub type cutensorOperationDescriptor_t = *mut core::ffi::c_void;
105
106/// Opaque plan preference handle.
107pub type cutensorPlanPreference_t = *mut core::ffi::c_void;
108
109/// Opaque plan (built from an operation descriptor + preference).
110pub type cutensorPlan_t = *mut core::ffi::c_void;
111
112/// `cutensorDataType_t` — element type enum.
113#[allow(non_snake_case)]
114pub mod cutensorDataType {
115    /// `CUTENSOR_R_16F` — real fp16.
116    pub const R_16F: i32 = 2; // fp16
117    /// `CUTENSOR_R_16BF` — real bfloat16.
118    pub const R_16BF: i32 = 14; // bfloat16
119    /// `CUTENSOR_R_32F` — real fp32.
120    pub const R_32F: i32 = 0; // float
121    /// `CUTENSOR_R_64F` — real fp64.
122    pub const R_64F: i32 = 1; // double
123    /// `CUTENSOR_C_32F` — complex fp32.
124    pub const C_32F: i32 = 4;
125    /// `CUTENSOR_C_64F` — complex fp64.
126    pub const C_64F: i32 = 5;
127    /// `CUTENSOR_R_8I` — real signed 8-bit integer.
128    pub const R_8I: i32 = 3;
129    /// `CUTENSOR_R_8U` — real unsigned 8-bit integer.
130    pub const R_8U: i32 = 8;
131    /// `CUTENSOR_R_32I` — real signed 32-bit integer.
132    pub const R_32I: i32 = 10;
133    /// `CUTENSOR_R_32U` — real unsigned 32-bit integer.
134    pub const R_32U: i32 = 12;
135}
136
137/// `cutensorComputeDescriptor_t` — the compute-precision descriptor
138/// used on modern cuTENSOR (v2+). Opaque pointer.
139///
140/// In v2 this must NOT be null when building operation descriptors.
141/// Get a valid pointer via [`Cutensor::compute_desc_32f`] and friends —
142/// these resolve the library's pre-defined descriptor globals
143/// (`CUTENSOR_COMPUTE_DESC_32F` etc.).
144pub type cutensorComputeDescriptor_t = *const core::ffi::c_void;
145
146impl Cutensor {
147    /// Return a pre-defined compute descriptor resolved from the cuTENSOR
148    /// shared library's exported data symbols. These symbols are global
149    /// pointer variables (`CUTENSOR_COMPUTE_DESC_*`) that the library
150    /// initializes — we read the pointer value at the symbol's address.
151    fn compute_desc_by_name(
152        &self,
153        name: &'static str,
154    ) -> Result<cutensorComputeDescriptor_t, LoaderError> {
155        // SAFETY: the symbol resolves to a `cutensorComputeDescriptor_t*`
156        // that cuTENSOR initializes on library load.
157        let raw: *mut () = unsafe { self.lib.raw_symbol(name)? };
158        let ptr_ptr = raw as *const cutensorComputeDescriptor_t;
159        Ok(unsafe { *ptr_ptr })
160    }
161
162    /// `CUTENSOR_COMPUTE_DESC_32F` — 32-bit float compute.
163    pub fn compute_desc_32f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
164        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_32F")
165    }
166    /// `CUTENSOR_COMPUTE_DESC_64F` — 64-bit float compute.
167    pub fn compute_desc_64f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
168        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_64F")
169    }
170    /// `CUTENSOR_COMPUTE_DESC_16F` — 16-bit float compute.
171    pub fn compute_desc_16f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
172        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_16F")
173    }
174    /// `CUTENSOR_COMPUTE_DESC_16BF` — bfloat16 compute.
175    pub fn compute_desc_16bf(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
176        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_16BF")
177    }
178    /// `CUTENSOR_COMPUTE_DESC_TF32` — TensorFloat32 compute.
179    pub fn compute_desc_tf32(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
180        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_TF32")
181    }
182    /// `CUTENSOR_COMPUTE_DESC_3XTF32` — 3× TF32 mantissa-extended compute.
183    pub fn compute_desc_3xtf32(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
184        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_3XTF32")
185    }
186    /// `CUTENSOR_COMPUTE_DESC_4X16F` — 4× FP16 mantissa-extended compute.
187    pub fn compute_desc_4x16f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
188        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_4X16F")
189    }
190    /// `CUTENSOR_COMPUTE_DESC_8XINT8` — 8× INT8 packed compute.
191    pub fn compute_desc_8xint8(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
192        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_8XINT8")
193    }
194    /// `CUTENSOR_COMPUTE_DESC_9X16BF` — 9× BF16 mantissa-extended compute.
195    pub fn compute_desc_9x16bf(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
196        self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_9X16BF")
197    }
198}
199
200/// `cutensorOperator_t` — element-wise op selector.
201#[allow(non_snake_case)]
202pub mod cutensorOperator {
203    /// `CUTENSOR_OP_IDENTITY` — identity (pass-through).
204    pub const IDENTITY: i32 = 1;
205    /// `CUTENSOR_OP_SQRT` — element-wise square root.
206    pub const SQRT: i32 = 2;
207    /// `CUTENSOR_OP_RELU` — element-wise ReLU.
208    pub const RELU: i32 = 8;
209    /// `CUTENSOR_OP_CONJ` — complex conjugate.
210    pub const CONJ: i32 = 9;
211    /// `CUTENSOR_OP_RCP` — element-wise reciprocal.
212    pub const RCP: i32 = 10;
213    /// `CUTENSOR_OP_SIGMOID` — element-wise sigmoid.
214    pub const SIGMOID: i32 = 11;
215    /// `CUTENSOR_OP_TANH` — element-wise hyperbolic tangent.
216    pub const TANH: i32 = 12;
217    /// `CUTENSOR_OP_ADD` — binary add (combines operands).
218    pub const ADD: i32 = 3;
219    /// `CUTENSOR_OP_MUL` — binary multiply (combines operands).
220    pub const MUL: i32 = 5;
221    /// `CUTENSOR_OP_MAX` — binary max (combines operands).
222    pub const MAX: i32 = 6;
223    /// `CUTENSOR_OP_MIN` — binary min (combines operands).
224    pub const MIN: i32 = 7;
225}
226
227/// `cutensorAlgo_t` — algorithm selector for contraction planning.
228#[allow(non_snake_case)]
229pub mod cutensorAlgo {
230    /// `CUTENSOR_ALGO_DEFAULT` — library-chosen algorithm.
231    pub const DEFAULT: i32 = -1;
232    /// `CUTENSOR_ALGO_GETT` — GETT (general tensor-tensor) algorithm.
233    pub const GETT: i32 = -4;
234    /// `CUTENSOR_ALGO_TGETT` — transposed-GETT algorithm.
235    pub const TGETT: i32 = -3;
236    /// `CUTENSOR_ALGO_TTGT` — transpose-transpose-GEMM-transpose algorithm.
237    pub const TTGT: i32 = -2;
238}
239
240/// `cutensorJitMode_t` — Just-in-time-compile selector (cuTENSOR 2.x).
241#[allow(non_snake_case)]
242pub mod cutensorJitMode {
243    /// `CUTENSOR_JIT_MODE_NONE` — JIT compilation disabled.
244    pub const NONE: i32 = 0;
245    /// `CUTENSOR_JIT_MODE_DEFAULT` — library-chosen JIT default.
246    pub const DEFAULT: i32 = 1;
247}
248
249/// `cutensorWorksizePreference_t`.
250#[allow(non_snake_case)]
251pub mod cutensorWorksizePreference {
252    /// `CUTENSOR_WORKSPACE_MIN` — request minimum workspace.
253    pub const MIN: i32 = 1;
254    /// `CUTENSOR_WORKSPACE_DEFAULT` — request library default workspace.
255    pub const DEFAULT: i32 = 2;
256    /// `CUTENSOR_WORKSPACE_MAX` — request maximum workspace.
257    pub const MAX: i32 = 3;
258}
259
260// ---- PFN types ----
261
262/// Function-pointer type for `cutensorCreate` (create cuTENSOR library handle). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
263pub type PFN_cutensorCreate =
264    unsafe extern "C" fn(handle_out: *mut cutensorHandle_t) -> cutensorStatus_t;
265/// Function-pointer type for `cutensorDestroy` (destroy cuTENSOR library handle). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
266pub type PFN_cutensorDestroy = unsafe extern "C" fn(handle: cutensorHandle_t) -> cutensorStatus_t;
267
268/// Function-pointer type for `cutensorCreateTensorDescriptor` (create a tensor descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
269pub type PFN_cutensorCreateTensorDescriptor = unsafe extern "C" fn(
270    handle: cutensorHandle_t,
271    desc_out: *mut cutensorTensorDescriptor_t,
272    num_modes: u32,
273    extents: *const i64,
274    strides: *const i64,
275    data_type: i32,
276    alignment_bytes: u32,
277) -> cutensorStatus_t;
278/// Function-pointer type for `cutensorDestroyTensorDescriptor` (destroy a tensor descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
279pub type PFN_cutensorDestroyTensorDescriptor =
280    unsafe extern "C" fn(desc: cutensorTensorDescriptor_t) -> cutensorStatus_t;
281
282/// Function-pointer type for `cutensorCreateContraction` (build an operation descriptor for a tensor contraction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
283pub type PFN_cutensorCreateContraction = unsafe extern "C" fn(
284    handle: cutensorHandle_t,
285    op_desc_out: *mut cutensorOperationDescriptor_t,
286    desc_a: cutensorTensorDescriptor_t,
287    modes_a: *const i32,
288    op_a: i32,
289    desc_b: cutensorTensorDescriptor_t,
290    modes_b: *const i32,
291    op_b: i32,
292    desc_c: cutensorTensorDescriptor_t,
293    modes_c: *const i32,
294    op_c: i32,
295    desc_d: cutensorTensorDescriptor_t,
296    modes_d: *const i32,
297    compute_desc: cutensorComputeDescriptor_t,
298) -> cutensorStatus_t;
299
300/// Function-pointer type for `cutensorDestroyOperationDescriptor` (destroy an operation descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
301pub type PFN_cutensorDestroyOperationDescriptor =
302    unsafe extern "C" fn(desc: cutensorOperationDescriptor_t) -> cutensorStatus_t;
303
304/// Function-pointer type for `cutensorCreatePlanPreference` (create a plan-preference object). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
305pub type PFN_cutensorCreatePlanPreference = unsafe extern "C" fn(
306    handle: cutensorHandle_t,
307    pref_out: *mut cutensorPlanPreference_t,
308    algo: i32,
309    jit_mode: i32,
310) -> cutensorStatus_t;
311/// Function-pointer type for `cutensorDestroyPlanPreference` (destroy a plan-preference object). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
312pub type PFN_cutensorDestroyPlanPreference =
313    unsafe extern "C" fn(pref: cutensorPlanPreference_t) -> cutensorStatus_t;
314
315/// Function-pointer type for `cutensorEstimateWorkspaceSize` (estimate workspace bytes required by a plan). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
316pub type PFN_cutensorEstimateWorkspaceSize = unsafe extern "C" fn(
317    handle: cutensorHandle_t,
318    op_desc: cutensorOperationDescriptor_t,
319    pref: cutensorPlanPreference_t,
320    workspace_pref: i32,
321    workspace_size_bytes_out: *mut u64,
322) -> cutensorStatus_t;
323
324/// Function-pointer type for `cutensorCreatePlan` (build an execution plan from an operation descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
325pub type PFN_cutensorCreatePlan = unsafe extern "C" fn(
326    handle: cutensorHandle_t,
327    plan_out: *mut cutensorPlan_t,
328    op_desc: cutensorOperationDescriptor_t,
329    pref: cutensorPlanPreference_t,
330    workspace_size_limit: u64,
331) -> cutensorStatus_t;
332/// Function-pointer type for `cutensorDestroyPlan` (destroy an execution plan). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
333pub type PFN_cutensorDestroyPlan = unsafe extern "C" fn(plan: cutensorPlan_t) -> cutensorStatus_t;
334
335/// Function-pointer type for `cutensorContract` (execute tensor contraction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
336pub type PFN_cutensorContract = unsafe extern "C" fn(
337    handle: cutensorHandle_t,
338    plan: cutensorPlan_t,
339    alpha: *const core::ffi::c_void,
340    a: *const core::ffi::c_void,
341    b: *const core::ffi::c_void,
342    beta: *const core::ffi::c_void,
343    c: *const core::ffi::c_void,
344    d: *mut core::ffi::c_void,
345    workspace: *mut core::ffi::c_void,
346    workspace_size_bytes: u64,
347    stream: *mut core::ffi::c_void, // cudaStream_t
348) -> cutensorStatus_t;
349
350/// Function-pointer type for `cutensorGetVersion` (query cuTENSOR library version). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
351pub type PFN_cutensorGetVersion = unsafe extern "C" fn() -> usize;
352/// Function-pointer type for `cutensorGetCudartVersion` (query the CUDA Runtime version cuTENSOR was built against). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
353pub type PFN_cutensorGetCudartVersion = unsafe extern "C" fn() -> usize;
354/// Function-pointer type for `cutensorGetErrorString` (decode a cutensorStatus_t into a static C string). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
355pub type PFN_cutensorGetErrorString =
356    unsafe extern "C" fn(status: cutensorStatus_t) -> *const core::ffi::c_char;
357
358// ---- Compute descriptor (opaque — built-in ones are statically exported from cuTENSOR) ----
359
360/// Function-pointer type for `cutensorCreateElementwiseBinary` (build an operation descriptor for an element-wise binary op). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
361pub type PFN_cutensorCreateElementwiseBinary = unsafe extern "C" fn(
362    handle: cutensorHandle_t,
363    op_desc_out: *mut cutensorOperationDescriptor_t,
364    desc_a: cutensorTensorDescriptor_t,
365    modes_a: *const i32,
366    op_a: i32,
367    desc_c: cutensorTensorDescriptor_t,
368    modes_c: *const i32,
369    op_c: i32,
370    desc_d: cutensorTensorDescriptor_t,
371    modes_d: *const i32,
372    op_ac: i32, // op_ac is the binary operator between A and C
373    compute_desc: cutensorComputeDescriptor_t,
374) -> cutensorStatus_t;
375
376/// Function-pointer type for `cutensorElementwiseBinaryExecute` (execute an element-wise binary plan). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
377pub type PFN_cutensorElementwiseBinaryExecute = unsafe extern "C" fn(
378    handle: cutensorHandle_t,
379    plan: cutensorPlan_t,
380    alpha: *const core::ffi::c_void,
381    a: *const core::ffi::c_void,
382    gamma: *const core::ffi::c_void,
383    c: *const core::ffi::c_void,
384    d: *mut core::ffi::c_void,
385    stream: *mut core::ffi::c_void,
386) -> cutensorStatus_t;
387
388/// Function-pointer type for `cutensorCreateElementwiseTrinary` (build an operation descriptor for an element-wise trinary op). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
389pub type PFN_cutensorCreateElementwiseTrinary = unsafe extern "C" fn(
390    handle: cutensorHandle_t,
391    op_desc_out: *mut cutensorOperationDescriptor_t,
392    desc_a: cutensorTensorDescriptor_t,
393    modes_a: *const i32,
394    op_a: i32,
395    desc_b: cutensorTensorDescriptor_t,
396    modes_b: *const i32,
397    op_b: i32,
398    desc_c: cutensorTensorDescriptor_t,
399    modes_c: *const i32,
400    op_c: i32,
401    desc_d: cutensorTensorDescriptor_t,
402    modes_d: *const i32,
403    op_ab: i32,
404    op_abc: i32,
405    compute_desc: cutensorComputeDescriptor_t,
406) -> cutensorStatus_t;
407
408/// Function-pointer type for `cutensorElementwiseTrinaryExecute` (execute an element-wise trinary plan). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
409pub type PFN_cutensorElementwiseTrinaryExecute = unsafe extern "C" fn(
410    handle: cutensorHandle_t,
411    plan: cutensorPlan_t,
412    alpha: *const core::ffi::c_void,
413    a: *const core::ffi::c_void,
414    beta: *const core::ffi::c_void,
415    b: *const core::ffi::c_void,
416    gamma: *const core::ffi::c_void,
417    c: *const core::ffi::c_void,
418    d: *mut core::ffi::c_void,
419    stream: *mut core::ffi::c_void,
420) -> cutensorStatus_t;
421
422/// Function-pointer type for `cutensorCreatePermutation` (build an operation descriptor for a tensor permutation). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
423pub type PFN_cutensorCreatePermutation = unsafe extern "C" fn(
424    handle: cutensorHandle_t,
425    op_desc_out: *mut cutensorOperationDescriptor_t,
426    desc_a: cutensorTensorDescriptor_t,
427    modes_a: *const i32,
428    op_a: i32,
429    desc_b: cutensorTensorDescriptor_t,
430    modes_b: *const i32,
431    compute_desc: cutensorComputeDescriptor_t,
432) -> cutensorStatus_t;
433
434/// Function-pointer type for `cutensorPermute` (execute tensor permutation). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
435pub type PFN_cutensorPermute = unsafe extern "C" fn(
436    handle: cutensorHandle_t,
437    plan: cutensorPlan_t,
438    alpha: *const core::ffi::c_void,
439    a: *const core::ffi::c_void,
440    b: *mut core::ffi::c_void,
441    stream: *mut core::ffi::c_void,
442) -> cutensorStatus_t;
443
444/// Function-pointer type for `cutensorCreateReduction` (build an operation descriptor for a tensor reduction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
445pub type PFN_cutensorCreateReduction = unsafe extern "C" fn(
446    handle: cutensorHandle_t,
447    op_desc_out: *mut cutensorOperationDescriptor_t,
448    desc_a: cutensorTensorDescriptor_t,
449    modes_a: *const i32,
450    op_a: i32,
451    desc_c: cutensorTensorDescriptor_t,
452    modes_c: *const i32,
453    op_c: i32,
454    desc_d: cutensorTensorDescriptor_t,
455    modes_d: *const i32,
456    op_reduce: i32, // ADD / MUL / MAX / MIN
457    compute_desc: cutensorComputeDescriptor_t,
458) -> cutensorStatus_t;
459
460/// Function-pointer type for `cutensorReduce` (execute tensor reduction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
461pub type PFN_cutensorReduce = unsafe extern "C" fn(
462    handle: cutensorHandle_t,
463    plan: cutensorPlan_t,
464    alpha: *const core::ffi::c_void,
465    a: *const core::ffi::c_void,
466    beta: *const core::ffi::c_void,
467    c: *const core::ffi::c_void,
468    d: *mut core::ffi::c_void,
469    workspace: *mut core::ffi::c_void,
470    workspace_size: u64,
471    stream: *mut core::ffi::c_void,
472) -> cutensorStatus_t;
473
474// ---- Attribute getters / setters for operation descriptors + plan preferences ----
475
476/// Function-pointer type for `cutensorOperationDescriptorGetAttribute` (get an attribute on an operation descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
477pub type PFN_cutensorOperationDescriptorGetAttribute = unsafe extern "C" fn(
478    handle: cutensorHandle_t,
479    op_desc: cutensorOperationDescriptor_t,
480    attr: i32,
481    buf: *mut core::ffi::c_void,
482    size_in_bytes: usize,
483) -> cutensorStatus_t;
484
485/// Function-pointer type for `cutensorOperationDescriptorSetAttribute` (set an attribute on an operation descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
486pub type PFN_cutensorOperationDescriptorSetAttribute = unsafe extern "C" fn(
487    handle: cutensorHandle_t,
488    op_desc: cutensorOperationDescriptor_t,
489    attr: i32,
490    buf: *const core::ffi::c_void,
491    size_in_bytes: usize,
492) -> cutensorStatus_t;
493
494/// Function-pointer type for `cutensorPlanPreferenceSetAttribute` (set an attribute on a plan-preference object). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
495pub type PFN_cutensorPlanPreferenceSetAttribute = unsafe extern "C" fn(
496    handle: cutensorHandle_t,
497    pref: cutensorPlanPreference_t,
498    attr: i32,
499    buf: *const core::ffi::c_void,
500    size_in_bytes: usize,
501) -> cutensorStatus_t;
502
503/// Function-pointer type for `cutensorPlanGetAttribute` (get an attribute on a plan). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
504pub type PFN_cutensorPlanGetAttribute = unsafe extern "C" fn(
505    handle: cutensorHandle_t,
506    plan: cutensorPlan_t,
507    attr: i32,
508    buf: *mut core::ffi::c_void,
509    size_in_bytes: usize,
510) -> cutensorStatus_t;
511
512/// Function-pointer type for `cutensorTensorDescriptorGetAttribute` (get an attribute on a tensor descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
513pub type PFN_cutensorTensorDescriptorGetAttribute = unsafe extern "C" fn(
514    handle: cutensorHandle_t,
515    desc: cutensorTensorDescriptor_t,
516    attr: i32,
517    buf: *mut core::ffi::c_void,
518    size_in_bytes: usize,
519) -> cutensorStatus_t;
520
521// ---- Plan-cache management ----
522
523/// Function-pointer type for `cutensorHandleResizePlanCache` (resize a handle's plan cache). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
524pub type PFN_cutensorHandleResizePlanCache =
525    unsafe extern "C" fn(handle: cutensorHandle_t, num_entries: u32) -> cutensorStatus_t;
526
527/// Function-pointer type for `cutensorHandleReadCacheFromFile` (read a plan/kernel cache from a file). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
528pub type PFN_cutensorHandleReadCacheFromFile = unsafe extern "C" fn(
529    handle: cutensorHandle_t,
530    filename: *const core::ffi::c_char,
531) -> cutensorStatus_t;
532
533/// Function-pointer type for `cutensorHandleWriteCacheToFile` (write a plan/kernel cache to a file). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
534pub type PFN_cutensorHandleWriteCacheToFile = unsafe extern "C" fn(
535    handle: cutensorHandle_t,
536    filename: *const core::ffi::c_char,
537) -> cutensorStatus_t;
538
539// ---- Trinary contraction (3-tensor chains) ----
540
541/// Function-pointer type for `cutensorCreateContractionTrinary` (build an operation descriptor for a three-tensor contraction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
542pub type PFN_cutensorCreateContractionTrinary = unsafe extern "C" fn(
543    handle: cutensorHandle_t,
544    op_desc_out: *mut cutensorOperationDescriptor_t,
545    desc_a: cutensorTensorDescriptor_t,
546    modes_a: *const i32,
547    op_a: i32,
548    desc_b: cutensorTensorDescriptor_t,
549    modes_b: *const i32,
550    op_b: i32,
551    desc_c: cutensorTensorDescriptor_t,
552    modes_c: *const i32,
553    op_c: i32,
554    desc_d: cutensorTensorDescriptor_t,
555    modes_d: *const i32,
556    op_d: i32,
557    desc_e: cutensorTensorDescriptor_t,
558    modes_e: *const i32,
559    compute_desc: cutensorComputeDescriptor_t,
560) -> cutensorStatus_t;
561
562/// Function-pointer type for `cutensorContractTrinary` (execute three-tensor contraction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
563pub type PFN_cutensorContractTrinary = unsafe extern "C" fn(
564    handle: cutensorHandle_t,
565    plan: cutensorPlan_t,
566    alpha: *const core::ffi::c_void,
567    a: *const core::ffi::c_void,
568    b: *const core::ffi::c_void,
569    c: *const core::ffi::c_void,
570    beta: *const core::ffi::c_void,
571    d: *const core::ffi::c_void,
572    e: *mut core::ffi::c_void,
573    workspace: *mut core::ffi::c_void,
574    workspace_size: u64,
575    stream: *mut core::ffi::c_void,
576) -> cutensorStatus_t;
577
578// ---- Custom compute descriptor lifecycle ----
579
580/// Function-pointer type for `cutensorCreateComputeDescriptor` (create a compute-precision descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
581pub type PFN_cutensorCreateComputeDescriptor = unsafe extern "C" fn(
582    handle: cutensorHandle_t,
583    desc_out: *mut cutensorComputeDescriptor_t,
584) -> cutensorStatus_t;
585
586/// Function-pointer type for `cutensorDestroyComputeDescriptor` (destroy a compute-precision descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
587pub type PFN_cutensorDestroyComputeDescriptor =
588    unsafe extern "C" fn(desc: cutensorComputeDescriptor_t) -> cutensorStatus_t;
589
590/// Function-pointer type for `cutensorComputeDescriptorGetAttribute` (get an attribute on a compute descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
591pub type PFN_cutensorComputeDescriptorGetAttribute = unsafe extern "C" fn(
592    handle: cutensorHandle_t,
593    desc: cutensorComputeDescriptor_t,
594    attr: i32,
595    buf: *mut core::ffi::c_void,
596    size_in_bytes: usize,
597) -> cutensorStatus_t;
598
599/// Function-pointer type for `cutensorComputeDescriptorSetAttribute` (set an attribute on a compute descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
600pub type PFN_cutensorComputeDescriptorSetAttribute = unsafe extern "C" fn(
601    handle: cutensorHandle_t,
602    desc: cutensorComputeDescriptor_t,
603    attr: i32,
604    buf: *const core::ffi::c_void,
605    size_in_bytes: usize,
606) -> cutensorStatus_t;
607
608// ---- Additional attribute getters / setters ----
609
610/// Function-pointer type for `cutensorTensorDescriptorSetAttribute` (set an attribute on a tensor descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
611pub type PFN_cutensorTensorDescriptorSetAttribute = unsafe extern "C" fn(
612    handle: cutensorHandle_t,
613    desc: cutensorTensorDescriptor_t,
614    attr: i32,
615    buf: *const core::ffi::c_void,
616    size_in_bytes: usize,
617) -> cutensorStatus_t;
618
619/// Function-pointer type for `cutensorPlanPreferenceGetAttribute` (get an attribute on a plan-preference object). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
620pub type PFN_cutensorPlanPreferenceGetAttribute = unsafe extern "C" fn(
621    handle: cutensorHandle_t,
622    pref: cutensorPlanPreference_t,
623    attr: i32,
624    buf: *mut core::ffi::c_void,
625    size_in_bytes: usize,
626) -> cutensorStatus_t;
627
628// ---- Operation-level introspection ----
629
630/// Function-pointer type for `cutensorOperationEstimateRuntime` (estimate runtime in milliseconds for a planned operation). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
631pub type PFN_cutensorOperationEstimateRuntime = unsafe extern "C" fn(
632    handle: cutensorHandle_t,
633    op_desc: cutensorOperationDescriptor_t,
634    pref: cutensorPlanPreference_t,
635    algo: i32,
636    runtime_ms_out: *mut f32,
637) -> cutensorStatus_t;
638
639/// Function-pointer type for `cutensorOperationNumAlgos` (query the number of algorithms available for an operation). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
640pub type PFN_cutensorOperationNumAlgos = unsafe extern "C" fn(
641    op_desc: cutensorOperationDescriptor_t,
642    num_algos_out: *mut i32,
643) -> cutensorStatus_t;
644
645// ---- Logging ----
646
647/// Function-pointer type for `cutensorLoggerSetLevel` (set logger verbosity level). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
648pub type PFN_cutensorLoggerSetLevel = unsafe extern "C" fn(level: i32) -> cutensorStatus_t;
649
650/// Function-pointer type for `cutensorLoggerSetMask` (set logger category mask). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
651pub type PFN_cutensorLoggerSetMask = unsafe extern "C" fn(mask: i32) -> cutensorStatus_t;
652
653/// Function-pointer type for `cutensorLoggerOpenFile` (open a logger output file by path). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
654pub type PFN_cutensorLoggerOpenFile =
655    unsafe extern "C" fn(path: *const core::ffi::c_char) -> cutensorStatus_t;
656
657/// Function-pointer type for `cutensorLoggerSetFile` (redirect logger output to an open FILE*). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
658pub type PFN_cutensorLoggerSetFile =
659    unsafe extern "C" fn(file: *mut core::ffi::c_void) -> cutensorStatus_t;
660
661/// Function-pointer type for `cutensorLoggerSetCallback` (register a logger callback). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
662pub type PFN_cutensorLoggerSetCallback = unsafe extern "C" fn(
663    callback: Option<unsafe extern "C" fn(i32, *const core::ffi::c_char, *const core::ffi::c_char)>,
664) -> cutensorStatus_t;
665
666/// Function-pointer type for `cutensorLoggerForceDisable` (force-disable the logger). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
667pub type PFN_cutensorLoggerForceDisable = unsafe extern "C" fn() -> cutensorStatus_t;
668
669// ---- Block-sparse contraction (cuTENSOR 2.x) ----
670
671/// Opaque block-sparse tensor descriptor.
672pub type cutensorBlockSparseTensorDescriptor_t = *mut core::ffi::c_void;
673
674/// Function-pointer type for `cutensorCreateBlockSparseTensorDescriptor` (create a block-sparse tensor descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
675pub type PFN_cutensorCreateBlockSparseTensorDescriptor = unsafe extern "C" fn(
676    handle: cutensorHandle_t,
677    desc_out: *mut cutensorBlockSparseTensorDescriptor_t,
678    num_modes: u32,
679    extents: *const i64,
680    block_size: *const i64,
681    strides: *const i64,
682    block_index_count: i64,
683    block_indices: *const i32,
684    data_type: i32,
685    alignment_bytes: u32,
686) -> cutensorStatus_t;
687
688/// Function-pointer type for `cutensorDestroyBlockSparseTensorDescriptor` (destroy a block-sparse tensor descriptor). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
689pub type PFN_cutensorDestroyBlockSparseTensorDescriptor =
690    unsafe extern "C" fn(desc: cutensorBlockSparseTensorDescriptor_t) -> cutensorStatus_t;
691
692/// Function-pointer type for `cutensorCreateBlockSparseContraction` (build an operation descriptor for a block-sparse contraction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
693pub type PFN_cutensorCreateBlockSparseContraction = unsafe extern "C" fn(
694    handle: cutensorHandle_t,
695    op_desc_out: *mut cutensorOperationDescriptor_t,
696    desc_a: cutensorBlockSparseTensorDescriptor_t,
697    modes_a: *const i32,
698    op_a: i32,
699    desc_b: cutensorTensorDescriptor_t,
700    modes_b: *const i32,
701    op_b: i32,
702    desc_c: cutensorTensorDescriptor_t,
703    modes_c: *const i32,
704    op_c: i32,
705    desc_d: cutensorTensorDescriptor_t,
706    modes_d: *const i32,
707    compute_desc: cutensorComputeDescriptor_t,
708) -> cutensorStatus_t;
709
710/// Function-pointer type for `cutensorBlockSparseContract` (execute block-sparse tensor contraction). See <https://docs.nvidia.com/cuda/cutensor/index.html>.
711pub type PFN_cutensorBlockSparseContract = unsafe extern "C" fn(
712    handle: cutensorHandle_t,
713    plan: cutensorPlan_t,
714    alpha: *const core::ffi::c_void,
715    a: *const core::ffi::c_void,
716    b: *const core::ffi::c_void,
717    beta: *const core::ffi::c_void,
718    c: *const core::ffi::c_void,
719    d: *mut core::ffi::c_void,
720    workspace: *mut core::ffi::c_void,
721    workspace_size: u64,
722    stream: *mut core::ffi::c_void,
723) -> cutensorStatus_t;
724
725// ---- Loader ----
726
727macro_rules! cutensor_fns {
728    ($($(#[$attr:meta])* fn $name:ident as $sym:literal : $pfn:ty;)*) => {
729        /// Lazily-resolved cuTENSOR function-pointer table.
730        pub struct Cutensor {
731            /// Lib field.
732            pub lib: Library,
733            $(
734                $name: OnceLock<$pfn>,
735            )*
736        }
737
738        impl core::fmt::Debug for Cutensor {
739            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
740                f.debug_struct("Cutensor").field("lib", &self.lib).finish_non_exhaustive()
741            }
742        }
743
744        impl Cutensor {
745            fn empty(lib: Library) -> Self {
746                Self { lib, $($name: OnceLock::new(),)* }
747            }
748            $(
749                $(#[$attr])*
750                #[doc = concat!("Resolve `", $sym, "`.")]
751                pub fn $name(&self) -> Result<$pfn, LoaderError> {
752                    if let Some(&p) = self.$name.get() { return Ok(p); }
753                    let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
754                    let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
755                    let _ = self.$name.set(p);
756                    Ok(p)
757                }
758            )*
759        }
760    };
761}
762
763cutensor_fns! {
764    fn cutensor_create as "cutensorCreate": PFN_cutensorCreate;
765    fn cutensor_destroy as "cutensorDestroy": PFN_cutensorDestroy;
766    fn cutensor_create_tensor_descriptor as "cutensorCreateTensorDescriptor":
767        PFN_cutensorCreateTensorDescriptor;
768    fn cutensor_destroy_tensor_descriptor as "cutensorDestroyTensorDescriptor":
769        PFN_cutensorDestroyTensorDescriptor;
770    fn cutensor_create_contraction as "cutensorCreateContraction": PFN_cutensorCreateContraction;
771    fn cutensor_destroy_operation_descriptor as "cutensorDestroyOperationDescriptor":
772        PFN_cutensorDestroyOperationDescriptor;
773    fn cutensor_create_plan_preference as "cutensorCreatePlanPreference":
774        PFN_cutensorCreatePlanPreference;
775    fn cutensor_destroy_plan_preference as "cutensorDestroyPlanPreference":
776        PFN_cutensorDestroyPlanPreference;
777    fn cutensor_estimate_workspace_size as "cutensorEstimateWorkspaceSize":
778        PFN_cutensorEstimateWorkspaceSize;
779    fn cutensor_create_plan as "cutensorCreatePlan": PFN_cutensorCreatePlan;
780    fn cutensor_destroy_plan as "cutensorDestroyPlan": PFN_cutensorDestroyPlan;
781    fn cutensor_contract as "cutensorContract": PFN_cutensorContract;
782    fn cutensor_get_version as "cutensorGetVersion": PFN_cutensorGetVersion;
783    fn cutensor_get_cudart_version as "cutensorGetCudartVersion": PFN_cutensorGetCudartVersion;
784    fn cutensor_get_error_string as "cutensorGetErrorString": PFN_cutensorGetErrorString;
785
786    // Elementwise binary (A op C → D)
787    fn cutensor_create_elementwise_binary as "cutensorCreateElementwiseBinary":
788        PFN_cutensorCreateElementwiseBinary;
789    fn cutensor_elementwise_binary_execute as "cutensorElementwiseBinaryExecute":
790        PFN_cutensorElementwiseBinaryExecute;
791
792    // Elementwise trinary ((A op B) op C → D)
793    fn cutensor_create_elementwise_trinary as "cutensorCreateElementwiseTrinary":
794        PFN_cutensorCreateElementwiseTrinary;
795    fn cutensor_elementwise_trinary_execute as "cutensorElementwiseTrinaryExecute":
796        PFN_cutensorElementwiseTrinaryExecute;
797
798    // Permutation
799    fn cutensor_create_permutation as "cutensorCreatePermutation":
800        PFN_cutensorCreatePermutation;
801    fn cutensor_permute as "cutensorPermute": PFN_cutensorPermute;
802
803    // Reduction
804    fn cutensor_create_reduction as "cutensorCreateReduction": PFN_cutensorCreateReduction;
805    fn cutensor_reduce as "cutensorReduce": PFN_cutensorReduce;
806
807    // Attributes
808    fn cutensor_operation_descriptor_get_attribute as "cutensorOperationDescriptorGetAttribute":
809        PFN_cutensorOperationDescriptorGetAttribute;
810    fn cutensor_operation_descriptor_set_attribute as "cutensorOperationDescriptorSetAttribute":
811        PFN_cutensorOperationDescriptorSetAttribute;
812    fn cutensor_plan_preference_set_attribute as "cutensorPlanPreferenceSetAttribute":
813        PFN_cutensorPlanPreferenceSetAttribute;
814    fn cutensor_plan_get_attribute as "cutensorPlanGetAttribute":
815        PFN_cutensorPlanGetAttribute;
816    fn cutensor_tensor_descriptor_get_attribute as "cutensorTensorDescriptorGetAttribute":
817        PFN_cutensorTensorDescriptorGetAttribute;
818
819    // Plan cache
820    fn cutensor_handle_resize_plan_cache as "cutensorHandleResizePlanCache":
821        PFN_cutensorHandleResizePlanCache;
822    fn cutensor_handle_read_plan_cache_from_file as "cutensorHandleReadPlanCacheFromFile":
823        PFN_cutensorHandleReadCacheFromFile;
824    fn cutensor_handle_write_plan_cache_to_file as "cutensorHandleWritePlanCacheToFile":
825        PFN_cutensorHandleWriteCacheToFile;
826    fn cutensor_read_kernel_cache_from_file as "cutensorReadKernelCacheFromFile":
827        PFN_cutensorHandleReadCacheFromFile;
828    fn cutensor_write_kernel_cache_to_file as "cutensorWriteKernelCacheToFile":
829        PFN_cutensorHandleWriteCacheToFile;
830
831    // Trinary contraction
832    fn cutensor_create_contraction_trinary as "cutensorCreateContractionTrinary":
833        PFN_cutensorCreateContractionTrinary;
834    fn cutensor_contract_trinary as "cutensorContractTrinary": PFN_cutensorContractTrinary;
835
836    // Custom compute descriptors
837    fn cutensor_create_compute_descriptor as "cutensorCreateComputeDescriptor":
838        PFN_cutensorCreateComputeDescriptor;
839    fn cutensor_destroy_compute_descriptor as "cutensorDestroyComputeDescriptor":
840        PFN_cutensorDestroyComputeDescriptor;
841    fn cutensor_compute_descriptor_get_attribute as "cutensorComputeDescriptorGetAttribute":
842        PFN_cutensorComputeDescriptorGetAttribute;
843    fn cutensor_compute_descriptor_set_attribute as "cutensorComputeDescriptorSetAttribute":
844        PFN_cutensorComputeDescriptorSetAttribute;
845
846    // Additional attributes
847    fn cutensor_tensor_descriptor_set_attribute as "cutensorTensorDescriptorSetAttribute":
848        PFN_cutensorTensorDescriptorSetAttribute;
849    fn cutensor_plan_preference_get_attribute as "cutensorPlanPreferenceGetAttribute":
850        PFN_cutensorPlanPreferenceGetAttribute;
851
852    // Introspection
853    fn cutensor_operation_estimate_runtime as "cutensorOperationEstimateRuntime":
854        PFN_cutensorOperationEstimateRuntime;
855    fn cutensor_operation_num_algos as "cutensorOperationNumAlgos":
856        PFN_cutensorOperationNumAlgos;
857
858    // Logging
859    fn cutensor_logger_set_level as "cutensorLoggerSetLevel": PFN_cutensorLoggerSetLevel;
860    fn cutensor_logger_set_mask as "cutensorLoggerSetMask": PFN_cutensorLoggerSetMask;
861    fn cutensor_logger_open_file as "cutensorLoggerOpenFile": PFN_cutensorLoggerOpenFile;
862    fn cutensor_logger_set_file as "cutensorLoggerSetFile": PFN_cutensorLoggerSetFile;
863    fn cutensor_logger_set_callback as "cutensorLoggerSetCallback":
864        PFN_cutensorLoggerSetCallback;
865    fn cutensor_logger_force_disable as "cutensorLoggerForceDisable":
866        PFN_cutensorLoggerForceDisable;
867
868    // Block-sparse contraction
869    fn cutensor_create_block_sparse_tensor_descriptor as "cutensorCreateBlockSparseTensorDescriptor":
870        PFN_cutensorCreateBlockSparseTensorDescriptor;
871    fn cutensor_destroy_block_sparse_tensor_descriptor
872        as "cutensorDestroyBlockSparseTensorDescriptor":
873        PFN_cutensorDestroyBlockSparseTensorDescriptor;
874    fn cutensor_create_block_sparse_contraction as "cutensorCreateBlockSparseContraction":
875        PFN_cutensorCreateBlockSparseContraction;
876    fn cutensor_block_sparse_contract as "cutensorBlockSparseContract":
877        PFN_cutensorBlockSparseContract;
878}
879
880fn cutensor_candidates() -> &'static [&'static str] {
881    #[cfg(target_os = "linux")]
882    {
883        &["libcutensor.so.2", "libcutensor.so.1", "libcutensor.so"]
884    }
885    #[cfg(target_os = "windows")]
886    {
887        &["cutensor.dll"]
888    }
889    #[cfg(not(any(target_os = "linux", target_os = "windows")))]
890    {
891        &[]
892    }
893}
894
895/// Extra directories to search for cuTENSOR on Windows — NVIDIA's
896/// installer places it in a non-CUDA-Toolkit location.
897#[cfg(target_os = "windows")]
898fn cutensor_extra_dirs() -> Vec<std::path::PathBuf> {
899    use std::path::PathBuf;
900    let mut out = Vec::new();
901
902    let progfiles = std::env::var("ProgramFiles").unwrap_or_else(|_| "C:\\Program Files".into());
903
904    // Stand-alone cuTENSOR installs.
905    let stand_alone_roots = [
906        format!("{progfiles}\\NVIDIA cuTENSOR"),
907        format!("{progfiles}\\NVIDIA\\cuTENSOR"),
908    ];
909    for root in &stand_alone_roots {
910        // Typical layouts:
911        //   <root>\<ver>\bin\<cuda-major>\cutensor.dll
912        //   <root>\<ver>\lib\<cuda-major>\cutensor.dll
913        //   <root>\bin\cutensor.dll
914        let root_pb = PathBuf::from(root);
915        if let Ok(top) = std::fs::read_dir(&root_pb) {
916            for ent in top.flatten() {
917                let p = ent.path();
918                if p.is_dir() {
919                    out.push(p.join("bin"));
920                    for sub in [
921                        "bin\\12", "bin\\13", "bin\\11", "lib\\12", "lib\\13", "lib\\11",
922                    ] {
923                        out.push(p.join(sub));
924                    }
925                }
926            }
927        }
928        out.push(root_pb.join("bin"));
929    }
930
931    // Also fall back to the CUDA Toolkit's own bin dir (some installers
932    // drop a stub there).
933    for var in ["CUDA_PATH", "CUDA_HOME"] {
934        if let Ok(p) = std::env::var(var) {
935            out.push(PathBuf::from(p).join("bin"));
936        }
937    }
938
939    out
940}
941
942/// Return the lazily-loaded cuTENSOR library accessor.
943pub fn cutensor() -> Result<&'static Cutensor, LoaderError> {
944    static CUTENSOR: OnceLock<Cutensor> = OnceLock::new();
945    if let Some(c) = CUTENSOR.get() {
946        return Ok(c);
947    }
948    let lib = match Library::open("cutensor", cutensor_candidates()) {
949        Ok(l) => l,
950        Err(e) => {
951            #[cfg(target_os = "windows")]
952            {
953                let mut found: Option<Library> = None;
954                for dir in cutensor_extra_dirs() {
955                    for candidate in cutensor_candidates() {
956                        let full = dir.join(candidate);
957                        if let Ok(l) = Library::open_at("cutensor", &full) {
958                            found = Some(l);
959                            break;
960                        }
961                    }
962                    if found.is_some() {
963                        break;
964                    }
965                }
966                match found {
967                    Some(l) => l,
968                    None => return Err(e),
969                }
970            }
971            #[cfg(not(target_os = "windows"))]
972            {
973                return Err(e);
974            }
975        }
976    };
977    let _ = CUTENSOR.set(Cutensor::empty(lib));
978    Ok(CUTENSOR.get().expect("OnceLock set or lost race"))
979}