1use crate::error::{Error, Result};
2use crate::execution::{ExecutableExecutionDescriptor, ExecutionDescriptor};
3use crate::ffi;
4use crate::graph::{
5 data_type, data_type_size, padding_mode, padding_style, tensor_named_data_layout,
6 Convolution2DDescriptor, Graph, Tensor,
7};
8use crate::types::{collect_owned_tensors, Operation, ShapedType};
9use core::ffi::{c_char, c_void};
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14 if !ptr.is_null() {
15 unsafe { ffi::mpsgraph_object_release(*ptr) };
17 *ptr = ptr::null_mut();
18 }
19}
20
21fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
22 let element_size = data_type_size(data_type)?;
23 shape
24 .iter()
25 .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
26}
27
28fn optional_cstring(name: Option<&str>) -> Option<CString> {
29 name.and_then(|value| CString::new(value).ok())
30}
31
32#[allow(clippy::ref_option)]
33fn cstring_ptr(value: &Option<CString>) -> *const c_char {
34 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
35}
36
37fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
38 if ptr.is_null() {
39 None
40 } else {
41 Some(Tensor::from_raw(ptr))
42 }
43}
44
45fn wrap_operation(ptr: *mut c_void) -> Option<Operation> {
46 if ptr.is_null() {
47 None
48 } else {
49 Some(Operation::from_raw(ptr))
50 }
51}
52
53fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
54 let mut values = collect_owned_tensors(box_handle);
55 if values.len() != 2 {
56 return None;
57 }
58 let second = values.pop()?;
59 let first = values.pop()?;
60 Some((first, second))
61}
62
63macro_rules! opaque_handle {
64 ($name:ident) => {
65pub struct $name {
67 ptr: *mut c_void,
68 }
69
70 unsafe impl Send for $name {}
71 unsafe impl Sync for $name {}
72
73 impl Drop for $name {
74 fn drop(&mut self) {
75 release_handle(&mut self.ptr);
76 }
77 }
78
79 impl $name {
80#[must_use]
82 pub const fn as_ptr(&self) -> *mut c_void {
83 self.ptr
84 }
85 }
86 };
87}
88
89pub mod execution_stage {
91pub const COMPLETED: u64 = 0;
93}
94
95pub mod reduction_mode {
97pub const MIN: usize = 0;
99pub const MAX: usize = 1;
101pub const SUM: usize = 2;
103pub const PRODUCT: usize = 3;
105pub const ARGUMENT_MIN: usize = 4;
107pub const ARGUMENT_MAX: usize = 5;
109}
110
111pub mod pooling_return_indices_mode {
113pub const NONE: usize = 0;
115pub const GLOBAL_FLATTEN_1D: usize = 1;
117pub const GLOBAL_FLATTEN_2D: usize = 2;
119pub const GLOBAL_FLATTEN_3D: usize = 3;
121pub const GLOBAL_FLATTEN_4D: usize = 4;
123pub const LOCAL_FLATTEN_1D: usize = 5;
125pub const LOCAL_FLATTEN_2D: usize = 6;
127pub const LOCAL_FLATTEN_3D: usize = 7;
129pub const LOCAL_FLATTEN_4D: usize = 8;
131}
132
133pub mod fft_scaling_mode {
135pub const NONE: usize = 0;
137pub const SIZE: usize = 1;
139pub const UNITARY: usize = 2;
141}
142
143pub mod loss_reduction_type {
145pub const NONE: u64 = 0;
147pub const AXIS: u64 = 0;
149pub const SUM: u64 = 1;
151pub const MEAN: u64 = 2;
153}
154
155pub mod non_maximum_suppression_coordinate_mode {
157pub const CORNERS_HEIGHT_FIRST: usize = 0;
159pub const CORNERS_WIDTH_FIRST: usize = 1;
161pub const CENTERS_HEIGHT_FIRST: usize = 2;
163pub const CENTERS_WIDTH_FIRST: usize = 3;
165}
166
167pub mod resize_mode {
169pub const NEAREST: usize = 0;
171pub const BILINEAR: usize = 1;
173}
174
175pub mod resize_nearest_rounding_mode {
177pub const ROUND_PREFER_CEIL: usize = 0;
179pub const ROUND_PREFER_FLOOR: usize = 1;
181pub const CEIL: usize = 2;
183pub const FLOOR: usize = 3;
185pub const ROUND_TO_EVEN: usize = 4;
187pub const ROUND_TO_ODD: usize = 5;
189}
190
191pub mod scatter_mode {
193pub const ADD: isize = 0;
195pub const SUB: isize = 1;
197pub const MUL: isize = 2;
199pub const DIV: isize = 3;
201pub const MIN: isize = 4;
203pub const MAX: isize = 5;
205pub const SET: isize = 6;
207}
208
209pub mod sparse_storage_type {
211pub const COO: u64 = 0;
213pub const CSC: u64 = 1;
215pub const CSR: u64 = 2;
217}
218
219opaque_handle!(Object);
220impl Object {
221 fn retain_from(ptr: *mut c_void) -> Self {
222 let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
224 Self { ptr }
225 }
226}
227
228opaque_handle!(GraphType);
229impl GraphType {
230 fn retain_from(ptr: *mut c_void) -> Self {
231 let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
233 Self { ptr }
234 }
235
236#[must_use]
238 pub fn as_object(&self) -> Object {
239 Object::retain_from(self.ptr)
240 }
241}
242
243opaque_handle!(VariableOp);
244impl VariableOp {
245#[must_use]
247 pub fn shape(&self) -> Vec<isize> {
248 let len = unsafe { ffi::mpsgraph_variable_op_shape_len(self.ptr) };
250 let mut shape = vec![0_isize; len];
251 if len > 0 {
252 unsafe { ffi::mpsgraph_variable_op_copy_shape(self.ptr, shape.as_mut_ptr()) };
254 }
255 shape
256 }
257
258#[must_use]
260 pub fn data_type(&self) -> u32 {
261 unsafe { ffi::mpsgraph_variable_op_data_type(self.ptr) }
263 }
264
265#[must_use]
267 pub fn as_object(&self) -> Object {
268 Object::retain_from(self.ptr)
269 }
270
271#[must_use]
273 pub fn as_operation(&self) -> Operation {
274 let ptr = unsafe { ffi::mpsgraph_object_retain(self.ptr) };
276 Operation::from_raw(ptr)
277 }
278}
279
280impl ShapedType {
281#[must_use]
283 pub fn as_graph_type(&self) -> GraphType {
284 GraphType::retain_from(self.as_ptr())
285 }
286}
287
288impl Operation {
289#[must_use]
291 pub fn as_variable(&self) -> Option<VariableOp> {
292 let ptr = unsafe { ffi::mpsgraph_operation_as_variable(self.as_ptr()) };
294 if ptr.is_null() {
295 None
296 } else {
297 Some(VariableOp { ptr })
298 }
299 }
300}
301
302#[derive(Debug, Clone, Copy)]
304pub struct Convolution3DDescriptorInfo {
305pub stride_in_x: usize,
307pub stride_in_y: usize,
309pub stride_in_z: usize,
311pub dilation_rate_in_x: usize,
313pub dilation_rate_in_y: usize,
315pub dilation_rate_in_z: usize,
317pub groups: usize,
319pub padding_left: usize,
321pub padding_right: usize,
323pub padding_top: usize,
325pub padding_bottom: usize,
327pub padding_front: usize,
329pub padding_back: usize,
331pub padding_style: usize,
333pub data_layout: usize,
335pub weights_layout: usize,
337}
338
339impl Default for Convolution3DDescriptorInfo {
340 fn default() -> Self {
341 Self {
342 stride_in_x: 1,
343 stride_in_y: 1,
344 stride_in_z: 1,
345 dilation_rate_in_x: 1,
346 dilation_rate_in_y: 1,
347 dilation_rate_in_z: 1,
348 groups: 1,
349 padding_left: 0,
350 padding_right: 0,
351 padding_top: 0,
352 padding_bottom: 0,
353 padding_front: 0,
354 padding_back: 0,
355 padding_style: padding_style::EXPLICIT,
356 data_layout: tensor_named_data_layout::NDHWC,
357 weights_layout: tensor_named_data_layout::DHWIO,
358 }
359 }
360}
361
362opaque_handle!(Convolution3DDescriptor);
363impl Convolution3DDescriptor {
364#[must_use]
366 pub fn new(info: Convolution3DDescriptorInfo) -> Option<Self> {
367 let ptr = unsafe {
369 ffi::mpsgraph_convolution3d_descriptor_new(
370 info.stride_in_x,
371 info.stride_in_y,
372 info.stride_in_z,
373 info.dilation_rate_in_x,
374 info.dilation_rate_in_y,
375 info.dilation_rate_in_z,
376 info.groups,
377 info.padding_left,
378 info.padding_right,
379 info.padding_top,
380 info.padding_bottom,
381 info.padding_front,
382 info.padding_back,
383 info.padding_style,
384 info.data_layout,
385 info.weights_layout,
386 )
387 };
388 if ptr.is_null() {
389 None
390 } else {
391 Some(Self { ptr })
392 }
393 }
394}
395
396#[derive(Debug, Clone, Copy)]
398pub struct DepthwiseConvolution2DDescriptorInfo {
399pub stride_in_x: usize,
401pub stride_in_y: usize,
403pub dilation_rate_in_x: usize,
405pub dilation_rate_in_y: usize,
407pub padding_left: usize,
409pub padding_right: usize,
411pub padding_top: usize,
413pub padding_bottom: usize,
415pub padding_style: usize,
417pub data_layout: usize,
419pub weights_layout: usize,
421}
422
423impl Default for DepthwiseConvolution2DDescriptorInfo {
424 fn default() -> Self {
425 Self {
426 stride_in_x: 1,
427 stride_in_y: 1,
428 dilation_rate_in_x: 1,
429 dilation_rate_in_y: 1,
430 padding_left: 0,
431 padding_right: 0,
432 padding_top: 0,
433 padding_bottom: 0,
434 padding_style: padding_style::EXPLICIT,
435 data_layout: tensor_named_data_layout::NHWC,
436 weights_layout: tensor_named_data_layout::HWIO,
437 }
438 }
439}
440
441opaque_handle!(DepthwiseConvolution2DDescriptor);
442impl DepthwiseConvolution2DDescriptor {
443#[must_use]
445 pub fn new(info: DepthwiseConvolution2DDescriptorInfo) -> Option<Self> {
446 let ptr = unsafe {
448 ffi::mpsgraph_depthwise_convolution2d_descriptor_new(
449 info.stride_in_x,
450 info.stride_in_y,
451 info.dilation_rate_in_x,
452 info.dilation_rate_in_y,
453 info.padding_left,
454 info.padding_right,
455 info.padding_top,
456 info.padding_bottom,
457 info.padding_style,
458 info.data_layout,
459 info.weights_layout,
460 )
461 };
462 if ptr.is_null() {
463 None
464 } else {
465 Some(Self { ptr })
466 }
467 }
468}
469
470#[derive(Debug, Clone, Copy)]
472pub struct DepthwiseConvolution3DDescriptorInfo {
473pub strides: [usize; 3],
475pub dilation_rates: [usize; 3],
477pub padding_values: [usize; 6],
479pub padding_style: usize,
481pub channel_dimension_index: isize,
483}
484
485impl Default for DepthwiseConvolution3DDescriptorInfo {
486 fn default() -> Self {
487 Self {
488 strides: [1, 1, 1],
489 dilation_rates: [1, 1, 1],
490 padding_values: [0, 0, 0, 0, 0, 0],
491 padding_style: padding_style::EXPLICIT,
492 channel_dimension_index: -1,
493 }
494 }
495}
496
497opaque_handle!(DepthwiseConvolution3DDescriptor);
498impl DepthwiseConvolution3DDescriptor {
499#[must_use]
501 pub fn new(info: DepthwiseConvolution3DDescriptorInfo) -> Option<Self> {
502 let ptr = unsafe {
504 ffi::mpsgraph_depthwise_convolution3d_descriptor_new(
505 info.strides.as_ptr(),
506 info.strides.len(),
507 info.dilation_rates.as_ptr(),
508 info.dilation_rates.len(),
509 info.padding_values.as_ptr(),
510 info.padding_values.len(),
511 info.padding_style,
512 info.channel_dimension_index,
513 )
514 };
515 if ptr.is_null() {
516 None
517 } else {
518 Some(Self { ptr })
519 }
520 }
521}
522
523#[derive(Debug, Clone, Copy)]
525pub struct FftDescriptorInfo {
526pub inverse: bool,
528pub scaling_mode: usize,
530pub round_to_odd_hermitean: bool,
532}
533
534impl Default for FftDescriptorInfo {
535 fn default() -> Self {
536 Self {
537 inverse: false,
538 scaling_mode: fft_scaling_mode::NONE,
539 round_to_odd_hermitean: false,
540 }
541 }
542}
543
544opaque_handle!(FftDescriptor);
545impl FftDescriptor {
546#[must_use]
548 pub fn new(info: FftDescriptorInfo) -> Option<Self> {
549 let ptr = unsafe {
551 ffi::mpsgraph_fft_descriptor_new(
552 info.inverse,
553 info.scaling_mode,
554 info.round_to_odd_hermitean,
555 )
556 };
557 if ptr.is_null() {
558 None
559 } else {
560 Some(Self { ptr })
561 }
562 }
563}
564
565#[derive(Debug, Clone, Copy)]
567pub struct ImToColDescriptorInfo {
568pub kernel_width: usize,
570pub kernel_height: usize,
572pub stride_in_x: usize,
574pub stride_in_y: usize,
576pub dilation_rate_in_x: usize,
578pub dilation_rate_in_y: usize,
580pub padding_left: usize,
582pub padding_right: usize,
584pub padding_top: usize,
586pub padding_bottom: usize,
588pub data_layout: usize,
590}
591
592impl Default for ImToColDescriptorInfo {
593 fn default() -> Self {
594 Self {
595 kernel_width: 1,
596 kernel_height: 1,
597 stride_in_x: 1,
598 stride_in_y: 1,
599 dilation_rate_in_x: 1,
600 dilation_rate_in_y: 1,
601 padding_left: 0,
602 padding_right: 0,
603 padding_top: 0,
604 padding_bottom: 0,
605 data_layout: tensor_named_data_layout::NHWC,
606 }
607 }
608}
609
610opaque_handle!(ImToColDescriptor);
611impl ImToColDescriptor {
612#[must_use]
614 pub fn new(info: ImToColDescriptorInfo) -> Option<Self> {
615 let ptr = unsafe {
617 ffi::mpsgraph_im_to_col_descriptor_new(
618 info.kernel_width,
619 info.kernel_height,
620 info.stride_in_x,
621 info.stride_in_y,
622 info.dilation_rate_in_x,
623 info.dilation_rate_in_y,
624 info.padding_left,
625 info.padding_right,
626 info.padding_top,
627 info.padding_bottom,
628 info.data_layout,
629 )
630 };
631 if ptr.is_null() {
632 None
633 } else {
634 Some(Self { ptr })
635 }
636 }
637}
638
639#[derive(Debug, Clone, Copy)]
641pub struct Pooling4DDescriptorInfo {
642pub kernel_sizes: [usize; 4],
644pub strides: [usize; 4],
646pub dilation_rates: [usize; 4],
648pub padding_values: [usize; 8],
650pub padding_style: usize,
652pub ceil_mode: bool,
654pub include_zero_pad_to_average: bool,
656pub return_indices_mode: usize,
658pub return_indices_data_type: u32,
660}
661
662impl Default for Pooling4DDescriptorInfo {
663 fn default() -> Self {
664 Self {
665 kernel_sizes: [1, 1, 1, 1],
666 strides: [1, 1, 1, 1],
667 dilation_rates: [1, 1, 1, 1],
668 padding_values: [0, 0, 0, 0, 0, 0, 0, 0],
669 padding_style: padding_style::EXPLICIT,
670 ceil_mode: false,
671 include_zero_pad_to_average: false,
672 return_indices_mode: pooling_return_indices_mode::NONE,
673 return_indices_data_type: data_type::INT32,
674 }
675 }
676}
677
678opaque_handle!(Pooling4DDescriptor);
679impl Pooling4DDescriptor {
680#[must_use]
682 pub fn new(info: Pooling4DDescriptorInfo) -> Option<Self> {
683 let ptr = unsafe {
685 ffi::mpsgraph_pooling4d_descriptor_new(
686 info.kernel_sizes.as_ptr(),
687 info.kernel_sizes.len(),
688 info.strides.as_ptr(),
689 info.strides.len(),
690 info.dilation_rates.as_ptr(),
691 info.dilation_rates.len(),
692 info.padding_values.as_ptr(),
693 info.padding_values.len(),
694 info.padding_style,
695 info.ceil_mode,
696 info.include_zero_pad_to_average,
697 info.return_indices_mode,
698 info.return_indices_data_type,
699 )
700 };
701 if ptr.is_null() {
702 None
703 } else {
704 Some(Self { ptr })
705 }
706 }
707}
708
709opaque_handle!(CreateSparseDescriptor);
710impl CreateSparseDescriptor {
711#[must_use]
713 pub fn new(storage_type: u64, data_type: u32) -> Option<Self> {
714 let ptr = unsafe { ffi::mpsgraph_sparse_descriptor_new(storage_type, data_type) };
716 if ptr.is_null() {
717 None
718 } else {
719 Some(Self { ptr })
720 }
721 }
722}
723
724#[derive(Debug, Clone, Copy)]
726pub struct StencilDescriptorInfo {
727pub reduction_mode: usize,
729pub offsets: [isize; 4],
731pub strides: [usize; 4],
733pub dilation_rates: [usize; 4],
735pub explicit_padding: [usize; 8],
737pub boundary_mode: isize,
739pub padding_style: usize,
741pub padding_constant: f32,
743}
744
745impl Default for StencilDescriptorInfo {
746 fn default() -> Self {
747 Self {
748 reduction_mode: reduction_mode::SUM,
749 offsets: [0, 0, 0, 0],
750 strides: [1, 1, 1, 1],
751 dilation_rates: [1, 1, 1, 1],
752 explicit_padding: [0, 0, 0, 0, 0, 0, 0, 0],
753 boundary_mode: padding_mode::ZERO,
754 padding_style: padding_style::EXPLICIT,
755 padding_constant: 0.0,
756 }
757 }
758}
759
760opaque_handle!(StencilDescriptor);
761impl StencilDescriptor {
762#[must_use]
764 pub fn new(info: StencilDescriptorInfo) -> Option<Self> {
765 let ptr = unsafe {
767 ffi::mpsgraph_stencil_descriptor_new(
768 info.reduction_mode,
769 info.offsets.as_ptr(),
770 info.offsets.len(),
771 info.strides.as_ptr(),
772 info.strides.len(),
773 info.dilation_rates.as_ptr(),
774 info.dilation_rates.len(),
775 info.explicit_padding.as_ptr(),
776 info.explicit_padding.len(),
777 info.boundary_mode,
778 info.padding_style,
779 info.padding_constant,
780 )
781 };
782 if ptr.is_null() {
783 None
784 } else {
785 Some(Self { ptr })
786 }
787 }
788}
789
790impl Graph {
791#[must_use]
793 pub fn convolution3d(
794 &self,
795 source: &Tensor,
796 weights: &Tensor,
797 descriptor: &Convolution3DDescriptor,
798 name: Option<&str>,
799 ) -> Option<Tensor> {
800 let name = optional_cstring(name);
801 let ptr = unsafe {
803 ffi::mpsgraph_graph_convolution3d(
804 self.as_ptr(),
805 source.as_ptr(),
806 weights.as_ptr(),
807 descriptor.as_ptr(),
808 cstring_ptr(&name),
809 )
810 };
811 wrap_tensor(ptr)
812 }
813
814#[must_use]
816 pub fn convolution_transpose2d(
817 &self,
818 source: &Tensor,
819 weights: &Tensor,
820 output_shape: &[usize],
821 descriptor: &Convolution2DDescriptor,
822 name: Option<&str>,
823 ) -> Option<Tensor> {
824 let name = optional_cstring(name);
825 let ptr = unsafe {
827 ffi::mpsgraph_graph_convolution_transpose2d(
828 self.as_ptr(),
829 source.as_ptr(),
830 weights.as_ptr(),
831 output_shape.as_ptr(),
832 output_shape.len(),
833 descriptor.as_ptr(),
834 cstring_ptr(&name),
835 )
836 };
837 wrap_tensor(ptr)
838 }
839
840#[must_use]
842 pub fn cumulative_sum(
843 &self,
844 tensor: &Tensor,
845 axis: isize,
846 exclusive: bool,
847 reverse: bool,
848 name: Option<&str>,
849 ) -> Option<Tensor> {
850 let name = optional_cstring(name);
851 let ptr = unsafe {
853 ffi::mpsgraph_graph_cumulative_sum(
854 self.as_ptr(),
855 tensor.as_ptr(),
856 axis,
857 exclusive,
858 reverse,
859 cstring_ptr(&name),
860 )
861 };
862 wrap_tensor(ptr)
863 }
864
865#[must_use]
867 pub fn depthwise_convolution2d(
868 &self,
869 source: &Tensor,
870 weights: &Tensor,
871 descriptor: &DepthwiseConvolution2DDescriptor,
872 name: Option<&str>,
873 ) -> Option<Tensor> {
874 let name = optional_cstring(name);
875 let ptr = unsafe {
877 ffi::mpsgraph_graph_depthwise_convolution2d(
878 self.as_ptr(),
879 source.as_ptr(),
880 weights.as_ptr(),
881 descriptor.as_ptr(),
882 cstring_ptr(&name),
883 )
884 };
885 wrap_tensor(ptr)
886 }
887
888#[must_use]
890 pub fn depthwise_convolution3d(
891 &self,
892 source: &Tensor,
893 weights: &Tensor,
894 descriptor: &DepthwiseConvolution3DDescriptor,
895 name: Option<&str>,
896 ) -> Option<Tensor> {
897 let name = optional_cstring(name);
898 let ptr = unsafe {
900 ffi::mpsgraph_graph_depthwise_convolution3d(
901 self.as_ptr(),
902 source.as_ptr(),
903 weights.as_ptr(),
904 descriptor.as_ptr(),
905 cstring_ptr(&name),
906 )
907 };
908 wrap_tensor(ptr)
909 }
910
911#[must_use]
913 pub fn fast_fourier_transform(
914 &self,
915 tensor: &Tensor,
916 axes: &[usize],
917 descriptor: &FftDescriptor,
918 name: Option<&str>,
919 ) -> Option<Tensor> {
920 let name = optional_cstring(name);
921 let ptr = unsafe {
923 ffi::mpsgraph_graph_fast_fourier_transform(
924 self.as_ptr(),
925 tensor.as_ptr(),
926 axes.as_ptr(),
927 axes.len(),
928 descriptor.as_ptr(),
929 cstring_ptr(&name),
930 )
931 };
932 wrap_tensor(ptr)
933 }
934
935#[must_use]
937 pub fn im_to_col(
938 &self,
939 source: &Tensor,
940 descriptor: &ImToColDescriptor,
941 name: Option<&str>,
942 ) -> Option<Tensor> {
943 let name = optional_cstring(name);
944 let ptr = unsafe {
946 ffi::mpsgraph_graph_im_to_col(
947 self.as_ptr(),
948 source.as_ptr(),
949 descriptor.as_ptr(),
950 cstring_ptr(&name),
951 )
952 };
953 wrap_tensor(ptr)
954 }
955
956#[must_use]
958 pub fn band_part(
959 &self,
960 tensor: &Tensor,
961 num_lower: isize,
962 num_upper: isize,
963 name: Option<&str>,
964 ) -> Option<Tensor> {
965 let name = optional_cstring(name);
966 let ptr = unsafe {
968 ffi::mpsgraph_graph_band_part(
969 self.as_ptr(),
970 tensor.as_ptr(),
971 num_lower,
972 num_upper,
973 cstring_ptr(&name),
974 )
975 };
976 wrap_tensor(ptr)
977 }
978
979#[must_use]
981 pub fn softmax_cross_entropy(
982 &self,
983 source: &Tensor,
984 labels: &Tensor,
985 axis: isize,
986 reduction_type: u64,
987 name: Option<&str>,
988 ) -> Option<Tensor> {
989 let name = optional_cstring(name);
990 let ptr = unsafe {
992 ffi::mpsgraph_graph_softmax_cross_entropy(
993 self.as_ptr(),
994 source.as_ptr(),
995 labels.as_ptr(),
996 axis,
997 reduction_type,
998 cstring_ptr(&name),
999 )
1000 };
1001 wrap_tensor(ptr)
1002 }
1003
1004#[must_use]
1006 pub fn matrix_inverse(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
1007 let name = optional_cstring(name);
1008 let ptr = unsafe {
1010 ffi::mpsgraph_graph_matrix_inverse(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
1011 };
1012 wrap_tensor(ptr)
1013 }
1014
1015#[must_use]
1017 pub fn variable_bytes(
1018 &self,
1019 data: &[u8],
1020 shape: &[usize],
1021 data_type: u32,
1022 name: Option<&str>,
1023 ) -> Option<Tensor> {
1024 let expected = checked_byte_len(shape, data_type)?;
1025 if data.len() != expected {
1026 return None;
1027 }
1028
1029 let name = optional_cstring(name);
1030 let ptr = unsafe {
1032 ffi::mpsgraph_graph_variable_data(
1033 self.as_ptr(),
1034 data.as_ptr().cast(),
1035 data.len(),
1036 shape.as_ptr(),
1037 shape.len(),
1038 data_type,
1039 cstring_ptr(&name),
1040 )
1041 };
1042 wrap_tensor(ptr)
1043 }
1044
1045#[must_use]
1047 pub fn variable_f32_slice(
1048 &self,
1049 values: &[f32],
1050 shape: &[usize],
1051 name: Option<&str>,
1052 ) -> Option<Tensor> {
1053 let bytes = unsafe {
1055 core::slice::from_raw_parts(
1056 values.as_ptr().cast::<u8>(),
1057 core::mem::size_of_val(values),
1058 )
1059 };
1060 self.variable_bytes(bytes, shape, data_type::FLOAT32, name)
1061 }
1062
1063#[must_use]
1065 pub fn read_variable(&self, variable: &Tensor, name: Option<&str>) -> Option<Tensor> {
1066 let name = optional_cstring(name);
1067 let ptr = unsafe {
1069 ffi::mpsgraph_graph_read_variable(self.as_ptr(), variable.as_ptr(), cstring_ptr(&name))
1070 };
1071 wrap_tensor(ptr)
1072 }
1073
1074#[must_use]
1076 pub fn assign_variable(
1077 &self,
1078 variable: &Tensor,
1079 value: &Tensor,
1080 name: Option<&str>,
1081 ) -> Option<Operation> {
1082 let name = optional_cstring(name);
1083 let ptr = unsafe {
1085 ffi::mpsgraph_graph_assign_variable(
1086 self.as_ptr(),
1087 variable.as_ptr(),
1088 value.as_ptr(),
1089 cstring_ptr(&name),
1090 )
1091 };
1092 wrap_operation(ptr)
1093 }
1094
1095#[must_use]
1097 #[allow(clippy::too_many_arguments)]
1098 pub fn non_maximum_suppression(
1099 &self,
1100 boxes: &Tensor,
1101 scores: &Tensor,
1102 iou_threshold: f32,
1103 score_threshold: f32,
1104 per_class_suppression: bool,
1105 coordinate_mode: usize,
1106 name: Option<&str>,
1107 ) -> Option<Tensor> {
1108 let name = optional_cstring(name);
1109 let ptr = unsafe {
1111 ffi::mpsgraph_graph_non_maximum_suppression(
1112 self.as_ptr(),
1113 boxes.as_ptr(),
1114 scores.as_ptr(),
1115 iou_threshold,
1116 score_threshold,
1117 per_class_suppression,
1118 coordinate_mode,
1119 cstring_ptr(&name),
1120 )
1121 };
1122 wrap_tensor(ptr)
1123 }
1124
1125#[must_use]
1127 pub fn non_zero_indices(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
1128 let name = optional_cstring(name);
1129 let ptr = unsafe {
1131 ffi::mpsgraph_graph_non_zero_indices(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
1132 };
1133 wrap_tensor(ptr)
1134 }
1135
1136#[must_use]
1138 pub fn one_hot(
1139 &self,
1140 indices: &Tensor,
1141 depth: usize,
1142 data_type: u32,
1143 name: Option<&str>,
1144 ) -> Option<Tensor> {
1145 let name = optional_cstring(name);
1146 let ptr = unsafe {
1148 ffi::mpsgraph_graph_one_hot(
1149 self.as_ptr(),
1150 indices.as_ptr(),
1151 depth,
1152 data_type,
1153 cstring_ptr(&name),
1154 )
1155 };
1156 wrap_tensor(ptr)
1157 }
1158
1159#[must_use]
1161 pub fn stochastic_gradient_descent(
1162 &self,
1163 learning_rate: &Tensor,
1164 values: &Tensor,
1165 gradient: &Tensor,
1166 name: Option<&str>,
1167 ) -> Option<Tensor> {
1168 let name = optional_cstring(name);
1169 let ptr = unsafe {
1171 ffi::mpsgraph_graph_stochastic_gradient_descent(
1172 self.as_ptr(),
1173 learning_rate.as_ptr(),
1174 values.as_ptr(),
1175 gradient.as_ptr(),
1176 cstring_ptr(&name),
1177 )
1178 };
1179 wrap_tensor(ptr)
1180 }
1181
1182#[must_use]
1184 pub fn max_pooling4d(
1185 &self,
1186 source: &Tensor,
1187 descriptor: &Pooling4DDescriptor,
1188 name: Option<&str>,
1189 ) -> Option<Tensor> {
1190 let name = optional_cstring(name);
1191 let ptr = unsafe {
1193 ffi::mpsgraph_graph_max_pooling4d(
1194 self.as_ptr(),
1195 source.as_ptr(),
1196 descriptor.as_ptr(),
1197 cstring_ptr(&name),
1198 )
1199 };
1200 wrap_tensor(ptr)
1201 }
1202
1203#[must_use]
1205 pub fn max_pooling4d_return_indices(
1206 &self,
1207 source: &Tensor,
1208 descriptor: &Pooling4DDescriptor,
1209 name: Option<&str>,
1210 ) -> Option<(Tensor, Tensor)> {
1211 let name = optional_cstring(name);
1212 let box_handle = unsafe {
1214 ffi::mpsgraph_graph_max_pooling4d_return_indices(
1215 self.as_ptr(),
1216 source.as_ptr(),
1217 descriptor.as_ptr(),
1218 cstring_ptr(&name),
1219 )
1220 };
1221 wrap_tensor_pair(box_handle)
1222 }
1223
1224#[must_use]
1226 pub fn quantize(
1227 &self,
1228 tensor: &Tensor,
1229 scale: f64,
1230 zero_point: f64,
1231 data_type: u32,
1232 name: Option<&str>,
1233 ) -> Option<Tensor> {
1234 let name = optional_cstring(name);
1235 let ptr = unsafe {
1237 ffi::mpsgraph_graph_quantize(
1238 self.as_ptr(),
1239 tensor.as_ptr(),
1240 scale,
1241 zero_point,
1242 data_type,
1243 cstring_ptr(&name),
1244 )
1245 };
1246 wrap_tensor(ptr)
1247 }
1248
1249#[must_use]
1251 pub fn dequantize(
1252 &self,
1253 tensor: &Tensor,
1254 scale: f64,
1255 zero_point: f64,
1256 data_type: u32,
1257 name: Option<&str>,
1258 ) -> Option<Tensor> {
1259 let name = optional_cstring(name);
1260 let ptr = unsafe {
1262 ffi::mpsgraph_graph_dequantize(
1263 self.as_ptr(),
1264 tensor.as_ptr(),
1265 scale,
1266 zero_point,
1267 data_type,
1268 cstring_ptr(&name),
1269 )
1270 };
1271 wrap_tensor(ptr)
1272 }
1273
1274#[must_use]
1276 #[allow(clippy::too_many_arguments)]
1277 pub fn resize(
1278 &self,
1279 images: &Tensor,
1280 size: &[usize],
1281 mode: usize,
1282 center_result: bool,
1283 align_corners: bool,
1284 layout: usize,
1285 name: Option<&str>,
1286 ) -> Option<Tensor> {
1287 let name = optional_cstring(name);
1288 let ptr = unsafe {
1290 ffi::mpsgraph_graph_resize(
1291 self.as_ptr(),
1292 images.as_ptr(),
1293 size.as_ptr(),
1294 size.len(),
1295 mode,
1296 center_result,
1297 align_corners,
1298 layout,
1299 cstring_ptr(&name),
1300 )
1301 };
1302 wrap_tensor(ptr)
1303 }
1304
1305#[must_use]
1307 #[allow(clippy::too_many_arguments)]
1308 pub fn resize_nearest(
1309 &self,
1310 images: &Tensor,
1311 size_tensor: &Tensor,
1312 nearest_rounding_mode: usize,
1313 center_result: bool,
1314 align_corners: bool,
1315 layout: usize,
1316 name: Option<&str>,
1317 ) -> Option<Tensor> {
1318 let name = optional_cstring(name);
1319 let ptr = unsafe {
1321 ffi::mpsgraph_graph_resize_nearest(
1322 self.as_ptr(),
1323 images.as_ptr(),
1324 size_tensor.as_ptr(),
1325 nearest_rounding_mode,
1326 center_result,
1327 align_corners,
1328 layout,
1329 cstring_ptr(&name),
1330 )
1331 };
1332 wrap_tensor(ptr)
1333 }
1334
1335#[must_use]
1337 #[allow(clippy::too_many_arguments)]
1338 pub fn sample_grid(
1339 &self,
1340 source: &Tensor,
1341 coordinates: &Tensor,
1342 layout: usize,
1343 normalize_coordinates: bool,
1344 relative_coordinates: bool,
1345 align_corners: bool,
1346 padding_mode: isize,
1347 sampling_mode: usize,
1348 constant_value: f64,
1349 name: Option<&str>,
1350 ) -> Option<Tensor> {
1351 let name = optional_cstring(name);
1352 let ptr = unsafe {
1354 ffi::mpsgraph_graph_sample_grid(
1355 self.as_ptr(),
1356 source.as_ptr(),
1357 coordinates.as_ptr(),
1358 layout,
1359 normalize_coordinates,
1360 relative_coordinates,
1361 align_corners,
1362 padding_mode,
1363 sampling_mode,
1364 constant_value,
1365 cstring_ptr(&name),
1366 )
1367 };
1368 wrap_tensor(ptr)
1369 }
1370
1371#[must_use]
1373 pub fn scatter_nd(
1374 &self,
1375 updates: &Tensor,
1376 indices: &Tensor,
1377 shape: &[usize],
1378 batch_dimensions: usize,
1379 mode: isize,
1380 name: Option<&str>,
1381 ) -> Option<Tensor> {
1382 let name = optional_cstring(name);
1383 let ptr = unsafe {
1385 ffi::mpsgraph_graph_scatter_nd(
1386 self.as_ptr(),
1387 updates.as_ptr(),
1388 indices.as_ptr(),
1389 shape.as_ptr(),
1390 shape.len(),
1391 batch_dimensions,
1392 mode,
1393 cstring_ptr(&name),
1394 )
1395 };
1396 wrap_tensor(ptr)
1397 }
1398
1399#[must_use]
1401 pub fn scatter(
1402 &self,
1403 updates: &Tensor,
1404 indices: &Tensor,
1405 shape: &[usize],
1406 axis: isize,
1407 mode: isize,
1408 name: Option<&str>,
1409 ) -> Option<Tensor> {
1410 let name = optional_cstring(name);
1411 let ptr = unsafe {
1413 ffi::mpsgraph_graph_scatter(
1414 self.as_ptr(),
1415 updates.as_ptr(),
1416 indices.as_ptr(),
1417 shape.as_ptr(),
1418 shape.len(),
1419 axis,
1420 mode,
1421 cstring_ptr(&name),
1422 )
1423 };
1424 wrap_tensor(ptr)
1425 }
1426
1427#[must_use]
1429 pub fn scatter_along_axis(
1430 &self,
1431 axis: isize,
1432 updates: &Tensor,
1433 indices: &Tensor,
1434 shape: &[usize],
1435 mode: isize,
1436 name: Option<&str>,
1437 ) -> Option<Tensor> {
1438 let name = optional_cstring(name);
1439 let ptr = unsafe {
1441 ffi::mpsgraph_graph_scatter_along_axis(
1442 self.as_ptr(),
1443 axis,
1444 updates.as_ptr(),
1445 indices.as_ptr(),
1446 shape.as_ptr(),
1447 shape.len(),
1448 mode,
1449 cstring_ptr(&name),
1450 )
1451 };
1452 wrap_tensor(ptr)
1453 }
1454
1455#[must_use]
1457 pub fn sort(
1458 &self,
1459 tensor: &Tensor,
1460 axis: isize,
1461 descending: bool,
1462 name: Option<&str>,
1463 ) -> Option<Tensor> {
1464 let name = optional_cstring(name);
1465 let ptr = unsafe {
1467 ffi::mpsgraph_graph_sort(
1468 self.as_ptr(),
1469 tensor.as_ptr(),
1470 axis,
1471 descending,
1472 cstring_ptr(&name),
1473 )
1474 };
1475 wrap_tensor(ptr)
1476 }
1477
1478#[must_use]
1480 pub fn arg_sort(
1481 &self,
1482 tensor: &Tensor,
1483 axis: isize,
1484 descending: bool,
1485 name: Option<&str>,
1486 ) -> Option<Tensor> {
1487 let name = optional_cstring(name);
1488 let ptr = unsafe {
1490 ffi::mpsgraph_graph_arg_sort(
1491 self.as_ptr(),
1492 tensor.as_ptr(),
1493 axis,
1494 descending,
1495 cstring_ptr(&name),
1496 )
1497 };
1498 wrap_tensor(ptr)
1499 }
1500
1501#[must_use]
1503 pub fn sparse_tensor_with_descriptor(
1504 &self,
1505 descriptor: &CreateSparseDescriptor,
1506 tensors: &[&Tensor],
1507 shape: &[usize],
1508 name: Option<&str>,
1509 ) -> Option<Tensor> {
1510 let name = optional_cstring(name);
1511 let handles = tensors
1512 .iter()
1513 .map(|tensor| tensor.as_ptr())
1514 .collect::<Vec<_>>();
1515 let ptr = unsafe {
1517 ffi::mpsgraph_graph_sparse_tensor_with_descriptor(
1518 self.as_ptr(),
1519 descriptor.as_ptr(),
1520 handles.as_ptr(),
1521 handles.len(),
1522 shape.as_ptr(),
1523 shape.len(),
1524 cstring_ptr(&name),
1525 )
1526 };
1527 wrap_tensor(ptr)
1528 }
1529
1530#[must_use]
1532 pub fn stencil(
1533 &self,
1534 source: &Tensor,
1535 weights: &Tensor,
1536 descriptor: &StencilDescriptor,
1537 name: Option<&str>,
1538 ) -> Option<Tensor> {
1539 let name = optional_cstring(name);
1540 let ptr = unsafe {
1542 ffi::mpsgraph_graph_stencil(
1543 self.as_ptr(),
1544 source.as_ptr(),
1545 weights.as_ptr(),
1546 descriptor.as_ptr(),
1547 cstring_ptr(&name),
1548 )
1549 };
1550 wrap_tensor(ptr)
1551 }
1552
1553#[must_use]
1555 pub fn top_k_gradient(
1556 &self,
1557 gradient: &Tensor,
1558 source: &Tensor,
1559 k: usize,
1560 name: Option<&str>,
1561 ) -> Option<Tensor> {
1562 let name = optional_cstring(name);
1563 let ptr = unsafe {
1565 ffi::mpsgraph_graph_topk_gradient(
1566 self.as_ptr(),
1567 gradient.as_ptr(),
1568 source.as_ptr(),
1569 k,
1570 cstring_ptr(&name),
1571 )
1572 };
1573 wrap_tensor(ptr)
1574 }
1575}
1576
1577impl ExecutionDescriptor {
1578 pub unsafe fn wait_for_shared_event_raw(
1582 &self,
1583 event_handle: *mut c_void,
1584 value: u64,
1585 ) -> Result<()> {
1586 let ok = unsafe {
1588 ffi::mpsgraph_execution_descriptor_wait_for_event(self.as_ptr(), event_handle, value)
1589 };
1590 if ok {
1591 Ok(())
1592 } else {
1593 Err(Error::OperationFailed(
1594 "failed to register execution descriptor shared-event wait",
1595 ))
1596 }
1597 }
1598
1599 pub unsafe fn signal_shared_event_raw(
1603 &self,
1604 event_handle: *mut c_void,
1605 execution_stage: u64,
1606 value: u64,
1607 ) -> Result<()> {
1608 let ok = unsafe {
1610 ffi::mpsgraph_execution_descriptor_signal_event(
1611 self.as_ptr(),
1612 event_handle,
1613 execution_stage,
1614 value,
1615 )
1616 };
1617 if ok {
1618 Ok(())
1619 } else {
1620 Err(Error::OperationFailed(
1621 "failed to register execution descriptor shared-event signal",
1622 ))
1623 }
1624 }
1625}
1626
1627impl ExecutableExecutionDescriptor {
1628 pub unsafe fn wait_for_shared_event_raw(
1632 &self,
1633 event_handle: *mut c_void,
1634 value: u64,
1635 ) -> Result<()> {
1636 let ok = unsafe {
1638 ffi::mpsgraph_executable_execution_descriptor_wait_for_event(
1639 self.as_ptr(),
1640 event_handle,
1641 value,
1642 )
1643 };
1644 if ok {
1645 Ok(())
1646 } else {
1647 Err(Error::OperationFailed(
1648 "failed to register executable execution descriptor shared-event wait",
1649 ))
1650 }
1651 }
1652
1653 pub unsafe fn signal_shared_event_raw(
1657 &self,
1658 event_handle: *mut c_void,
1659 execution_stage: u64,
1660 value: u64,
1661 ) -> Result<()> {
1662 let ok = unsafe {
1664 ffi::mpsgraph_executable_execution_descriptor_signal_event(
1665 self.as_ptr(),
1666 event_handle,
1667 execution_stage,
1668 value,
1669 )
1670 };
1671 if ok {
1672 Ok(())
1673 } else {
1674 Err(Error::OperationFailed(
1675 "failed to register executable execution descriptor shared-event signal",
1676 ))
1677 }
1678 }
1679}