1use crate::error::{Error, Result};
2use crate::execution::{ExecutableExecutionDescriptor, ExecutionDescriptor};
3use crate::ffi;
4use crate::graph::{
5 data_type, data_type_size, padding_mode, padding_style, tensor_named_data_layout,
6 Convolution2DDescriptor, Graph, Tensor,
7};
8use crate::types::{collect_owned_tensors, Operation, ShapedType};
9use core::ffi::{c_char, c_void};
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14 if !ptr.is_null() {
15 unsafe { ffi::mpsgraph_object_release(*ptr) };
17 *ptr = ptr::null_mut();
18 }
19}
20
21fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
22 let element_size = data_type_size(data_type)?;
23 shape
24 .iter()
25 .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
26}
27
28fn optional_cstring(name: Option<&str>) -> Option<CString> {
29 name.and_then(|value| CString::new(value).ok())
30}
31
32#[allow(clippy::ref_option)]
33fn cstring_ptr(value: &Option<CString>) -> *const c_char {
34 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
35}
36
37fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
38 if ptr.is_null() {
39 None
40 } else {
41 Some(Tensor::from_raw(ptr))
42 }
43}
44
45fn wrap_operation(ptr: *mut c_void) -> Option<Operation> {
46 if ptr.is_null() {
47 None
48 } else {
49 Some(Operation::from_raw(ptr))
50 }
51}
52
53fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
54 let mut values = collect_owned_tensors(box_handle);
55 if values.len() != 2 {
56 return None;
57 }
58 let second = values.pop()?;
59 let first = values.pop()?;
60 Some((first, second))
61}
62
63macro_rules! opaque_handle {
64 ($name:ident) => {
65 pub struct $name {
66 ptr: *mut c_void,
67 }
68
69 unsafe impl Send for $name {}
70 unsafe impl Sync for $name {}
71
72 impl Drop for $name {
73 fn drop(&mut self) {
74 release_handle(&mut self.ptr);
75 }
76 }
77
78 impl $name {
79 #[must_use]
80 pub const fn as_ptr(&self) -> *mut c_void {
81 self.ptr
82 }
83 }
84 };
85}
86
87pub mod execution_stage {
89 pub const COMPLETED: u64 = 0;
90}
91
92pub mod reduction_mode {
94 pub const MIN: usize = 0;
95 pub const MAX: usize = 1;
96 pub const SUM: usize = 2;
97 pub const PRODUCT: usize = 3;
98 pub const ARGUMENT_MIN: usize = 4;
99 pub const ARGUMENT_MAX: usize = 5;
100}
101
102pub mod pooling_return_indices_mode {
104 pub const NONE: usize = 0;
105 pub const GLOBAL_FLATTEN_1D: usize = 1;
106 pub const GLOBAL_FLATTEN_2D: usize = 2;
107 pub const GLOBAL_FLATTEN_3D: usize = 3;
108 pub const GLOBAL_FLATTEN_4D: usize = 4;
109 pub const LOCAL_FLATTEN_1D: usize = 5;
110 pub const LOCAL_FLATTEN_2D: usize = 6;
111 pub const LOCAL_FLATTEN_3D: usize = 7;
112 pub const LOCAL_FLATTEN_4D: usize = 8;
113}
114
115pub mod fft_scaling_mode {
117 pub const NONE: usize = 0;
118 pub const SIZE: usize = 1;
119 pub const UNITARY: usize = 2;
120}
121
122pub mod loss_reduction_type {
124 pub const NONE: u64 = 0;
125 pub const AXIS: u64 = 0;
126 pub const SUM: u64 = 1;
127 pub const MEAN: u64 = 2;
128}
129
130pub mod non_maximum_suppression_coordinate_mode {
132 pub const CORNERS_HEIGHT_FIRST: usize = 0;
133 pub const CORNERS_WIDTH_FIRST: usize = 1;
134 pub const CENTERS_HEIGHT_FIRST: usize = 2;
135 pub const CENTERS_WIDTH_FIRST: usize = 3;
136}
137
138pub mod resize_mode {
140 pub const NEAREST: usize = 0;
141 pub const BILINEAR: usize = 1;
142}
143
144pub mod resize_nearest_rounding_mode {
146 pub const ROUND_PREFER_CEIL: usize = 0;
147 pub const ROUND_PREFER_FLOOR: usize = 1;
148 pub const CEIL: usize = 2;
149 pub const FLOOR: usize = 3;
150 pub const ROUND_TO_EVEN: usize = 4;
151 pub const ROUND_TO_ODD: usize = 5;
152}
153
154pub mod scatter_mode {
156 pub const ADD: isize = 0;
157 pub const SUB: isize = 1;
158 pub const MUL: isize = 2;
159 pub const DIV: isize = 3;
160 pub const MIN: isize = 4;
161 pub const MAX: isize = 5;
162 pub const SET: isize = 6;
163}
164
165pub mod sparse_storage_type {
167 pub const COO: u64 = 0;
168 pub const CSC: u64 = 1;
169 pub const CSR: u64 = 2;
170}
171
172opaque_handle!(Object);
173impl Object {
174 fn retain_from(ptr: *mut c_void) -> Self {
175 let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
177 Self { ptr }
178 }
179}
180
181opaque_handle!(GraphType);
182impl GraphType {
183 fn retain_from(ptr: *mut c_void) -> Self {
184 let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
186 Self { ptr }
187 }
188
189 #[must_use]
190 pub fn as_object(&self) -> Object {
191 Object::retain_from(self.ptr)
192 }
193}
194
195opaque_handle!(VariableOp);
196impl VariableOp {
197 #[must_use]
198 pub fn shape(&self) -> Vec<isize> {
199 let len = unsafe { ffi::mpsgraph_variable_op_shape_len(self.ptr) };
201 let mut shape = vec![0_isize; len];
202 if len > 0 {
203 unsafe { ffi::mpsgraph_variable_op_copy_shape(self.ptr, shape.as_mut_ptr()) };
205 }
206 shape
207 }
208
209 #[must_use]
210 pub fn data_type(&self) -> u32 {
211 unsafe { ffi::mpsgraph_variable_op_data_type(self.ptr) }
213 }
214
215 #[must_use]
216 pub fn as_object(&self) -> Object {
217 Object::retain_from(self.ptr)
218 }
219
220 #[must_use]
221 pub fn as_operation(&self) -> Operation {
222 let ptr = unsafe { ffi::mpsgraph_object_retain(self.ptr) };
224 Operation::from_raw(ptr)
225 }
226}
227
228impl ShapedType {
229 #[must_use]
230 pub fn as_graph_type(&self) -> GraphType {
231 GraphType::retain_from(self.as_ptr())
232 }
233}
234
235impl Operation {
236 #[must_use]
237 pub fn as_variable(&self) -> Option<VariableOp> {
238 let ptr = unsafe { ffi::mpsgraph_operation_as_variable(self.as_ptr()) };
240 if ptr.is_null() {
241 None
242 } else {
243 Some(VariableOp { ptr })
244 }
245 }
246}
247
248#[derive(Debug, Clone, Copy)]
249pub struct Convolution3DDescriptorInfo {
250 pub stride_in_x: usize,
251 pub stride_in_y: usize,
252 pub stride_in_z: usize,
253 pub dilation_rate_in_x: usize,
254 pub dilation_rate_in_y: usize,
255 pub dilation_rate_in_z: usize,
256 pub groups: usize,
257 pub padding_left: usize,
258 pub padding_right: usize,
259 pub padding_top: usize,
260 pub padding_bottom: usize,
261 pub padding_front: usize,
262 pub padding_back: usize,
263 pub padding_style: usize,
264 pub data_layout: usize,
265 pub weights_layout: usize,
266}
267
268impl Default for Convolution3DDescriptorInfo {
269 fn default() -> Self {
270 Self {
271 stride_in_x: 1,
272 stride_in_y: 1,
273 stride_in_z: 1,
274 dilation_rate_in_x: 1,
275 dilation_rate_in_y: 1,
276 dilation_rate_in_z: 1,
277 groups: 1,
278 padding_left: 0,
279 padding_right: 0,
280 padding_top: 0,
281 padding_bottom: 0,
282 padding_front: 0,
283 padding_back: 0,
284 padding_style: padding_style::EXPLICIT,
285 data_layout: tensor_named_data_layout::NDHWC,
286 weights_layout: tensor_named_data_layout::DHWIO,
287 }
288 }
289}
290
291opaque_handle!(Convolution3DDescriptor);
292impl Convolution3DDescriptor {
293 #[must_use]
294 pub fn new(info: Convolution3DDescriptorInfo) -> Option<Self> {
295 let ptr = unsafe {
297 ffi::mpsgraph_convolution3d_descriptor_new(
298 info.stride_in_x,
299 info.stride_in_y,
300 info.stride_in_z,
301 info.dilation_rate_in_x,
302 info.dilation_rate_in_y,
303 info.dilation_rate_in_z,
304 info.groups,
305 info.padding_left,
306 info.padding_right,
307 info.padding_top,
308 info.padding_bottom,
309 info.padding_front,
310 info.padding_back,
311 info.padding_style,
312 info.data_layout,
313 info.weights_layout,
314 )
315 };
316 if ptr.is_null() {
317 None
318 } else {
319 Some(Self { ptr })
320 }
321 }
322}
323
324#[derive(Debug, Clone, Copy)]
325pub struct DepthwiseConvolution2DDescriptorInfo {
326 pub stride_in_x: usize,
327 pub stride_in_y: usize,
328 pub dilation_rate_in_x: usize,
329 pub dilation_rate_in_y: usize,
330 pub padding_left: usize,
331 pub padding_right: usize,
332 pub padding_top: usize,
333 pub padding_bottom: usize,
334 pub padding_style: usize,
335 pub data_layout: usize,
336 pub weights_layout: usize,
337}
338
339impl Default for DepthwiseConvolution2DDescriptorInfo {
340 fn default() -> Self {
341 Self {
342 stride_in_x: 1,
343 stride_in_y: 1,
344 dilation_rate_in_x: 1,
345 dilation_rate_in_y: 1,
346 padding_left: 0,
347 padding_right: 0,
348 padding_top: 0,
349 padding_bottom: 0,
350 padding_style: padding_style::EXPLICIT,
351 data_layout: tensor_named_data_layout::NHWC,
352 weights_layout: tensor_named_data_layout::HWIO,
353 }
354 }
355}
356
357opaque_handle!(DepthwiseConvolution2DDescriptor);
358impl DepthwiseConvolution2DDescriptor {
359 #[must_use]
360 pub fn new(info: DepthwiseConvolution2DDescriptorInfo) -> Option<Self> {
361 let ptr = unsafe {
363 ffi::mpsgraph_depthwise_convolution2d_descriptor_new(
364 info.stride_in_x,
365 info.stride_in_y,
366 info.dilation_rate_in_x,
367 info.dilation_rate_in_y,
368 info.padding_left,
369 info.padding_right,
370 info.padding_top,
371 info.padding_bottom,
372 info.padding_style,
373 info.data_layout,
374 info.weights_layout,
375 )
376 };
377 if ptr.is_null() {
378 None
379 } else {
380 Some(Self { ptr })
381 }
382 }
383}
384
385#[derive(Debug, Clone, Copy)]
386pub struct DepthwiseConvolution3DDescriptorInfo {
387 pub strides: [usize; 3],
388 pub dilation_rates: [usize; 3],
389 pub padding_values: [usize; 6],
390 pub padding_style: usize,
391 pub channel_dimension_index: isize,
392}
393
394impl Default for DepthwiseConvolution3DDescriptorInfo {
395 fn default() -> Self {
396 Self {
397 strides: [1, 1, 1],
398 dilation_rates: [1, 1, 1],
399 padding_values: [0, 0, 0, 0, 0, 0],
400 padding_style: padding_style::EXPLICIT,
401 channel_dimension_index: -1,
402 }
403 }
404}
405
406opaque_handle!(DepthwiseConvolution3DDescriptor);
407impl DepthwiseConvolution3DDescriptor {
408 #[must_use]
409 pub fn new(info: DepthwiseConvolution3DDescriptorInfo) -> Option<Self> {
410 let ptr = unsafe {
412 ffi::mpsgraph_depthwise_convolution3d_descriptor_new(
413 info.strides.as_ptr(),
414 info.strides.len(),
415 info.dilation_rates.as_ptr(),
416 info.dilation_rates.len(),
417 info.padding_values.as_ptr(),
418 info.padding_values.len(),
419 info.padding_style,
420 info.channel_dimension_index,
421 )
422 };
423 if ptr.is_null() {
424 None
425 } else {
426 Some(Self { ptr })
427 }
428 }
429}
430
431#[derive(Debug, Clone, Copy)]
432pub struct FftDescriptorInfo {
433 pub inverse: bool,
434 pub scaling_mode: usize,
435 pub round_to_odd_hermitean: bool,
436}
437
438impl Default for FftDescriptorInfo {
439 fn default() -> Self {
440 Self {
441 inverse: false,
442 scaling_mode: fft_scaling_mode::NONE,
443 round_to_odd_hermitean: false,
444 }
445 }
446}
447
448opaque_handle!(FftDescriptor);
449impl FftDescriptor {
450 #[must_use]
451 pub fn new(info: FftDescriptorInfo) -> Option<Self> {
452 let ptr = unsafe {
454 ffi::mpsgraph_fft_descriptor_new(
455 info.inverse,
456 info.scaling_mode,
457 info.round_to_odd_hermitean,
458 )
459 };
460 if ptr.is_null() {
461 None
462 } else {
463 Some(Self { ptr })
464 }
465 }
466}
467
468#[derive(Debug, Clone, Copy)]
469pub struct ImToColDescriptorInfo {
470 pub kernel_width: usize,
471 pub kernel_height: usize,
472 pub stride_in_x: usize,
473 pub stride_in_y: usize,
474 pub dilation_rate_in_x: usize,
475 pub dilation_rate_in_y: usize,
476 pub padding_left: usize,
477 pub padding_right: usize,
478 pub padding_top: usize,
479 pub padding_bottom: usize,
480 pub data_layout: usize,
481}
482
483impl Default for ImToColDescriptorInfo {
484 fn default() -> Self {
485 Self {
486 kernel_width: 1,
487 kernel_height: 1,
488 stride_in_x: 1,
489 stride_in_y: 1,
490 dilation_rate_in_x: 1,
491 dilation_rate_in_y: 1,
492 padding_left: 0,
493 padding_right: 0,
494 padding_top: 0,
495 padding_bottom: 0,
496 data_layout: tensor_named_data_layout::NHWC,
497 }
498 }
499}
500
501opaque_handle!(ImToColDescriptor);
502impl ImToColDescriptor {
503 #[must_use]
504 pub fn new(info: ImToColDescriptorInfo) -> Option<Self> {
505 let ptr = unsafe {
507 ffi::mpsgraph_im_to_col_descriptor_new(
508 info.kernel_width,
509 info.kernel_height,
510 info.stride_in_x,
511 info.stride_in_y,
512 info.dilation_rate_in_x,
513 info.dilation_rate_in_y,
514 info.padding_left,
515 info.padding_right,
516 info.padding_top,
517 info.padding_bottom,
518 info.data_layout,
519 )
520 };
521 if ptr.is_null() {
522 None
523 } else {
524 Some(Self { ptr })
525 }
526 }
527}
528
529#[derive(Debug, Clone, Copy)]
530pub struct Pooling4DDescriptorInfo {
531 pub kernel_sizes: [usize; 4],
532 pub strides: [usize; 4],
533 pub dilation_rates: [usize; 4],
534 pub padding_values: [usize; 8],
535 pub padding_style: usize,
536 pub ceil_mode: bool,
537 pub include_zero_pad_to_average: bool,
538 pub return_indices_mode: usize,
539 pub return_indices_data_type: u32,
540}
541
542impl Default for Pooling4DDescriptorInfo {
543 fn default() -> Self {
544 Self {
545 kernel_sizes: [1, 1, 1, 1],
546 strides: [1, 1, 1, 1],
547 dilation_rates: [1, 1, 1, 1],
548 padding_values: [0, 0, 0, 0, 0, 0, 0, 0],
549 padding_style: padding_style::EXPLICIT,
550 ceil_mode: false,
551 include_zero_pad_to_average: false,
552 return_indices_mode: pooling_return_indices_mode::NONE,
553 return_indices_data_type: data_type::INT32,
554 }
555 }
556}
557
558opaque_handle!(Pooling4DDescriptor);
559impl Pooling4DDescriptor {
560 #[must_use]
561 pub fn new(info: Pooling4DDescriptorInfo) -> Option<Self> {
562 let ptr = unsafe {
564 ffi::mpsgraph_pooling4d_descriptor_new(
565 info.kernel_sizes.as_ptr(),
566 info.kernel_sizes.len(),
567 info.strides.as_ptr(),
568 info.strides.len(),
569 info.dilation_rates.as_ptr(),
570 info.dilation_rates.len(),
571 info.padding_values.as_ptr(),
572 info.padding_values.len(),
573 info.padding_style,
574 info.ceil_mode,
575 info.include_zero_pad_to_average,
576 info.return_indices_mode,
577 info.return_indices_data_type,
578 )
579 };
580 if ptr.is_null() {
581 None
582 } else {
583 Some(Self { ptr })
584 }
585 }
586}
587
588opaque_handle!(CreateSparseDescriptor);
589impl CreateSparseDescriptor {
590 #[must_use]
591 pub fn new(storage_type: u64, data_type: u32) -> Option<Self> {
592 let ptr = unsafe { ffi::mpsgraph_sparse_descriptor_new(storage_type, data_type) };
594 if ptr.is_null() {
595 None
596 } else {
597 Some(Self { ptr })
598 }
599 }
600}
601
602#[derive(Debug, Clone, Copy)]
603pub struct StencilDescriptorInfo {
604 pub reduction_mode: usize,
605 pub offsets: [isize; 4],
606 pub strides: [usize; 4],
607 pub dilation_rates: [usize; 4],
608 pub explicit_padding: [usize; 8],
609 pub boundary_mode: isize,
610 pub padding_style: usize,
611 pub padding_constant: f32,
612}
613
614impl Default for StencilDescriptorInfo {
615 fn default() -> Self {
616 Self {
617 reduction_mode: reduction_mode::SUM,
618 offsets: [0, 0, 0, 0],
619 strides: [1, 1, 1, 1],
620 dilation_rates: [1, 1, 1, 1],
621 explicit_padding: [0, 0, 0, 0, 0, 0, 0, 0],
622 boundary_mode: padding_mode::ZERO,
623 padding_style: padding_style::EXPLICIT,
624 padding_constant: 0.0,
625 }
626 }
627}
628
629opaque_handle!(StencilDescriptor);
630impl StencilDescriptor {
631 #[must_use]
632 pub fn new(info: StencilDescriptorInfo) -> Option<Self> {
633 let ptr = unsafe {
635 ffi::mpsgraph_stencil_descriptor_new(
636 info.reduction_mode,
637 info.offsets.as_ptr(),
638 info.offsets.len(),
639 info.strides.as_ptr(),
640 info.strides.len(),
641 info.dilation_rates.as_ptr(),
642 info.dilation_rates.len(),
643 info.explicit_padding.as_ptr(),
644 info.explicit_padding.len(),
645 info.boundary_mode,
646 info.padding_style,
647 info.padding_constant,
648 )
649 };
650 if ptr.is_null() {
651 None
652 } else {
653 Some(Self { ptr })
654 }
655 }
656}
657
658impl Graph {
659 #[must_use]
660 pub fn convolution3d(
661 &self,
662 source: &Tensor,
663 weights: &Tensor,
664 descriptor: &Convolution3DDescriptor,
665 name: Option<&str>,
666 ) -> Option<Tensor> {
667 let name = optional_cstring(name);
668 let ptr = unsafe {
670 ffi::mpsgraph_graph_convolution3d(
671 self.as_ptr(),
672 source.as_ptr(),
673 weights.as_ptr(),
674 descriptor.as_ptr(),
675 cstring_ptr(&name),
676 )
677 };
678 wrap_tensor(ptr)
679 }
680
681 #[must_use]
682 pub fn convolution_transpose2d(
683 &self,
684 source: &Tensor,
685 weights: &Tensor,
686 output_shape: &[usize],
687 descriptor: &Convolution2DDescriptor,
688 name: Option<&str>,
689 ) -> Option<Tensor> {
690 let name = optional_cstring(name);
691 let ptr = unsafe {
693 ffi::mpsgraph_graph_convolution_transpose2d(
694 self.as_ptr(),
695 source.as_ptr(),
696 weights.as_ptr(),
697 output_shape.as_ptr(),
698 output_shape.len(),
699 descriptor.as_ptr(),
700 cstring_ptr(&name),
701 )
702 };
703 wrap_tensor(ptr)
704 }
705
706 #[must_use]
707 pub fn cumulative_sum(
708 &self,
709 tensor: &Tensor,
710 axis: isize,
711 exclusive: bool,
712 reverse: bool,
713 name: Option<&str>,
714 ) -> Option<Tensor> {
715 let name = optional_cstring(name);
716 let ptr = unsafe {
718 ffi::mpsgraph_graph_cumulative_sum(
719 self.as_ptr(),
720 tensor.as_ptr(),
721 axis,
722 exclusive,
723 reverse,
724 cstring_ptr(&name),
725 )
726 };
727 wrap_tensor(ptr)
728 }
729
730 #[must_use]
731 pub fn depthwise_convolution2d(
732 &self,
733 source: &Tensor,
734 weights: &Tensor,
735 descriptor: &DepthwiseConvolution2DDescriptor,
736 name: Option<&str>,
737 ) -> Option<Tensor> {
738 let name = optional_cstring(name);
739 let ptr = unsafe {
741 ffi::mpsgraph_graph_depthwise_convolution2d(
742 self.as_ptr(),
743 source.as_ptr(),
744 weights.as_ptr(),
745 descriptor.as_ptr(),
746 cstring_ptr(&name),
747 )
748 };
749 wrap_tensor(ptr)
750 }
751
752 #[must_use]
753 pub fn depthwise_convolution3d(
754 &self,
755 source: &Tensor,
756 weights: &Tensor,
757 descriptor: &DepthwiseConvolution3DDescriptor,
758 name: Option<&str>,
759 ) -> Option<Tensor> {
760 let name = optional_cstring(name);
761 let ptr = unsafe {
763 ffi::mpsgraph_graph_depthwise_convolution3d(
764 self.as_ptr(),
765 source.as_ptr(),
766 weights.as_ptr(),
767 descriptor.as_ptr(),
768 cstring_ptr(&name),
769 )
770 };
771 wrap_tensor(ptr)
772 }
773
774 #[must_use]
775 pub fn fast_fourier_transform(
776 &self,
777 tensor: &Tensor,
778 axes: &[usize],
779 descriptor: &FftDescriptor,
780 name: Option<&str>,
781 ) -> Option<Tensor> {
782 let name = optional_cstring(name);
783 let ptr = unsafe {
785 ffi::mpsgraph_graph_fast_fourier_transform(
786 self.as_ptr(),
787 tensor.as_ptr(),
788 axes.as_ptr(),
789 axes.len(),
790 descriptor.as_ptr(),
791 cstring_ptr(&name),
792 )
793 };
794 wrap_tensor(ptr)
795 }
796
797 #[must_use]
798 pub fn im_to_col(
799 &self,
800 source: &Tensor,
801 descriptor: &ImToColDescriptor,
802 name: Option<&str>,
803 ) -> Option<Tensor> {
804 let name = optional_cstring(name);
805 let ptr = unsafe {
807 ffi::mpsgraph_graph_im_to_col(
808 self.as_ptr(),
809 source.as_ptr(),
810 descriptor.as_ptr(),
811 cstring_ptr(&name),
812 )
813 };
814 wrap_tensor(ptr)
815 }
816
817 #[must_use]
818 pub fn band_part(
819 &self,
820 tensor: &Tensor,
821 num_lower: isize,
822 num_upper: isize,
823 name: Option<&str>,
824 ) -> Option<Tensor> {
825 let name = optional_cstring(name);
826 let ptr = unsafe {
828 ffi::mpsgraph_graph_band_part(
829 self.as_ptr(),
830 tensor.as_ptr(),
831 num_lower,
832 num_upper,
833 cstring_ptr(&name),
834 )
835 };
836 wrap_tensor(ptr)
837 }
838
839 #[must_use]
840 pub fn softmax_cross_entropy(
841 &self,
842 source: &Tensor,
843 labels: &Tensor,
844 axis: isize,
845 reduction_type: u64,
846 name: Option<&str>,
847 ) -> Option<Tensor> {
848 let name = optional_cstring(name);
849 let ptr = unsafe {
851 ffi::mpsgraph_graph_softmax_cross_entropy(
852 self.as_ptr(),
853 source.as_ptr(),
854 labels.as_ptr(),
855 axis,
856 reduction_type,
857 cstring_ptr(&name),
858 )
859 };
860 wrap_tensor(ptr)
861 }
862
863 #[must_use]
864 pub fn matrix_inverse(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
865 let name = optional_cstring(name);
866 let ptr = unsafe {
868 ffi::mpsgraph_graph_matrix_inverse(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
869 };
870 wrap_tensor(ptr)
871 }
872
873 #[must_use]
874 pub fn variable_bytes(
875 &self,
876 data: &[u8],
877 shape: &[usize],
878 data_type: u32,
879 name: Option<&str>,
880 ) -> Option<Tensor> {
881 let expected = checked_byte_len(shape, data_type)?;
882 if data.len() != expected {
883 return None;
884 }
885
886 let name = optional_cstring(name);
887 let ptr = unsafe {
889 ffi::mpsgraph_graph_variable_data(
890 self.as_ptr(),
891 data.as_ptr().cast(),
892 data.len(),
893 shape.as_ptr(),
894 shape.len(),
895 data_type,
896 cstring_ptr(&name),
897 )
898 };
899 wrap_tensor(ptr)
900 }
901
902 #[must_use]
903 pub fn variable_f32_slice(
904 &self,
905 values: &[f32],
906 shape: &[usize],
907 name: Option<&str>,
908 ) -> Option<Tensor> {
909 let bytes = unsafe {
911 core::slice::from_raw_parts(
912 values.as_ptr().cast::<u8>(),
913 core::mem::size_of_val(values),
914 )
915 };
916 self.variable_bytes(bytes, shape, data_type::FLOAT32, name)
917 }
918
919 #[must_use]
920 pub fn read_variable(&self, variable: &Tensor, name: Option<&str>) -> Option<Tensor> {
921 let name = optional_cstring(name);
922 let ptr = unsafe {
924 ffi::mpsgraph_graph_read_variable(self.as_ptr(), variable.as_ptr(), cstring_ptr(&name))
925 };
926 wrap_tensor(ptr)
927 }
928
929 #[must_use]
930 pub fn assign_variable(
931 &self,
932 variable: &Tensor,
933 value: &Tensor,
934 name: Option<&str>,
935 ) -> Option<Operation> {
936 let name = optional_cstring(name);
937 let ptr = unsafe {
939 ffi::mpsgraph_graph_assign_variable(
940 self.as_ptr(),
941 variable.as_ptr(),
942 value.as_ptr(),
943 cstring_ptr(&name),
944 )
945 };
946 wrap_operation(ptr)
947 }
948
949 #[must_use]
950 #[allow(clippy::too_many_arguments)]
951 pub fn non_maximum_suppression(
952 &self,
953 boxes: &Tensor,
954 scores: &Tensor,
955 iou_threshold: f32,
956 score_threshold: f32,
957 per_class_suppression: bool,
958 coordinate_mode: usize,
959 name: Option<&str>,
960 ) -> Option<Tensor> {
961 let name = optional_cstring(name);
962 let ptr = unsafe {
964 ffi::mpsgraph_graph_non_maximum_suppression(
965 self.as_ptr(),
966 boxes.as_ptr(),
967 scores.as_ptr(),
968 iou_threshold,
969 score_threshold,
970 per_class_suppression,
971 coordinate_mode,
972 cstring_ptr(&name),
973 )
974 };
975 wrap_tensor(ptr)
976 }
977
978 #[must_use]
979 pub fn non_zero_indices(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
980 let name = optional_cstring(name);
981 let ptr = unsafe {
983 ffi::mpsgraph_graph_non_zero_indices(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
984 };
985 wrap_tensor(ptr)
986 }
987
988 #[must_use]
989 pub fn one_hot(
990 &self,
991 indices: &Tensor,
992 depth: usize,
993 data_type: u32,
994 name: Option<&str>,
995 ) -> Option<Tensor> {
996 let name = optional_cstring(name);
997 let ptr = unsafe {
999 ffi::mpsgraph_graph_one_hot(
1000 self.as_ptr(),
1001 indices.as_ptr(),
1002 depth,
1003 data_type,
1004 cstring_ptr(&name),
1005 )
1006 };
1007 wrap_tensor(ptr)
1008 }
1009
1010 #[must_use]
1011 pub fn stochastic_gradient_descent(
1012 &self,
1013 learning_rate: &Tensor,
1014 values: &Tensor,
1015 gradient: &Tensor,
1016 name: Option<&str>,
1017 ) -> Option<Tensor> {
1018 let name = optional_cstring(name);
1019 let ptr = unsafe {
1021 ffi::mpsgraph_graph_stochastic_gradient_descent(
1022 self.as_ptr(),
1023 learning_rate.as_ptr(),
1024 values.as_ptr(),
1025 gradient.as_ptr(),
1026 cstring_ptr(&name),
1027 )
1028 };
1029 wrap_tensor(ptr)
1030 }
1031
1032 #[must_use]
1033 pub fn max_pooling4d(
1034 &self,
1035 source: &Tensor,
1036 descriptor: &Pooling4DDescriptor,
1037 name: Option<&str>,
1038 ) -> Option<Tensor> {
1039 let name = optional_cstring(name);
1040 let ptr = unsafe {
1042 ffi::mpsgraph_graph_max_pooling4d(
1043 self.as_ptr(),
1044 source.as_ptr(),
1045 descriptor.as_ptr(),
1046 cstring_ptr(&name),
1047 )
1048 };
1049 wrap_tensor(ptr)
1050 }
1051
1052 #[must_use]
1053 pub fn max_pooling4d_return_indices(
1054 &self,
1055 source: &Tensor,
1056 descriptor: &Pooling4DDescriptor,
1057 name: Option<&str>,
1058 ) -> Option<(Tensor, Tensor)> {
1059 let name = optional_cstring(name);
1060 let box_handle = unsafe {
1062 ffi::mpsgraph_graph_max_pooling4d_return_indices(
1063 self.as_ptr(),
1064 source.as_ptr(),
1065 descriptor.as_ptr(),
1066 cstring_ptr(&name),
1067 )
1068 };
1069 wrap_tensor_pair(box_handle)
1070 }
1071
1072 #[must_use]
1073 pub fn quantize(
1074 &self,
1075 tensor: &Tensor,
1076 scale: f64,
1077 zero_point: f64,
1078 data_type: u32,
1079 name: Option<&str>,
1080 ) -> Option<Tensor> {
1081 let name = optional_cstring(name);
1082 let ptr = unsafe {
1084 ffi::mpsgraph_graph_quantize(
1085 self.as_ptr(),
1086 tensor.as_ptr(),
1087 scale,
1088 zero_point,
1089 data_type,
1090 cstring_ptr(&name),
1091 )
1092 };
1093 wrap_tensor(ptr)
1094 }
1095
1096 #[must_use]
1097 pub fn dequantize(
1098 &self,
1099 tensor: &Tensor,
1100 scale: f64,
1101 zero_point: f64,
1102 data_type: u32,
1103 name: Option<&str>,
1104 ) -> Option<Tensor> {
1105 let name = optional_cstring(name);
1106 let ptr = unsafe {
1108 ffi::mpsgraph_graph_dequantize(
1109 self.as_ptr(),
1110 tensor.as_ptr(),
1111 scale,
1112 zero_point,
1113 data_type,
1114 cstring_ptr(&name),
1115 )
1116 };
1117 wrap_tensor(ptr)
1118 }
1119
1120 #[must_use]
1121 #[allow(clippy::too_many_arguments)]
1122 pub fn resize(
1123 &self,
1124 images: &Tensor,
1125 size: &[usize],
1126 mode: usize,
1127 center_result: bool,
1128 align_corners: bool,
1129 layout: usize,
1130 name: Option<&str>,
1131 ) -> Option<Tensor> {
1132 let name = optional_cstring(name);
1133 let ptr = unsafe {
1135 ffi::mpsgraph_graph_resize(
1136 self.as_ptr(),
1137 images.as_ptr(),
1138 size.as_ptr(),
1139 size.len(),
1140 mode,
1141 center_result,
1142 align_corners,
1143 layout,
1144 cstring_ptr(&name),
1145 )
1146 };
1147 wrap_tensor(ptr)
1148 }
1149
1150 #[must_use]
1151 #[allow(clippy::too_many_arguments)]
1152 pub fn resize_nearest(
1153 &self,
1154 images: &Tensor,
1155 size_tensor: &Tensor,
1156 nearest_rounding_mode: usize,
1157 center_result: bool,
1158 align_corners: bool,
1159 layout: usize,
1160 name: Option<&str>,
1161 ) -> Option<Tensor> {
1162 let name = optional_cstring(name);
1163 let ptr = unsafe {
1165 ffi::mpsgraph_graph_resize_nearest(
1166 self.as_ptr(),
1167 images.as_ptr(),
1168 size_tensor.as_ptr(),
1169 nearest_rounding_mode,
1170 center_result,
1171 align_corners,
1172 layout,
1173 cstring_ptr(&name),
1174 )
1175 };
1176 wrap_tensor(ptr)
1177 }
1178
1179 #[must_use]
1180 #[allow(clippy::too_many_arguments)]
1181 pub fn sample_grid(
1182 &self,
1183 source: &Tensor,
1184 coordinates: &Tensor,
1185 layout: usize,
1186 normalize_coordinates: bool,
1187 relative_coordinates: bool,
1188 align_corners: bool,
1189 padding_mode: isize,
1190 sampling_mode: usize,
1191 constant_value: f64,
1192 name: Option<&str>,
1193 ) -> Option<Tensor> {
1194 let name = optional_cstring(name);
1195 let ptr = unsafe {
1197 ffi::mpsgraph_graph_sample_grid(
1198 self.as_ptr(),
1199 source.as_ptr(),
1200 coordinates.as_ptr(),
1201 layout,
1202 normalize_coordinates,
1203 relative_coordinates,
1204 align_corners,
1205 padding_mode,
1206 sampling_mode,
1207 constant_value,
1208 cstring_ptr(&name),
1209 )
1210 };
1211 wrap_tensor(ptr)
1212 }
1213
1214 #[must_use]
1215 pub fn scatter_nd(
1216 &self,
1217 updates: &Tensor,
1218 indices: &Tensor,
1219 shape: &[usize],
1220 batch_dimensions: usize,
1221 mode: isize,
1222 name: Option<&str>,
1223 ) -> Option<Tensor> {
1224 let name = optional_cstring(name);
1225 let ptr = unsafe {
1227 ffi::mpsgraph_graph_scatter_nd(
1228 self.as_ptr(),
1229 updates.as_ptr(),
1230 indices.as_ptr(),
1231 shape.as_ptr(),
1232 shape.len(),
1233 batch_dimensions,
1234 mode,
1235 cstring_ptr(&name),
1236 )
1237 };
1238 wrap_tensor(ptr)
1239 }
1240
1241 #[must_use]
1242 pub fn scatter(
1243 &self,
1244 updates: &Tensor,
1245 indices: &Tensor,
1246 shape: &[usize],
1247 axis: isize,
1248 mode: isize,
1249 name: Option<&str>,
1250 ) -> Option<Tensor> {
1251 let name = optional_cstring(name);
1252 let ptr = unsafe {
1254 ffi::mpsgraph_graph_scatter(
1255 self.as_ptr(),
1256 updates.as_ptr(),
1257 indices.as_ptr(),
1258 shape.as_ptr(),
1259 shape.len(),
1260 axis,
1261 mode,
1262 cstring_ptr(&name),
1263 )
1264 };
1265 wrap_tensor(ptr)
1266 }
1267
1268 #[must_use]
1269 pub fn scatter_along_axis(
1270 &self,
1271 axis: isize,
1272 updates: &Tensor,
1273 indices: &Tensor,
1274 shape: &[usize],
1275 mode: isize,
1276 name: Option<&str>,
1277 ) -> Option<Tensor> {
1278 let name = optional_cstring(name);
1279 let ptr = unsafe {
1281 ffi::mpsgraph_graph_scatter_along_axis(
1282 self.as_ptr(),
1283 axis,
1284 updates.as_ptr(),
1285 indices.as_ptr(),
1286 shape.as_ptr(),
1287 shape.len(),
1288 mode,
1289 cstring_ptr(&name),
1290 )
1291 };
1292 wrap_tensor(ptr)
1293 }
1294
1295 #[must_use]
1296 pub fn sort(
1297 &self,
1298 tensor: &Tensor,
1299 axis: isize,
1300 descending: bool,
1301 name: Option<&str>,
1302 ) -> Option<Tensor> {
1303 let name = optional_cstring(name);
1304 let ptr = unsafe {
1306 ffi::mpsgraph_graph_sort(
1307 self.as_ptr(),
1308 tensor.as_ptr(),
1309 axis,
1310 descending,
1311 cstring_ptr(&name),
1312 )
1313 };
1314 wrap_tensor(ptr)
1315 }
1316
1317 #[must_use]
1318 pub fn arg_sort(
1319 &self,
1320 tensor: &Tensor,
1321 axis: isize,
1322 descending: bool,
1323 name: Option<&str>,
1324 ) -> Option<Tensor> {
1325 let name = optional_cstring(name);
1326 let ptr = unsafe {
1328 ffi::mpsgraph_graph_arg_sort(
1329 self.as_ptr(),
1330 tensor.as_ptr(),
1331 axis,
1332 descending,
1333 cstring_ptr(&name),
1334 )
1335 };
1336 wrap_tensor(ptr)
1337 }
1338
1339 #[must_use]
1340 pub fn sparse_tensor_with_descriptor(
1341 &self,
1342 descriptor: &CreateSparseDescriptor,
1343 tensors: &[&Tensor],
1344 shape: &[usize],
1345 name: Option<&str>,
1346 ) -> Option<Tensor> {
1347 let name = optional_cstring(name);
1348 let handles = tensors
1349 .iter()
1350 .map(|tensor| tensor.as_ptr())
1351 .collect::<Vec<_>>();
1352 let ptr = unsafe {
1354 ffi::mpsgraph_graph_sparse_tensor_with_descriptor(
1355 self.as_ptr(),
1356 descriptor.as_ptr(),
1357 handles.as_ptr(),
1358 handles.len(),
1359 shape.as_ptr(),
1360 shape.len(),
1361 cstring_ptr(&name),
1362 )
1363 };
1364 wrap_tensor(ptr)
1365 }
1366
1367 #[must_use]
1368 pub fn stencil(
1369 &self,
1370 source: &Tensor,
1371 weights: &Tensor,
1372 descriptor: &StencilDescriptor,
1373 name: Option<&str>,
1374 ) -> Option<Tensor> {
1375 let name = optional_cstring(name);
1376 let ptr = unsafe {
1378 ffi::mpsgraph_graph_stencil(
1379 self.as_ptr(),
1380 source.as_ptr(),
1381 weights.as_ptr(),
1382 descriptor.as_ptr(),
1383 cstring_ptr(&name),
1384 )
1385 };
1386 wrap_tensor(ptr)
1387 }
1388
1389 #[must_use]
1390 pub fn top_k_gradient(
1391 &self,
1392 gradient: &Tensor,
1393 source: &Tensor,
1394 k: usize,
1395 name: Option<&str>,
1396 ) -> Option<Tensor> {
1397 let name = optional_cstring(name);
1398 let ptr = unsafe {
1400 ffi::mpsgraph_graph_topk_gradient(
1401 self.as_ptr(),
1402 gradient.as_ptr(),
1403 source.as_ptr(),
1404 k,
1405 cstring_ptr(&name),
1406 )
1407 };
1408 wrap_tensor(ptr)
1409 }
1410}
1411
1412impl ExecutionDescriptor {
1413 pub unsafe fn wait_for_shared_event_raw(
1417 &self,
1418 event_handle: *mut c_void,
1419 value: u64,
1420 ) -> Result<()> {
1421 let ok = unsafe {
1423 ffi::mpsgraph_execution_descriptor_wait_for_event(self.as_ptr(), event_handle, value)
1424 };
1425 if ok {
1426 Ok(())
1427 } else {
1428 Err(Error::OperationFailed(
1429 "failed to register execution descriptor shared-event wait",
1430 ))
1431 }
1432 }
1433
1434 pub unsafe fn signal_shared_event_raw(
1438 &self,
1439 event_handle: *mut c_void,
1440 execution_stage: u64,
1441 value: u64,
1442 ) -> Result<()> {
1443 let ok = unsafe {
1445 ffi::mpsgraph_execution_descriptor_signal_event(
1446 self.as_ptr(),
1447 event_handle,
1448 execution_stage,
1449 value,
1450 )
1451 };
1452 if ok {
1453 Ok(())
1454 } else {
1455 Err(Error::OperationFailed(
1456 "failed to register execution descriptor shared-event signal",
1457 ))
1458 }
1459 }
1460}
1461
1462impl ExecutableExecutionDescriptor {
1463 pub unsafe fn wait_for_shared_event_raw(
1467 &self,
1468 event_handle: *mut c_void,
1469 value: u64,
1470 ) -> Result<()> {
1471 let ok = unsafe {
1473 ffi::mpsgraph_executable_execution_descriptor_wait_for_event(
1474 self.as_ptr(),
1475 event_handle,
1476 value,
1477 )
1478 };
1479 if ok {
1480 Ok(())
1481 } else {
1482 Err(Error::OperationFailed(
1483 "failed to register executable execution descriptor shared-event wait",
1484 ))
1485 }
1486 }
1487
1488 pub unsafe fn signal_shared_event_raw(
1492 &self,
1493 event_handle: *mut c_void,
1494 execution_stage: u64,
1495 value: u64,
1496 ) -> Result<()> {
1497 let ok = unsafe {
1499 ffi::mpsgraph_executable_execution_descriptor_signal_event(
1500 self.as_ptr(),
1501 event_handle,
1502 execution_stage,
1503 value,
1504 )
1505 };
1506 if ok {
1507 Ok(())
1508 } else {
1509 Err(Error::OperationFailed(
1510 "failed to register executable execution descriptor shared-event signal",
1511 ))
1512 }
1513 }
1514}