1#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
13#![warn(missing_debug_implementations)]
14
15use std::sync::OnceLock;
16
17use baracuda_core::{Library, LoaderError};
18use baracuda_types::CudaStatus;
19
20#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
22#[repr(transparent)]
23pub struct cutensorStatus_t(pub i32);
24
25impl cutensorStatus_t {
26 pub const SUCCESS: Self = Self(0);
28 pub const NOT_INITIALIZED: Self = Self(1);
30 pub const ALLOC_FAILED: Self = Self(3);
32 pub const INVALID_VALUE: Self = Self(7);
34 pub const ARCH_MISMATCH: Self = Self(8);
36 pub const MAPPING_ERROR: Self = Self(11);
38 pub const EXECUTION_FAILED: Self = Self(13);
40 pub const INTERNAL_ERROR: Self = Self(14);
42 pub const NOT_SUPPORTED: Self = Self(15);
44 pub const LICENSE_ERROR: Self = Self(16);
46 pub const CUBLAS_ERROR: Self = Self(17);
48 pub const CUDA_ERROR: Self = Self(18);
50 pub const INSUFFICIENT_WORKSPACE: Self = Self(19);
52 pub const INSUFFICIENT_DRIVER: Self = Self(20);
54 pub const IO_ERROR: Self = Self(21);
56
57 pub const fn is_success(self) -> bool {
59 self.0 == 0
60 }
61}
62
63impl CudaStatus for cutensorStatus_t {
64 fn code(self) -> i32 {
65 self.0
66 }
67 fn name(self) -> &'static str {
68 match self.0 {
69 0 => "CUTENSOR_STATUS_SUCCESS",
70 1 => "CUTENSOR_STATUS_NOT_INITIALIZED",
71 3 => "CUTENSOR_STATUS_ALLOC_FAILED",
72 7 => "CUTENSOR_STATUS_INVALID_VALUE",
73 13 => "CUTENSOR_STATUS_EXECUTION_FAILED",
74 15 => "CUTENSOR_STATUS_NOT_SUPPORTED",
75 19 => "CUTENSOR_STATUS_INSUFFICIENT_WORKSPACE",
76 _ => "CUTENSOR_STATUS_UNRECOGNIZED",
77 }
78 }
79 fn description(self) -> &'static str {
80 match self.0 {
81 0 => "success",
82 15 => "operation not supported",
83 19 => "workspace buffer too small",
84 _ => "unrecognized cuTENSOR status code",
85 }
86 }
87 fn is_success(self) -> bool {
88 cutensorStatus_t::is_success(self)
89 }
90 fn library(self) -> &'static str {
91 "cutensor"
92 }
93}
94
95pub type cutensorHandle_t = *mut core::ffi::c_void;
99
100pub type cutensorTensorDescriptor_t = *mut core::ffi::c_void;
102
103pub type cutensorOperationDescriptor_t = *mut core::ffi::c_void;
105
106pub type cutensorPlanPreference_t = *mut core::ffi::c_void;
108
109pub type cutensorPlan_t = *mut core::ffi::c_void;
111
112#[allow(non_snake_case)]
114pub mod cutensorDataType {
115 pub const R_16F: i32 = 2; pub const R_16BF: i32 = 14; pub const R_32F: i32 = 0; pub const R_64F: i32 = 1; pub const C_32F: i32 = 4;
125 pub const C_64F: i32 = 5;
127 pub const R_8I: i32 = 3;
129 pub const R_8U: i32 = 8;
131 pub const R_32I: i32 = 10;
133 pub const R_32U: i32 = 12;
135}
136
137pub type cutensorComputeDescriptor_t = *const core::ffi::c_void;
145
146impl Cutensor {
147 fn compute_desc_by_name(
152 &self,
153 name: &'static str,
154 ) -> Result<cutensorComputeDescriptor_t, LoaderError> {
155 let raw: *mut () = unsafe { self.lib.raw_symbol(name)? };
158 let ptr_ptr = raw as *const cutensorComputeDescriptor_t;
159 Ok(unsafe { *ptr_ptr })
160 }
161
162 pub fn compute_desc_32f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
164 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_32F")
165 }
166 pub fn compute_desc_64f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
168 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_64F")
169 }
170 pub fn compute_desc_16f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
172 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_16F")
173 }
174 pub fn compute_desc_16bf(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
176 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_16BF")
177 }
178 pub fn compute_desc_tf32(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
180 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_TF32")
181 }
182 pub fn compute_desc_3xtf32(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
184 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_3XTF32")
185 }
186 pub fn compute_desc_4x16f(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
188 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_4X16F")
189 }
190 pub fn compute_desc_8xint8(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
192 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_8XINT8")
193 }
194 pub fn compute_desc_9x16bf(&self) -> Result<cutensorComputeDescriptor_t, LoaderError> {
196 self.compute_desc_by_name("CUTENSOR_COMPUTE_DESC_9X16BF")
197 }
198}
199
200#[allow(non_snake_case)]
202pub mod cutensorOperator {
203 pub const IDENTITY: i32 = 1;
205 pub const SQRT: i32 = 2;
207 pub const RELU: i32 = 8;
209 pub const CONJ: i32 = 9;
211 pub const RCP: i32 = 10;
213 pub const SIGMOID: i32 = 11;
215 pub const TANH: i32 = 12;
217 pub const ADD: i32 = 3;
219 pub const MUL: i32 = 5;
221 pub const MAX: i32 = 6;
223 pub const MIN: i32 = 7;
225}
226
227#[allow(non_snake_case)]
229pub mod cutensorAlgo {
230 pub const DEFAULT: i32 = -1;
232 pub const GETT: i32 = -4;
234 pub const TGETT: i32 = -3;
236 pub const TTGT: i32 = -2;
238}
239
240#[allow(non_snake_case)]
242pub mod cutensorJitMode {
243 pub const NONE: i32 = 0;
245 pub const DEFAULT: i32 = 1;
247}
248
249#[allow(non_snake_case)]
251pub mod cutensorWorksizePreference {
252 pub const MIN: i32 = 1;
254 pub const DEFAULT: i32 = 2;
256 pub const MAX: i32 = 3;
258}
259
260pub type PFN_cutensorCreate =
264 unsafe extern "C" fn(handle_out: *mut cutensorHandle_t) -> cutensorStatus_t;
265pub type PFN_cutensorDestroy = unsafe extern "C" fn(handle: cutensorHandle_t) -> cutensorStatus_t;
267
268pub type PFN_cutensorCreateTensorDescriptor = unsafe extern "C" fn(
270 handle: cutensorHandle_t,
271 desc_out: *mut cutensorTensorDescriptor_t,
272 num_modes: u32,
273 extents: *const i64,
274 strides: *const i64,
275 data_type: i32,
276 alignment_bytes: u32,
277) -> cutensorStatus_t;
278pub type PFN_cutensorDestroyTensorDescriptor =
280 unsafe extern "C" fn(desc: cutensorTensorDescriptor_t) -> cutensorStatus_t;
281
282pub type PFN_cutensorCreateContraction = unsafe extern "C" fn(
284 handle: cutensorHandle_t,
285 op_desc_out: *mut cutensorOperationDescriptor_t,
286 desc_a: cutensorTensorDescriptor_t,
287 modes_a: *const i32,
288 op_a: i32,
289 desc_b: cutensorTensorDescriptor_t,
290 modes_b: *const i32,
291 op_b: i32,
292 desc_c: cutensorTensorDescriptor_t,
293 modes_c: *const i32,
294 op_c: i32,
295 desc_d: cutensorTensorDescriptor_t,
296 modes_d: *const i32,
297 compute_desc: cutensorComputeDescriptor_t,
298) -> cutensorStatus_t;
299
300pub type PFN_cutensorDestroyOperationDescriptor =
302 unsafe extern "C" fn(desc: cutensorOperationDescriptor_t) -> cutensorStatus_t;
303
304pub type PFN_cutensorCreatePlanPreference = unsafe extern "C" fn(
306 handle: cutensorHandle_t,
307 pref_out: *mut cutensorPlanPreference_t,
308 algo: i32,
309 jit_mode: i32,
310) -> cutensorStatus_t;
311pub type PFN_cutensorDestroyPlanPreference =
313 unsafe extern "C" fn(pref: cutensorPlanPreference_t) -> cutensorStatus_t;
314
315pub type PFN_cutensorEstimateWorkspaceSize = unsafe extern "C" fn(
317 handle: cutensorHandle_t,
318 op_desc: cutensorOperationDescriptor_t,
319 pref: cutensorPlanPreference_t,
320 workspace_pref: i32,
321 workspace_size_bytes_out: *mut u64,
322) -> cutensorStatus_t;
323
324pub type PFN_cutensorCreatePlan = unsafe extern "C" fn(
326 handle: cutensorHandle_t,
327 plan_out: *mut cutensorPlan_t,
328 op_desc: cutensorOperationDescriptor_t,
329 pref: cutensorPlanPreference_t,
330 workspace_size_limit: u64,
331) -> cutensorStatus_t;
332pub type PFN_cutensorDestroyPlan = unsafe extern "C" fn(plan: cutensorPlan_t) -> cutensorStatus_t;
334
335pub type PFN_cutensorContract = unsafe extern "C" fn(
337 handle: cutensorHandle_t,
338 plan: cutensorPlan_t,
339 alpha: *const core::ffi::c_void,
340 a: *const core::ffi::c_void,
341 b: *const core::ffi::c_void,
342 beta: *const core::ffi::c_void,
343 c: *const core::ffi::c_void,
344 d: *mut core::ffi::c_void,
345 workspace: *mut core::ffi::c_void,
346 workspace_size_bytes: u64,
347 stream: *mut core::ffi::c_void, ) -> cutensorStatus_t;
349
350pub type PFN_cutensorGetVersion = unsafe extern "C" fn() -> usize;
352pub type PFN_cutensorGetCudartVersion = unsafe extern "C" fn() -> usize;
354pub type PFN_cutensorGetErrorString =
356 unsafe extern "C" fn(status: cutensorStatus_t) -> *const core::ffi::c_char;
357
358pub type PFN_cutensorCreateElementwiseBinary = unsafe extern "C" fn(
362 handle: cutensorHandle_t,
363 op_desc_out: *mut cutensorOperationDescriptor_t,
364 desc_a: cutensorTensorDescriptor_t,
365 modes_a: *const i32,
366 op_a: i32,
367 desc_c: cutensorTensorDescriptor_t,
368 modes_c: *const i32,
369 op_c: i32,
370 desc_d: cutensorTensorDescriptor_t,
371 modes_d: *const i32,
372 op_ac: i32, compute_desc: cutensorComputeDescriptor_t,
374) -> cutensorStatus_t;
375
376pub type PFN_cutensorElementwiseBinaryExecute = unsafe extern "C" fn(
378 handle: cutensorHandle_t,
379 plan: cutensorPlan_t,
380 alpha: *const core::ffi::c_void,
381 a: *const core::ffi::c_void,
382 gamma: *const core::ffi::c_void,
383 c: *const core::ffi::c_void,
384 d: *mut core::ffi::c_void,
385 stream: *mut core::ffi::c_void,
386) -> cutensorStatus_t;
387
388pub type PFN_cutensorCreateElementwiseTrinary = unsafe extern "C" fn(
390 handle: cutensorHandle_t,
391 op_desc_out: *mut cutensorOperationDescriptor_t,
392 desc_a: cutensorTensorDescriptor_t,
393 modes_a: *const i32,
394 op_a: i32,
395 desc_b: cutensorTensorDescriptor_t,
396 modes_b: *const i32,
397 op_b: i32,
398 desc_c: cutensorTensorDescriptor_t,
399 modes_c: *const i32,
400 op_c: i32,
401 desc_d: cutensorTensorDescriptor_t,
402 modes_d: *const i32,
403 op_ab: i32,
404 op_abc: i32,
405 compute_desc: cutensorComputeDescriptor_t,
406) -> cutensorStatus_t;
407
408pub type PFN_cutensorElementwiseTrinaryExecute = unsafe extern "C" fn(
410 handle: cutensorHandle_t,
411 plan: cutensorPlan_t,
412 alpha: *const core::ffi::c_void,
413 a: *const core::ffi::c_void,
414 beta: *const core::ffi::c_void,
415 b: *const core::ffi::c_void,
416 gamma: *const core::ffi::c_void,
417 c: *const core::ffi::c_void,
418 d: *mut core::ffi::c_void,
419 stream: *mut core::ffi::c_void,
420) -> cutensorStatus_t;
421
422pub type PFN_cutensorCreatePermutation = unsafe extern "C" fn(
424 handle: cutensorHandle_t,
425 op_desc_out: *mut cutensorOperationDescriptor_t,
426 desc_a: cutensorTensorDescriptor_t,
427 modes_a: *const i32,
428 op_a: i32,
429 desc_b: cutensorTensorDescriptor_t,
430 modes_b: *const i32,
431 compute_desc: cutensorComputeDescriptor_t,
432) -> cutensorStatus_t;
433
434pub type PFN_cutensorPermute = unsafe extern "C" fn(
436 handle: cutensorHandle_t,
437 plan: cutensorPlan_t,
438 alpha: *const core::ffi::c_void,
439 a: *const core::ffi::c_void,
440 b: *mut core::ffi::c_void,
441 stream: *mut core::ffi::c_void,
442) -> cutensorStatus_t;
443
444pub type PFN_cutensorCreateReduction = unsafe extern "C" fn(
446 handle: cutensorHandle_t,
447 op_desc_out: *mut cutensorOperationDescriptor_t,
448 desc_a: cutensorTensorDescriptor_t,
449 modes_a: *const i32,
450 op_a: i32,
451 desc_c: cutensorTensorDescriptor_t,
452 modes_c: *const i32,
453 op_c: i32,
454 desc_d: cutensorTensorDescriptor_t,
455 modes_d: *const i32,
456 op_reduce: i32, compute_desc: cutensorComputeDescriptor_t,
458) -> cutensorStatus_t;
459
460pub type PFN_cutensorReduce = unsafe extern "C" fn(
462 handle: cutensorHandle_t,
463 plan: cutensorPlan_t,
464 alpha: *const core::ffi::c_void,
465 a: *const core::ffi::c_void,
466 beta: *const core::ffi::c_void,
467 c: *const core::ffi::c_void,
468 d: *mut core::ffi::c_void,
469 workspace: *mut core::ffi::c_void,
470 workspace_size: u64,
471 stream: *mut core::ffi::c_void,
472) -> cutensorStatus_t;
473
474pub type PFN_cutensorOperationDescriptorGetAttribute = unsafe extern "C" fn(
478 handle: cutensorHandle_t,
479 op_desc: cutensorOperationDescriptor_t,
480 attr: i32,
481 buf: *mut core::ffi::c_void,
482 size_in_bytes: usize,
483) -> cutensorStatus_t;
484
485pub type PFN_cutensorOperationDescriptorSetAttribute = unsafe extern "C" fn(
487 handle: cutensorHandle_t,
488 op_desc: cutensorOperationDescriptor_t,
489 attr: i32,
490 buf: *const core::ffi::c_void,
491 size_in_bytes: usize,
492) -> cutensorStatus_t;
493
494pub type PFN_cutensorPlanPreferenceSetAttribute = unsafe extern "C" fn(
496 handle: cutensorHandle_t,
497 pref: cutensorPlanPreference_t,
498 attr: i32,
499 buf: *const core::ffi::c_void,
500 size_in_bytes: usize,
501) -> cutensorStatus_t;
502
503pub type PFN_cutensorPlanGetAttribute = unsafe extern "C" fn(
505 handle: cutensorHandle_t,
506 plan: cutensorPlan_t,
507 attr: i32,
508 buf: *mut core::ffi::c_void,
509 size_in_bytes: usize,
510) -> cutensorStatus_t;
511
512pub type PFN_cutensorTensorDescriptorGetAttribute = unsafe extern "C" fn(
514 handle: cutensorHandle_t,
515 desc: cutensorTensorDescriptor_t,
516 attr: i32,
517 buf: *mut core::ffi::c_void,
518 size_in_bytes: usize,
519) -> cutensorStatus_t;
520
521pub type PFN_cutensorHandleResizePlanCache =
525 unsafe extern "C" fn(handle: cutensorHandle_t, num_entries: u32) -> cutensorStatus_t;
526
527pub type PFN_cutensorHandleReadCacheFromFile = unsafe extern "C" fn(
529 handle: cutensorHandle_t,
530 filename: *const core::ffi::c_char,
531) -> cutensorStatus_t;
532
533pub type PFN_cutensorHandleWriteCacheToFile = unsafe extern "C" fn(
535 handle: cutensorHandle_t,
536 filename: *const core::ffi::c_char,
537) -> cutensorStatus_t;
538
539pub type PFN_cutensorCreateContractionTrinary = unsafe extern "C" fn(
543 handle: cutensorHandle_t,
544 op_desc_out: *mut cutensorOperationDescriptor_t,
545 desc_a: cutensorTensorDescriptor_t,
546 modes_a: *const i32,
547 op_a: i32,
548 desc_b: cutensorTensorDescriptor_t,
549 modes_b: *const i32,
550 op_b: i32,
551 desc_c: cutensorTensorDescriptor_t,
552 modes_c: *const i32,
553 op_c: i32,
554 desc_d: cutensorTensorDescriptor_t,
555 modes_d: *const i32,
556 op_d: i32,
557 desc_e: cutensorTensorDescriptor_t,
558 modes_e: *const i32,
559 compute_desc: cutensorComputeDescriptor_t,
560) -> cutensorStatus_t;
561
562pub type PFN_cutensorContractTrinary = unsafe extern "C" fn(
564 handle: cutensorHandle_t,
565 plan: cutensorPlan_t,
566 alpha: *const core::ffi::c_void,
567 a: *const core::ffi::c_void,
568 b: *const core::ffi::c_void,
569 c: *const core::ffi::c_void,
570 beta: *const core::ffi::c_void,
571 d: *const core::ffi::c_void,
572 e: *mut core::ffi::c_void,
573 workspace: *mut core::ffi::c_void,
574 workspace_size: u64,
575 stream: *mut core::ffi::c_void,
576) -> cutensorStatus_t;
577
578pub type PFN_cutensorCreateComputeDescriptor = unsafe extern "C" fn(
582 handle: cutensorHandle_t,
583 desc_out: *mut cutensorComputeDescriptor_t,
584) -> cutensorStatus_t;
585
586pub type PFN_cutensorDestroyComputeDescriptor =
588 unsafe extern "C" fn(desc: cutensorComputeDescriptor_t) -> cutensorStatus_t;
589
590pub type PFN_cutensorComputeDescriptorGetAttribute = unsafe extern "C" fn(
592 handle: cutensorHandle_t,
593 desc: cutensorComputeDescriptor_t,
594 attr: i32,
595 buf: *mut core::ffi::c_void,
596 size_in_bytes: usize,
597) -> cutensorStatus_t;
598
599pub type PFN_cutensorComputeDescriptorSetAttribute = unsafe extern "C" fn(
601 handle: cutensorHandle_t,
602 desc: cutensorComputeDescriptor_t,
603 attr: i32,
604 buf: *const core::ffi::c_void,
605 size_in_bytes: usize,
606) -> cutensorStatus_t;
607
608pub type PFN_cutensorTensorDescriptorSetAttribute = unsafe extern "C" fn(
612 handle: cutensorHandle_t,
613 desc: cutensorTensorDescriptor_t,
614 attr: i32,
615 buf: *const core::ffi::c_void,
616 size_in_bytes: usize,
617) -> cutensorStatus_t;
618
619pub type PFN_cutensorPlanPreferenceGetAttribute = unsafe extern "C" fn(
621 handle: cutensorHandle_t,
622 pref: cutensorPlanPreference_t,
623 attr: i32,
624 buf: *mut core::ffi::c_void,
625 size_in_bytes: usize,
626) -> cutensorStatus_t;
627
628pub type PFN_cutensorOperationEstimateRuntime = unsafe extern "C" fn(
632 handle: cutensorHandle_t,
633 op_desc: cutensorOperationDescriptor_t,
634 pref: cutensorPlanPreference_t,
635 algo: i32,
636 runtime_ms_out: *mut f32,
637) -> cutensorStatus_t;
638
639pub type PFN_cutensorOperationNumAlgos = unsafe extern "C" fn(
641 op_desc: cutensorOperationDescriptor_t,
642 num_algos_out: *mut i32,
643) -> cutensorStatus_t;
644
645pub type PFN_cutensorLoggerSetLevel = unsafe extern "C" fn(level: i32) -> cutensorStatus_t;
649
650pub type PFN_cutensorLoggerSetMask = unsafe extern "C" fn(mask: i32) -> cutensorStatus_t;
652
653pub type PFN_cutensorLoggerOpenFile =
655 unsafe extern "C" fn(path: *const core::ffi::c_char) -> cutensorStatus_t;
656
657pub type PFN_cutensorLoggerSetFile =
659 unsafe extern "C" fn(file: *mut core::ffi::c_void) -> cutensorStatus_t;
660
661pub type PFN_cutensorLoggerSetCallback = unsafe extern "C" fn(
663 callback: Option<unsafe extern "C" fn(i32, *const core::ffi::c_char, *const core::ffi::c_char)>,
664) -> cutensorStatus_t;
665
666pub type PFN_cutensorLoggerForceDisable = unsafe extern "C" fn() -> cutensorStatus_t;
668
669pub type cutensorBlockSparseTensorDescriptor_t = *mut core::ffi::c_void;
673
674pub type PFN_cutensorCreateBlockSparseTensorDescriptor = unsafe extern "C" fn(
676 handle: cutensorHandle_t,
677 desc_out: *mut cutensorBlockSparseTensorDescriptor_t,
678 num_modes: u32,
679 extents: *const i64,
680 block_size: *const i64,
681 strides: *const i64,
682 block_index_count: i64,
683 block_indices: *const i32,
684 data_type: i32,
685 alignment_bytes: u32,
686) -> cutensorStatus_t;
687
688pub type PFN_cutensorDestroyBlockSparseTensorDescriptor =
690 unsafe extern "C" fn(desc: cutensorBlockSparseTensorDescriptor_t) -> cutensorStatus_t;
691
692pub type PFN_cutensorCreateBlockSparseContraction = unsafe extern "C" fn(
694 handle: cutensorHandle_t,
695 op_desc_out: *mut cutensorOperationDescriptor_t,
696 desc_a: cutensorBlockSparseTensorDescriptor_t,
697 modes_a: *const i32,
698 op_a: i32,
699 desc_b: cutensorTensorDescriptor_t,
700 modes_b: *const i32,
701 op_b: i32,
702 desc_c: cutensorTensorDescriptor_t,
703 modes_c: *const i32,
704 op_c: i32,
705 desc_d: cutensorTensorDescriptor_t,
706 modes_d: *const i32,
707 compute_desc: cutensorComputeDescriptor_t,
708) -> cutensorStatus_t;
709
710pub type PFN_cutensorBlockSparseContract = unsafe extern "C" fn(
712 handle: cutensorHandle_t,
713 plan: cutensorPlan_t,
714 alpha: *const core::ffi::c_void,
715 a: *const core::ffi::c_void,
716 b: *const core::ffi::c_void,
717 beta: *const core::ffi::c_void,
718 c: *const core::ffi::c_void,
719 d: *mut core::ffi::c_void,
720 workspace: *mut core::ffi::c_void,
721 workspace_size: u64,
722 stream: *mut core::ffi::c_void,
723) -> cutensorStatus_t;
724
725macro_rules! cutensor_fns {
728 ($($(#[$attr:meta])* fn $name:ident as $sym:literal : $pfn:ty;)*) => {
729 pub struct Cutensor {
731 pub lib: Library,
733 $(
734 $name: OnceLock<$pfn>,
735 )*
736 }
737
738 impl core::fmt::Debug for Cutensor {
739 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
740 f.debug_struct("Cutensor").field("lib", &self.lib).finish_non_exhaustive()
741 }
742 }
743
744 impl Cutensor {
745 fn empty(lib: Library) -> Self {
746 Self { lib, $($name: OnceLock::new(),)* }
747 }
748 $(
749 $(#[$attr])*
750 #[doc = concat!("Resolve `", $sym, "`.")]
751 pub fn $name(&self) -> Result<$pfn, LoaderError> {
752 if let Some(&p) = self.$name.get() { return Ok(p); }
753 let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
754 let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
755 let _ = self.$name.set(p);
756 Ok(p)
757 }
758 )*
759 }
760 };
761}
762
763cutensor_fns! {
764 fn cutensor_create as "cutensorCreate": PFN_cutensorCreate;
765 fn cutensor_destroy as "cutensorDestroy": PFN_cutensorDestroy;
766 fn cutensor_create_tensor_descriptor as "cutensorCreateTensorDescriptor":
767 PFN_cutensorCreateTensorDescriptor;
768 fn cutensor_destroy_tensor_descriptor as "cutensorDestroyTensorDescriptor":
769 PFN_cutensorDestroyTensorDescriptor;
770 fn cutensor_create_contraction as "cutensorCreateContraction": PFN_cutensorCreateContraction;
771 fn cutensor_destroy_operation_descriptor as "cutensorDestroyOperationDescriptor":
772 PFN_cutensorDestroyOperationDescriptor;
773 fn cutensor_create_plan_preference as "cutensorCreatePlanPreference":
774 PFN_cutensorCreatePlanPreference;
775 fn cutensor_destroy_plan_preference as "cutensorDestroyPlanPreference":
776 PFN_cutensorDestroyPlanPreference;
777 fn cutensor_estimate_workspace_size as "cutensorEstimateWorkspaceSize":
778 PFN_cutensorEstimateWorkspaceSize;
779 fn cutensor_create_plan as "cutensorCreatePlan": PFN_cutensorCreatePlan;
780 fn cutensor_destroy_plan as "cutensorDestroyPlan": PFN_cutensorDestroyPlan;
781 fn cutensor_contract as "cutensorContract": PFN_cutensorContract;
782 fn cutensor_get_version as "cutensorGetVersion": PFN_cutensorGetVersion;
783 fn cutensor_get_cudart_version as "cutensorGetCudartVersion": PFN_cutensorGetCudartVersion;
784 fn cutensor_get_error_string as "cutensorGetErrorString": PFN_cutensorGetErrorString;
785
786 fn cutensor_create_elementwise_binary as "cutensorCreateElementwiseBinary":
788 PFN_cutensorCreateElementwiseBinary;
789 fn cutensor_elementwise_binary_execute as "cutensorElementwiseBinaryExecute":
790 PFN_cutensorElementwiseBinaryExecute;
791
792 fn cutensor_create_elementwise_trinary as "cutensorCreateElementwiseTrinary":
794 PFN_cutensorCreateElementwiseTrinary;
795 fn cutensor_elementwise_trinary_execute as "cutensorElementwiseTrinaryExecute":
796 PFN_cutensorElementwiseTrinaryExecute;
797
798 fn cutensor_create_permutation as "cutensorCreatePermutation":
800 PFN_cutensorCreatePermutation;
801 fn cutensor_permute as "cutensorPermute": PFN_cutensorPermute;
802
803 fn cutensor_create_reduction as "cutensorCreateReduction": PFN_cutensorCreateReduction;
805 fn cutensor_reduce as "cutensorReduce": PFN_cutensorReduce;
806
807 fn cutensor_operation_descriptor_get_attribute as "cutensorOperationDescriptorGetAttribute":
809 PFN_cutensorOperationDescriptorGetAttribute;
810 fn cutensor_operation_descriptor_set_attribute as "cutensorOperationDescriptorSetAttribute":
811 PFN_cutensorOperationDescriptorSetAttribute;
812 fn cutensor_plan_preference_set_attribute as "cutensorPlanPreferenceSetAttribute":
813 PFN_cutensorPlanPreferenceSetAttribute;
814 fn cutensor_plan_get_attribute as "cutensorPlanGetAttribute":
815 PFN_cutensorPlanGetAttribute;
816 fn cutensor_tensor_descriptor_get_attribute as "cutensorTensorDescriptorGetAttribute":
817 PFN_cutensorTensorDescriptorGetAttribute;
818
819 fn cutensor_handle_resize_plan_cache as "cutensorHandleResizePlanCache":
821 PFN_cutensorHandleResizePlanCache;
822 fn cutensor_handle_read_plan_cache_from_file as "cutensorHandleReadPlanCacheFromFile":
823 PFN_cutensorHandleReadCacheFromFile;
824 fn cutensor_handle_write_plan_cache_to_file as "cutensorHandleWritePlanCacheToFile":
825 PFN_cutensorHandleWriteCacheToFile;
826 fn cutensor_read_kernel_cache_from_file as "cutensorReadKernelCacheFromFile":
827 PFN_cutensorHandleReadCacheFromFile;
828 fn cutensor_write_kernel_cache_to_file as "cutensorWriteKernelCacheToFile":
829 PFN_cutensorHandleWriteCacheToFile;
830
831 fn cutensor_create_contraction_trinary as "cutensorCreateContractionTrinary":
833 PFN_cutensorCreateContractionTrinary;
834 fn cutensor_contract_trinary as "cutensorContractTrinary": PFN_cutensorContractTrinary;
835
836 fn cutensor_create_compute_descriptor as "cutensorCreateComputeDescriptor":
838 PFN_cutensorCreateComputeDescriptor;
839 fn cutensor_destroy_compute_descriptor as "cutensorDestroyComputeDescriptor":
840 PFN_cutensorDestroyComputeDescriptor;
841 fn cutensor_compute_descriptor_get_attribute as "cutensorComputeDescriptorGetAttribute":
842 PFN_cutensorComputeDescriptorGetAttribute;
843 fn cutensor_compute_descriptor_set_attribute as "cutensorComputeDescriptorSetAttribute":
844 PFN_cutensorComputeDescriptorSetAttribute;
845
846 fn cutensor_tensor_descriptor_set_attribute as "cutensorTensorDescriptorSetAttribute":
848 PFN_cutensorTensorDescriptorSetAttribute;
849 fn cutensor_plan_preference_get_attribute as "cutensorPlanPreferenceGetAttribute":
850 PFN_cutensorPlanPreferenceGetAttribute;
851
852 fn cutensor_operation_estimate_runtime as "cutensorOperationEstimateRuntime":
854 PFN_cutensorOperationEstimateRuntime;
855 fn cutensor_operation_num_algos as "cutensorOperationNumAlgos":
856 PFN_cutensorOperationNumAlgos;
857
858 fn cutensor_logger_set_level as "cutensorLoggerSetLevel": PFN_cutensorLoggerSetLevel;
860 fn cutensor_logger_set_mask as "cutensorLoggerSetMask": PFN_cutensorLoggerSetMask;
861 fn cutensor_logger_open_file as "cutensorLoggerOpenFile": PFN_cutensorLoggerOpenFile;
862 fn cutensor_logger_set_file as "cutensorLoggerSetFile": PFN_cutensorLoggerSetFile;
863 fn cutensor_logger_set_callback as "cutensorLoggerSetCallback":
864 PFN_cutensorLoggerSetCallback;
865 fn cutensor_logger_force_disable as "cutensorLoggerForceDisable":
866 PFN_cutensorLoggerForceDisable;
867
868 fn cutensor_create_block_sparse_tensor_descriptor as "cutensorCreateBlockSparseTensorDescriptor":
870 PFN_cutensorCreateBlockSparseTensorDescriptor;
871 fn cutensor_destroy_block_sparse_tensor_descriptor
872 as "cutensorDestroyBlockSparseTensorDescriptor":
873 PFN_cutensorDestroyBlockSparseTensorDescriptor;
874 fn cutensor_create_block_sparse_contraction as "cutensorCreateBlockSparseContraction":
875 PFN_cutensorCreateBlockSparseContraction;
876 fn cutensor_block_sparse_contract as "cutensorBlockSparseContract":
877 PFN_cutensorBlockSparseContract;
878}
879
880fn cutensor_candidates() -> &'static [&'static str] {
881 #[cfg(target_os = "linux")]
882 {
883 &["libcutensor.so.2", "libcutensor.so.1", "libcutensor.so"]
884 }
885 #[cfg(target_os = "windows")]
886 {
887 &["cutensor.dll"]
888 }
889 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
890 {
891 &[]
892 }
893}
894
895#[cfg(target_os = "windows")]
898fn cutensor_extra_dirs() -> Vec<std::path::PathBuf> {
899 use std::path::PathBuf;
900 let mut out = Vec::new();
901
902 let progfiles = std::env::var("ProgramFiles").unwrap_or_else(|_| "C:\\Program Files".into());
903
904 let stand_alone_roots = [
906 format!("{progfiles}\\NVIDIA cuTENSOR"),
907 format!("{progfiles}\\NVIDIA\\cuTENSOR"),
908 ];
909 for root in &stand_alone_roots {
910 let root_pb = PathBuf::from(root);
915 if let Ok(top) = std::fs::read_dir(&root_pb) {
916 for ent in top.flatten() {
917 let p = ent.path();
918 if p.is_dir() {
919 out.push(p.join("bin"));
920 for sub in [
921 "bin\\12", "bin\\13", "bin\\11", "lib\\12", "lib\\13", "lib\\11",
922 ] {
923 out.push(p.join(sub));
924 }
925 }
926 }
927 }
928 out.push(root_pb.join("bin"));
929 }
930
931 for var in ["CUDA_PATH", "CUDA_HOME"] {
934 if let Ok(p) = std::env::var(var) {
935 out.push(PathBuf::from(p).join("bin"));
936 }
937 }
938
939 out
940}
941
942pub fn cutensor() -> Result<&'static Cutensor, LoaderError> {
944 static CUTENSOR: OnceLock<Cutensor> = OnceLock::new();
945 if let Some(c) = CUTENSOR.get() {
946 return Ok(c);
947 }
948 let lib = match Library::open("cutensor", cutensor_candidates()) {
949 Ok(l) => l,
950 Err(e) => {
951 #[cfg(target_os = "windows")]
952 {
953 let mut found: Option<Library> = None;
954 for dir in cutensor_extra_dirs() {
955 for candidate in cutensor_candidates() {
956 let full = dir.join(candidate);
957 if let Ok(l) = Library::open_at("cutensor", &full) {
958 found = Some(l);
959 break;
960 }
961 }
962 if found.is_some() {
963 break;
964 }
965 }
966 match found {
967 Some(l) => l,
968 None => return Err(e),
969 }
970 }
971 #[cfg(not(target_os = "windows"))]
972 {
973 return Err(e);
974 }
975 }
976 };
977 let _ = CUTENSOR.set(Cutensor::empty(lib));
978 Ok(CUTENSOR.get().expect("OnceLock set or lost race"))
979}