1#![warn(missing_debug_implementations)]
104
105use core::ffi::c_void;
106use std::ffi::CString;
107
108use baracuda_cutensor_sys::{
109 cutensor, cutensorAlgo, cutensorDataType, cutensorHandle_t, cutensorJitMode,
110 cutensorOperationDescriptor_t, cutensorOperator, cutensorPlanPreference_t, cutensorPlan_t,
111 cutensorStatus_t, cutensorTensorDescriptor_t, cutensorWorksizePreference,
112};
113
114pub type Error = baracuda_core::Error<cutensorStatus_t>;
116pub type Result<T, E = Error> = core::result::Result<T, E>;
118
119#[inline]
120fn check(status: cutensorStatus_t) -> Result<()> {
121 Error::check(status)
122}
123
124pub fn probe() -> Result<()> {
126 cutensor()?;
127 Ok(())
128}
129
130pub fn version() -> Result<usize> {
133 let c = cutensor()?;
134 let cu = c.cutensor_get_version()?;
135 Ok(unsafe { cu() })
136}
137
138pub fn cudart_version() -> Result<usize> {
140 let c = cutensor()?;
141 let cu = c.cutensor_get_cudart_version()?;
142 Ok(unsafe { cu() })
143}
144
145pub fn set_log_level(level: i32) -> Result<()> {
147 let c = cutensor()?;
148 let cu = c.cutensor_logger_set_level()?;
149 check(unsafe { cu(level) })
150}
151
152pub fn set_log_mask(mask: i32) -> Result<()> {
155 let c = cutensor()?;
156 let cu = c.cutensor_logger_set_mask()?;
157 check(unsafe { cu(mask) })
158}
159
160pub fn open_log_file(path: &str) -> Result<()> {
162 let cpath = std::ffi::CString::new(path).map_err(|_| Error::Status {
163 status: cutensorStatus_t::INVALID_VALUE,
164 })?;
165 let c = cutensor()?;
166 let cu = c.cutensor_logger_open_file()?;
167 check(unsafe { cu(cpath.as_ptr()) })
168}
169
170pub fn force_disable_logging() -> Result<()> {
172 let c = cutensor()?;
173 let cu = c.cutensor_logger_force_disable()?;
174 check(unsafe { cu() })
175}
176
177#[derive(Copy, Clone, Debug, Eq, PartialEq)]
179pub enum DataType {
180 F16,
182 BF16,
184 F32,
186 F64,
188 ComplexF32,
190 ComplexF64,
192 I8,
194 U8,
196 I32,
198 U32,
200}
201
202impl DataType {
203 #[inline]
204 fn raw(self) -> i32 {
205 match self {
206 DataType::F16 => cutensorDataType::R_16F,
207 DataType::BF16 => cutensorDataType::R_16BF,
208 DataType::F32 => cutensorDataType::R_32F,
209 DataType::F64 => cutensorDataType::R_64F,
210 DataType::ComplexF32 => cutensorDataType::C_32F,
211 DataType::ComplexF64 => cutensorDataType::C_64F,
212 DataType::I8 => cutensorDataType::R_8I,
213 DataType::U8 => cutensorDataType::R_8U,
214 DataType::I32 => cutensorDataType::R_32I,
215 DataType::U32 => cutensorDataType::R_32U,
216 }
217 }
218}
219
220#[derive(Copy, Clone, Debug, Eq, PartialEq)]
222pub enum UnaryOp {
223 Identity,
225 Sqrt,
227 Relu,
229 Conj,
231 Rcp,
233 Sigmoid,
235 Tanh,
237}
238
239impl UnaryOp {
240 #[inline]
241 fn raw(self) -> i32 {
242 match self {
243 UnaryOp::Identity => cutensorOperator::IDENTITY,
244 UnaryOp::Sqrt => cutensorOperator::SQRT,
245 UnaryOp::Relu => cutensorOperator::RELU,
246 UnaryOp::Conj => cutensorOperator::CONJ,
247 UnaryOp::Rcp => cutensorOperator::RCP,
248 UnaryOp::Sigmoid => cutensorOperator::SIGMOID,
249 UnaryOp::Tanh => cutensorOperator::TANH,
250 }
251 }
252}
253
254#[derive(Copy, Clone, Debug, Eq, PartialEq)]
257pub enum BinaryOp {
258 Add,
260 Mul,
262 Max,
264 Min,
266}
267
268impl BinaryOp {
269 #[inline]
270 fn raw(self) -> i32 {
271 match self {
272 BinaryOp::Add => cutensorOperator::ADD,
273 BinaryOp::Mul => cutensorOperator::MUL,
274 BinaryOp::Max => cutensorOperator::MAX,
275 BinaryOp::Min => cutensorOperator::MIN,
276 }
277 }
278}
279
280#[derive(Debug)]
282pub struct Handle {
283 handle: cutensorHandle_t,
284}
285
286unsafe impl Send for Handle {}
287
288impl Handle {
289 pub fn new() -> Result<Self> {
291 let c = cutensor()?;
292 let cu = c.cutensor_create()?;
293 let mut h: cutensorHandle_t = core::ptr::null_mut();
294 check(unsafe { cu(&mut h) })?;
295 Ok(Self { handle: h })
296 }
297
298 #[inline]
300 pub fn as_raw(&self) -> cutensorHandle_t {
301 self.handle
302 }
303
304 pub fn resize_plan_cache(&self, num_entries: u32) -> Result<()> {
307 let c = cutensor()?;
308 let cu = c.cutensor_handle_resize_plan_cache()?;
309 check(unsafe { cu(self.handle, num_entries) })
310 }
311
312 pub fn write_plan_cache_to_file(&self, path: &str) -> Result<()> {
314 let cpath = CString::new(path).map_err(|_| Error::Status {
315 status: cutensorStatus_t::INVALID_VALUE,
316 })?;
317 let c = cutensor()?;
318 let cu = c.cutensor_handle_write_plan_cache_to_file()?;
319 check(unsafe { cu(self.handle, cpath.as_ptr()) })
320 }
321
322 pub fn read_plan_cache_from_file(&self, path: &str) -> Result<()> {
324 let cpath = CString::new(path).map_err(|_| Error::Status {
325 status: cutensorStatus_t::INVALID_VALUE,
326 })?;
327 let c = cutensor()?;
328 let cu = c.cutensor_handle_read_plan_cache_from_file()?;
329 check(unsafe { cu(self.handle, cpath.as_ptr()) })
330 }
331
332 pub fn write_kernel_cache_to_file(&self, path: &str) -> Result<()> {
336 let cpath = CString::new(path).map_err(|_| Error::Status {
337 status: cutensorStatus_t::INVALID_VALUE,
338 })?;
339 let c = cutensor()?;
340 let cu = c.cutensor_write_kernel_cache_to_file()?;
341 check(unsafe { cu(self.handle, cpath.as_ptr()) })
342 }
343
344 pub fn read_kernel_cache_from_file(&self, path: &str) -> Result<()> {
346 let cpath = CString::new(path).map_err(|_| Error::Status {
347 status: cutensorStatus_t::INVALID_VALUE,
348 })?;
349 let c = cutensor()?;
350 let cu = c.cutensor_read_kernel_cache_from_file()?;
351 check(unsafe { cu(self.handle, cpath.as_ptr()) })
352 }
353
354 pub fn compute_desc_32f(&self) -> Result<*const c_void> {
358 Ok(cutensor()?.compute_desc_32f()?)
359 }
360 pub fn compute_desc_64f(&self) -> Result<*const c_void> {
362 Ok(cutensor()?.compute_desc_64f()?)
363 }
364 pub fn compute_desc_16f(&self) -> Result<*const c_void> {
366 Ok(cutensor()?.compute_desc_16f()?)
367 }
368 pub fn compute_desc_16bf(&self) -> Result<*const c_void> {
370 Ok(cutensor()?.compute_desc_16bf()?)
371 }
372 pub fn compute_desc_tf32(&self) -> Result<*const c_void> {
374 Ok(cutensor()?.compute_desc_tf32()?)
375 }
376 pub fn compute_desc_3xtf32(&self) -> Result<*const c_void> {
378 Ok(cutensor()?.compute_desc_3xtf32()?)
379 }
380 pub fn compute_desc_4x16f(&self) -> Result<*const c_void> {
382 Ok(cutensor()?.compute_desc_4x16f()?)
383 }
384 pub fn compute_desc_8xint8(&self) -> Result<*const c_void> {
386 Ok(cutensor()?.compute_desc_8xint8()?)
387 }
388 pub fn compute_desc_9x16bf(&self) -> Result<*const c_void> {
390 Ok(cutensor()?.compute_desc_9x16bf()?)
391 }
392}
393
394#[derive(Debug)]
398pub struct ComputeDescriptor<'h> {
399 desc: baracuda_cutensor_sys::cutensorComputeDescriptor_t,
400 _handle: &'h Handle,
401}
402
403impl<'h> ComputeDescriptor<'h> {
404 pub fn new(handle: &'h Handle) -> Result<Self> {
407 let c = cutensor()?;
408 let cu = c.cutensor_create_compute_descriptor()?;
409 let mut desc: baracuda_cutensor_sys::cutensorComputeDescriptor_t = core::ptr::null();
410 check(unsafe { cu(handle.as_raw(), &mut desc as *mut _ as *mut _) })?;
411 Ok(Self {
412 desc,
413 _handle: handle,
414 })
415 }
416
417 #[inline]
419 pub fn as_raw(&self) -> baracuda_cutensor_sys::cutensorComputeDescriptor_t {
420 self.desc
421 }
422
423 pub unsafe fn set_attribute(
427 &self,
428 attr: i32,
429 value: *const c_void,
430 size_bytes: usize,
431 ) -> Result<()> { unsafe {
432 let c = cutensor()?;
433 let cu = c.cutensor_compute_descriptor_set_attribute()?;
434 check(cu(
435 self._handle.as_raw(),
436 self.desc,
437 attr,
438 value,
439 size_bytes,
440 ))
441 }}
442
443 pub unsafe fn get_attribute(
447 &self,
448 attr: i32,
449 value: *mut c_void,
450 size_bytes: usize,
451 ) -> Result<()> { unsafe {
452 let c = cutensor()?;
453 let cu = c.cutensor_compute_descriptor_get_attribute()?;
454 check(cu(
455 self._handle.as_raw(),
456 self.desc,
457 attr,
458 value,
459 size_bytes,
460 ))
461 }}
462}
463
464impl Drop for ComputeDescriptor<'_> {
465 fn drop(&mut self) {
466 if let Ok(c) = cutensor() {
467 if let Ok(cu) = c.cutensor_destroy_compute_descriptor() {
468 let _ = unsafe { cu(self.desc) };
469 }
470 }
471 }
472}
473
474#[derive(Debug)]
477pub struct BlockSparseTensorDescriptor<'h> {
478 desc: baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t,
479 _handle: &'h Handle,
480}
481
482impl<'h> BlockSparseTensorDescriptor<'h> {
483 #[allow(clippy::too_many_arguments)]
491 pub fn new(
492 handle: &'h Handle,
493 extents: &[i64],
494 block_size: &[i64],
495 strides: Option<&[i64]>,
496 block_indices: &[i32],
497 dtype: DataType,
498 alignment_bytes: u32,
499 ) -> Result<Self> {
500 assert_eq!(block_size.len(), extents.len());
501 if let Some(s) = strides {
502 assert_eq!(s.len(), extents.len());
503 }
504 let num_modes = extents.len() as u32;
505 let block_count = (block_indices.len() / extents.len()) as i64;
506 let c = cutensor()?;
507 let cu = c.cutensor_create_block_sparse_tensor_descriptor()?;
508 let mut desc: baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t =
509 core::ptr::null_mut();
510 check(unsafe {
511 cu(
512 handle.as_raw(),
513 &mut desc,
514 num_modes,
515 extents.as_ptr(),
516 block_size.as_ptr(),
517 strides.map_or(core::ptr::null(), |s| s.as_ptr()),
518 block_count,
519 block_indices.as_ptr(),
520 dtype.raw(),
521 alignment_bytes,
522 )
523 })?;
524 Ok(Self {
525 desc,
526 _handle: handle,
527 })
528 }
529
530 #[inline]
532 pub fn as_raw(&self) -> baracuda_cutensor_sys::cutensorBlockSparseTensorDescriptor_t {
533 self.desc
534 }
535}
536
537impl Drop for BlockSparseTensorDescriptor<'_> {
538 fn drop(&mut self) {
539 if let Ok(c) = cutensor() {
540 if let Ok(cu) = c.cutensor_destroy_block_sparse_tensor_descriptor() {
541 let _ = unsafe { cu(self.desc) };
542 }
543 }
544 }
545}
546
547#[derive(Debug)]
549pub struct BlockSparseContraction;
550
551impl BlockSparseContraction {
552 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
556 pub unsafe fn new<'h>(
557 handle: &'h Handle,
558 a: &BlockSparseTensorDescriptor<'h>,
559 modes_a: &[i32],
560 b: &TensorDescriptor<'h>,
561 modes_b: &[i32],
562 c: &TensorDescriptor<'h>,
563 modes_c: &[i32],
564 d: &TensorDescriptor<'h>,
565 modes_d: &[i32],
566 compute_desc: *const c_void,
567 ) -> Result<OperationDescriptor<'h>> { unsafe {
568 let lib = cutensor()?;
569 let cu = lib.cutensor_create_block_sparse_contraction()?;
570 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
571 check(cu(
572 handle.as_raw(),
573 &mut desc,
574 a.as_raw(),
575 modes_a.as_ptr(),
576 cutensorOperator::IDENTITY,
577 b.as_raw(),
578 modes_b.as_ptr(),
579 cutensorOperator::IDENTITY,
580 c.as_raw(),
581 modes_c.as_ptr(),
582 cutensorOperator::IDENTITY,
583 d.as_raw(),
584 modes_d.as_ptr(),
585 compute_desc,
586 ))?;
587 Ok(OperationDescriptor {
588 desc,
589 handle,
590 kind: OpKind::BlockSparseContraction,
591 })
592 }}
593}
594
595#[derive(Debug)]
597pub struct TrinaryContraction;
598
599impl TrinaryContraction {
600 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
604 pub unsafe fn new<'h>(
605 handle: &'h Handle,
606 a: &TensorDescriptor<'h>,
607 modes_a: &[i32],
608 b: &TensorDescriptor<'h>,
609 modes_b: &[i32],
610 c: &TensorDescriptor<'h>,
611 modes_c: &[i32],
612 d: &TensorDescriptor<'h>,
613 modes_d: &[i32],
614 e: &TensorDescriptor<'h>,
615 modes_e: &[i32],
616 compute_desc: *const c_void,
617 ) -> Result<OperationDescriptor<'h>> { unsafe {
618 let lib = cutensor()?;
619 let cu = lib.cutensor_create_contraction_trinary()?;
620 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
621 check(cu(
622 handle.as_raw(),
623 &mut desc,
624 a.as_raw(),
625 modes_a.as_ptr(),
626 cutensorOperator::IDENTITY,
627 b.as_raw(),
628 modes_b.as_ptr(),
629 cutensorOperator::IDENTITY,
630 c.as_raw(),
631 modes_c.as_ptr(),
632 cutensorOperator::IDENTITY,
633 d.as_raw(),
634 modes_d.as_ptr(),
635 cutensorOperator::IDENTITY,
636 e.as_raw(),
637 modes_e.as_ptr(),
638 compute_desc,
639 ))?;
640 Ok(OperationDescriptor {
641 desc,
642 handle,
643 kind: OpKind::TrinaryContraction,
644 })
645 }}
646}
647
648impl Drop for Handle {
649 fn drop(&mut self) {
650 if let Ok(c) = cutensor() {
651 if let Ok(cu) = c.cutensor_destroy() {
652 let _ = unsafe { cu(self.handle) };
653 }
654 }
655 }
656}
657
658#[derive(Debug)]
660pub struct TensorDescriptor<'h> {
661 desc: cutensorTensorDescriptor_t,
662 _handle: &'h Handle,
663}
664
665impl<'h> TensorDescriptor<'h> {
666 pub fn new(
669 handle: &'h Handle,
670 extents: &[i64],
671 strides: Option<&[i64]>,
672 dtype: DataType,
673 alignment_bytes: u32,
674 ) -> Result<Self> {
675 let c = cutensor()?;
676 let cu = c.cutensor_create_tensor_descriptor()?;
677 let num_modes = extents.len() as u32;
678 if let Some(s) = strides {
679 assert_eq!(s.len(), extents.len(), "strides length mismatch");
680 }
681 let mut desc: cutensorTensorDescriptor_t = core::ptr::null_mut();
682 check(unsafe {
683 cu(
684 handle.as_raw(),
685 &mut desc,
686 num_modes,
687 extents.as_ptr(),
688 strides.map_or(core::ptr::null(), |s| s.as_ptr()),
689 dtype.raw(),
690 alignment_bytes,
691 )
692 })?;
693 Ok(Self {
694 desc,
695 _handle: handle,
696 })
697 }
698
699 #[inline]
701 pub fn as_raw(&self) -> cutensorTensorDescriptor_t {
702 self.desc
703 }
704
705 pub unsafe fn set_attribute(
711 &self,
712 attr: i32,
713 buf: *const c_void,
714 size_bytes: usize,
715 ) -> Result<()> { unsafe {
716 let c = cutensor()?;
717 let cu = c.cutensor_tensor_descriptor_set_attribute()?;
718 check(cu(self._handle.as_raw(), self.desc, attr, buf, size_bytes))
719 }}
720}
721
722impl Drop for TensorDescriptor<'_> {
723 fn drop(&mut self) {
724 if let Ok(c) = cutensor() {
725 if let Ok(cu) = c.cutensor_destroy_tensor_descriptor() {
726 let _ = unsafe { cu(self.desc) };
727 }
728 }
729 }
730}
731
732#[derive(Copy, Clone, Debug, Eq, PartialEq)]
735enum OpKind {
736 Contraction,
737 TrinaryContraction,
738 BlockSparseContraction,
739 Reduction,
740 ElementwiseBinary,
741 ElementwiseTrinary,
742 Permutation,
743}
744
745#[derive(Debug)]
749pub struct OperationDescriptor<'h> {
750 desc: cutensorOperationDescriptor_t,
751 handle: &'h Handle,
752 kind: OpKind,
753}
754
755impl<'h> OperationDescriptor<'h> {
756 #[inline]
758 pub fn as_raw(&self) -> cutensorOperationDescriptor_t {
759 self.desc
760 }
761
762 pub fn estimate_workspace(
765 &self,
766 pref: &PlanPreference<'h>,
767 kind: WorkspaceKind,
768 ) -> Result<u64> {
769 let c = cutensor()?;
770 let cu = c.cutensor_estimate_workspace_size()?;
771 let mut size: u64 = 0;
772 check(unsafe {
773 cu(
774 self.handle.as_raw(),
775 self.desc,
776 pref.as_raw(),
777 kind.raw(),
778 &mut size,
779 )
780 })?;
781 Ok(size)
782 }
783
784 pub fn estimate_runtime(&self, pref: &PlanPreference<'h>, algo: i32) -> Result<f32> {
787 let c = cutensor()?;
788 let cu = c.cutensor_operation_estimate_runtime()?;
789 let mut ms: f32 = 0.0;
790 check(unsafe {
791 cu(
792 self.handle.as_raw(),
793 self.desc,
794 pref.as_raw(),
795 algo,
796 &mut ms,
797 )
798 })?;
799 Ok(ms)
800 }
801
802 pub fn num_algos(&self) -> Result<i32> {
804 let c = cutensor()?;
805 let cu = c.cutensor_operation_num_algos()?;
806 let mut n: i32 = 0;
807 check(unsafe { cu(self.desc, &mut n) })?;
808 Ok(n)
809 }
810
811 pub unsafe fn get_attribute(
817 &self,
818 attr: i32,
819 buf: *mut c_void,
820 size_bytes: usize,
821 ) -> Result<()> { unsafe {
822 let c = cutensor()?;
823 let cu = c.cutensor_operation_descriptor_get_attribute()?;
824 check(cu(self.handle.as_raw(), self.desc, attr, buf, size_bytes))
825 }}
826
827 pub unsafe fn set_attribute(
833 &self,
834 attr: i32,
835 buf: *const c_void,
836 size_bytes: usize,
837 ) -> Result<()> { unsafe {
838 let c = cutensor()?;
839 let cu = c.cutensor_operation_descriptor_set_attribute()?;
840 check(cu(self.handle.as_raw(), self.desc, attr, buf, size_bytes))
841 }}
842}
843
844impl Drop for OperationDescriptor<'_> {
845 fn drop(&mut self) {
846 if let Ok(c) = cutensor() {
847 if let Ok(cu) = c.cutensor_destroy_operation_descriptor() {
848 let _ = unsafe { cu(self.desc) };
849 }
850 }
851 }
852}
853
854#[derive(Debug)]
856pub struct Contraction;
857
858impl Contraction {
859 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
868 pub unsafe fn new<'h>(
869 handle: &'h Handle,
870 a: &TensorDescriptor<'h>,
871 modes_a: &[i32],
872 b: &TensorDescriptor<'h>,
873 modes_b: &[i32],
874 c: &TensorDescriptor<'h>,
875 modes_c: &[i32],
876 d: &TensorDescriptor<'h>,
877 modes_d: &[i32],
878 compute_desc: *const c_void,
879 ) -> Result<OperationDescriptor<'h>> { unsafe {
880 let cu_lib = cutensor()?;
881 let cu = cu_lib.cutensor_create_contraction()?;
882 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
883 check(cu(
884 handle.as_raw(),
885 &mut desc,
886 a.as_raw(),
887 modes_a.as_ptr(),
888 cutensorOperator::IDENTITY,
889 b.as_raw(),
890 modes_b.as_ptr(),
891 cutensorOperator::IDENTITY,
892 c.as_raw(),
893 modes_c.as_ptr(),
894 cutensorOperator::IDENTITY,
895 d.as_raw(),
896 modes_d.as_ptr(),
897 compute_desc,
898 ))?;
899 Ok(OperationDescriptor {
900 desc,
901 handle,
902 kind: OpKind::Contraction,
903 })
904 }}
905}
906
907#[derive(Debug)]
909pub struct Reduction;
910
911impl Reduction {
912 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
920 pub unsafe fn new<'h>(
921 handle: &'h Handle,
922 a: &TensorDescriptor<'h>,
923 modes_a: &[i32],
924 c: &TensorDescriptor<'h>,
925 modes_c: &[i32],
926 d: &TensorDescriptor<'h>,
927 modes_d: &[i32],
928 op_reduce: BinaryOp,
929 compute_desc: *const c_void,
930 ) -> Result<OperationDescriptor<'h>> { unsafe {
931 let lib = cutensor()?;
932 let cu = lib.cutensor_create_reduction()?;
933 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
934 check(cu(
935 handle.as_raw(),
936 &mut desc,
937 a.as_raw(),
938 modes_a.as_ptr(),
939 cutensorOperator::IDENTITY,
940 c.as_raw(),
941 modes_c.as_ptr(),
942 cutensorOperator::IDENTITY,
943 d.as_raw(),
944 modes_d.as_ptr(),
945 op_reduce.raw(),
946 compute_desc,
947 ))?;
948 Ok(OperationDescriptor {
949 desc,
950 handle,
951 kind: OpKind::Reduction,
952 })
953 }}
954}
955
956#[derive(Debug)]
958pub struct ElementwiseBinary;
959
960impl ElementwiseBinary {
961 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
965 pub unsafe fn new<'h>(
966 handle: &'h Handle,
967 a: &TensorDescriptor<'h>,
968 modes_a: &[i32],
969 op_a: UnaryOp,
970 c: &TensorDescriptor<'h>,
971 modes_c: &[i32],
972 op_c: UnaryOp,
973 d: &TensorDescriptor<'h>,
974 modes_d: &[i32],
975 op_ac: BinaryOp,
976 compute_desc: *const c_void,
977 ) -> Result<OperationDescriptor<'h>> { unsafe {
978 let lib = cutensor()?;
979 let cu = lib.cutensor_create_elementwise_binary()?;
980 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
981 check(cu(
982 handle.as_raw(),
983 &mut desc,
984 a.as_raw(),
985 modes_a.as_ptr(),
986 op_a.raw(),
987 c.as_raw(),
988 modes_c.as_ptr(),
989 op_c.raw(),
990 d.as_raw(),
991 modes_d.as_ptr(),
992 op_ac.raw(),
993 compute_desc,
994 ))?;
995 Ok(OperationDescriptor {
996 desc,
997 handle,
998 kind: OpKind::ElementwiseBinary,
999 })
1000 }}
1001}
1002
1003#[derive(Debug)]
1006pub struct ElementwiseTrinary;
1007
1008impl ElementwiseTrinary {
1009 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
1013 pub unsafe fn new<'h>(
1014 handle: &'h Handle,
1015 a: &TensorDescriptor<'h>,
1016 modes_a: &[i32],
1017 op_a: UnaryOp,
1018 b: &TensorDescriptor<'h>,
1019 modes_b: &[i32],
1020 op_b: UnaryOp,
1021 c: &TensorDescriptor<'h>,
1022 modes_c: &[i32],
1023 op_c: UnaryOp,
1024 d: &TensorDescriptor<'h>,
1025 modes_d: &[i32],
1026 op_ab: BinaryOp,
1027 op_abc: BinaryOp,
1028 compute_desc: *const c_void,
1029 ) -> Result<OperationDescriptor<'h>> { unsafe {
1030 let lib = cutensor()?;
1031 let cu = lib.cutensor_create_elementwise_trinary()?;
1032 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
1033 check(cu(
1034 handle.as_raw(),
1035 &mut desc,
1036 a.as_raw(),
1037 modes_a.as_ptr(),
1038 op_a.raw(),
1039 b.as_raw(),
1040 modes_b.as_ptr(),
1041 op_b.raw(),
1042 c.as_raw(),
1043 modes_c.as_ptr(),
1044 op_c.raw(),
1045 d.as_raw(),
1046 modes_d.as_ptr(),
1047 op_ab.raw(),
1048 op_abc.raw(),
1049 compute_desc,
1050 ))?;
1051 Ok(OperationDescriptor {
1052 desc,
1053 handle,
1054 kind: OpKind::ElementwiseTrinary,
1055 })
1056 }}
1057}
1058
1059#[derive(Debug)]
1062pub struct Permutation;
1063
1064impl Permutation {
1065 #[allow(clippy::too_many_arguments, clippy::new_ret_no_self)]
1069 pub unsafe fn new<'h>(
1070 handle: &'h Handle,
1071 a: &TensorDescriptor<'h>,
1072 modes_a: &[i32],
1073 op_a: UnaryOp,
1074 b: &TensorDescriptor<'h>,
1075 modes_b: &[i32],
1076 compute_desc: *const c_void,
1077 ) -> Result<OperationDescriptor<'h>> { unsafe {
1078 let lib = cutensor()?;
1079 let cu = lib.cutensor_create_permutation()?;
1080 let mut desc: cutensorOperationDescriptor_t = core::ptr::null_mut();
1081 check(cu(
1082 handle.as_raw(),
1083 &mut desc,
1084 a.as_raw(),
1085 modes_a.as_ptr(),
1086 op_a.raw(),
1087 b.as_raw(),
1088 modes_b.as_ptr(),
1089 compute_desc,
1090 ))?;
1091 Ok(OperationDescriptor {
1092 desc,
1093 handle,
1094 kind: OpKind::Permutation,
1095 })
1096 }}
1097}
1098
1099#[derive(Debug)]
1101pub struct PlanPreference<'h> {
1102 pref: cutensorPlanPreference_t,
1103 _handle: &'h Handle,
1104}
1105
1106impl<'h> PlanPreference<'h> {
1107 pub fn new(handle: &'h Handle, algo: i32, jit_mode: i32) -> Result<Self> {
1110 let c = cutensor()?;
1111 let cu = c.cutensor_create_plan_preference()?;
1112 let mut p: cutensorPlanPreference_t = core::ptr::null_mut();
1113 check(unsafe { cu(handle.as_raw(), &mut p, algo, jit_mode) })?;
1114 Ok(Self {
1115 pref: p,
1116 _handle: handle,
1117 })
1118 }
1119
1120 pub fn default_for(handle: &'h Handle) -> Result<Self> {
1122 Self::new(handle, cutensorAlgo::DEFAULT, cutensorJitMode::NONE)
1123 }
1124
1125 #[inline]
1127 pub fn as_raw(&self) -> cutensorPlanPreference_t {
1128 self.pref
1129 }
1130
1131 pub unsafe fn set_attribute(
1139 &self,
1140 attr: i32,
1141 value: *const c_void,
1142 size_bytes: usize,
1143 ) -> Result<()> { unsafe {
1144 let c = cutensor()?;
1145 let cu = c.cutensor_plan_preference_set_attribute()?;
1146 check(cu(
1147 self._handle.as_raw(),
1148 self.pref,
1149 attr,
1150 value,
1151 size_bytes,
1152 ))
1153 }}
1154
1155 pub unsafe fn get_attribute(
1161 &self,
1162 attr: i32,
1163 value: *mut c_void,
1164 size_bytes: usize,
1165 ) -> Result<()> { unsafe {
1166 let c = cutensor()?;
1167 let cu = c.cutensor_plan_preference_get_attribute()?;
1168 check(cu(
1169 self._handle.as_raw(),
1170 self.pref,
1171 attr,
1172 value,
1173 size_bytes,
1174 ))
1175 }}
1176}
1177
1178impl Drop for PlanPreference<'_> {
1179 fn drop(&mut self) {
1180 if let Ok(c) = cutensor() {
1181 if let Ok(cu) = c.cutensor_destroy_plan_preference() {
1182 let _ = unsafe { cu(self.pref) };
1183 }
1184 }
1185 }
1186}
1187
1188#[derive(Copy, Clone, Debug)]
1190pub enum WorkspaceKind {
1191 Min,
1193 Default,
1195 Max,
1197}
1198
1199impl WorkspaceKind {
1200 #[inline]
1201 fn raw(self) -> i32 {
1202 match self {
1203 WorkspaceKind::Min => cutensorWorksizePreference::MIN,
1204 WorkspaceKind::Default => cutensorWorksizePreference::DEFAULT,
1205 WorkspaceKind::Max => cutensorWorksizePreference::MAX,
1206 }
1207 }
1208}
1209
1210#[derive(Debug)]
1213pub struct Plan<'h> {
1214 plan: cutensorPlan_t,
1215 handle: &'h Handle,
1216 kind: OpKind,
1217}
1218
1219impl<'h> Plan<'h> {
1220 pub fn new(
1223 op: &OperationDescriptor<'h>,
1224 pref: &PlanPreference<'h>,
1225 workspace_size_limit: u64,
1226 ) -> Result<Self> {
1227 let c = cutensor()?;
1228 let cu = c.cutensor_create_plan()?;
1229 let mut p: cutensorPlan_t = core::ptr::null_mut();
1230 check(unsafe {
1231 cu(
1232 op.handle.as_raw(),
1233 &mut p,
1234 op.as_raw(),
1235 pref.as_raw(),
1236 workspace_size_limit,
1237 )
1238 })?;
1239 Ok(Self {
1240 plan: p,
1241 handle: op.handle,
1242 kind: op.kind,
1243 })
1244 }
1245
1246 #[inline]
1248 pub fn as_raw(&self) -> cutensorPlan_t {
1249 self.plan
1250 }
1251
1252 #[allow(clippy::too_many_arguments)]
1260 pub unsafe fn contract(
1261 &self,
1262 alpha: *const c_void,
1263 a: *const c_void,
1264 b: *const c_void,
1265 beta: *const c_void,
1266 c: *const c_void,
1267 d: *mut c_void,
1268 workspace: *mut c_void,
1269 workspace_bytes: u64,
1270 stream: *mut c_void,
1271 ) -> Result<()> { unsafe {
1272 assert_eq!(self.kind, OpKind::Contraction, "plan is not a contraction");
1273 let lib = cutensor()?;
1274 let cu = lib.cutensor_contract()?;
1275 check(cu(
1276 self.handle.as_raw(),
1277 self.plan,
1278 alpha,
1279 a,
1280 b,
1281 beta,
1282 c,
1283 d,
1284 workspace,
1285 workspace_bytes,
1286 stream,
1287 ))
1288 }}
1289
1290 #[allow(clippy::too_many_arguments)]
1296 pub unsafe fn reduce(
1297 &self,
1298 alpha: *const c_void,
1299 a: *const c_void,
1300 beta: *const c_void,
1301 c: *const c_void,
1302 d: *mut c_void,
1303 workspace: *mut c_void,
1304 workspace_bytes: u64,
1305 stream: *mut c_void,
1306 ) -> Result<()> { unsafe {
1307 assert_eq!(self.kind, OpKind::Reduction, "plan is not a reduction");
1308 let lib = cutensor()?;
1309 let cu = lib.cutensor_reduce()?;
1310 check(cu(
1311 self.handle.as_raw(),
1312 self.plan,
1313 alpha,
1314 a,
1315 beta,
1316 c,
1317 d,
1318 workspace,
1319 workspace_bytes,
1320 stream,
1321 ))
1322 }}
1323
1324 #[allow(clippy::too_many_arguments)]
1330 pub unsafe fn elementwise_binary(
1331 &self,
1332 alpha: *const c_void,
1333 a: *const c_void,
1334 gamma: *const c_void,
1335 c: *const c_void,
1336 d: *mut c_void,
1337 stream: *mut c_void,
1338 ) -> Result<()> { unsafe {
1339 assert_eq!(
1340 self.kind,
1341 OpKind::ElementwiseBinary,
1342 "plan is not an elementwise-binary"
1343 );
1344 let lib = cutensor()?;
1345 let cu = lib.cutensor_elementwise_binary_execute()?;
1346 check(cu(
1347 self.handle.as_raw(),
1348 self.plan,
1349 alpha,
1350 a,
1351 gamma,
1352 c,
1353 d,
1354 stream,
1355 ))
1356 }}
1357
1358 #[allow(clippy::too_many_arguments)]
1364 pub unsafe fn elementwise_trinary(
1365 &self,
1366 alpha: *const c_void,
1367 a: *const c_void,
1368 beta: *const c_void,
1369 b: *const c_void,
1370 gamma: *const c_void,
1371 c: *const c_void,
1372 d: *mut c_void,
1373 stream: *mut c_void,
1374 ) -> Result<()> { unsafe {
1375 assert_eq!(
1376 self.kind,
1377 OpKind::ElementwiseTrinary,
1378 "plan is not an elementwise-trinary"
1379 );
1380 let lib = cutensor()?;
1381 let cu = lib.cutensor_elementwise_trinary_execute()?;
1382 check(cu(
1383 self.handle.as_raw(),
1384 self.plan,
1385 alpha,
1386 a,
1387 beta,
1388 b,
1389 gamma,
1390 c,
1391 d,
1392 stream,
1393 ))
1394 }}
1395
1396 pub unsafe fn permute(
1402 &self,
1403 alpha: *const c_void,
1404 a: *const c_void,
1405 b: *mut c_void,
1406 stream: *mut c_void,
1407 ) -> Result<()> { unsafe {
1408 assert_eq!(self.kind, OpKind::Permutation, "plan is not a permutation");
1409 let lib = cutensor()?;
1410 let cu = lib.cutensor_permute()?;
1411 check(cu(self.handle.as_raw(), self.plan, alpha, a, b, stream))
1412 }}
1413
1414 #[allow(clippy::too_many_arguments)]
1422 pub unsafe fn contract_block_sparse(
1423 &self,
1424 alpha: *const c_void,
1425 a: *const c_void,
1426 b: *const c_void,
1427 beta: *const c_void,
1428 c: *const c_void,
1429 d: *mut c_void,
1430 workspace: *mut c_void,
1431 workspace_bytes: u64,
1432 stream: *mut c_void,
1433 ) -> Result<()> { unsafe {
1434 assert_eq!(
1435 self.kind,
1436 OpKind::BlockSparseContraction,
1437 "plan is not a block-sparse contraction"
1438 );
1439 let lib = cutensor()?;
1440 let cu = lib.cutensor_block_sparse_contract()?;
1441 check(cu(
1442 self.handle.as_raw(),
1443 self.plan,
1444 alpha,
1445 a,
1446 b,
1447 beta,
1448 c,
1449 d,
1450 workspace,
1451 workspace_bytes,
1452 stream,
1453 ))
1454 }}
1455
1456 #[allow(clippy::too_many_arguments)]
1463 pub unsafe fn contract_trinary(
1464 &self,
1465 alpha: *const c_void,
1466 a: *const c_void,
1467 b: *const c_void,
1468 c: *const c_void,
1469 beta: *const c_void,
1470 d: *const c_void,
1471 e: *mut c_void,
1472 workspace: *mut c_void,
1473 workspace_bytes: u64,
1474 stream: *mut c_void,
1475 ) -> Result<()> { unsafe {
1476 assert_eq!(
1477 self.kind,
1478 OpKind::TrinaryContraction,
1479 "plan is not a trinary-contraction"
1480 );
1481 let lib = cutensor()?;
1482 let cu = lib.cutensor_contract_trinary()?;
1483 check(cu(
1484 self.handle.as_raw(),
1485 self.plan,
1486 alpha,
1487 a,
1488 b,
1489 c,
1490 beta,
1491 d,
1492 e,
1493 workspace,
1494 workspace_bytes,
1495 stream,
1496 ))
1497 }}
1498}
1499
1500impl Drop for Plan<'_> {
1501 fn drop(&mut self) {
1502 if let Ok(c) = cutensor() {
1503 if let Ok(cu) = c.cutensor_destroy_plan() {
1504 let _ = unsafe { cu(self.plan) };
1505 }
1506 }
1507 }
1508}