Skip to main content

apple_mpsgraph/
specialized.rs

1use crate::error::{Error, Result};
2use crate::execution::{ExecutableExecutionDescriptor, ExecutionDescriptor};
3use crate::ffi;
4use crate::graph::{
5    data_type, data_type_size, padding_mode, padding_style, tensor_named_data_layout,
6    Convolution2DDescriptor, Graph, Tensor,
7};
8use crate::types::{collect_owned_tensors, Operation, ShapedType};
9use core::ffi::{c_char, c_void};
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14    if !ptr.is_null() {
15        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
16        unsafe { ffi::mpsgraph_object_release(*ptr) };
17        *ptr = ptr::null_mut();
18    }
19}
20
21fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
22    let element_size = data_type_size(data_type)?;
23    shape
24        .iter()
25        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
26}
27
28fn optional_cstring(name: Option<&str>) -> Option<CString> {
29    name.and_then(|value| CString::new(value).ok())
30}
31
32#[allow(clippy::ref_option)]
33fn cstring_ptr(value: &Option<CString>) -> *const c_char {
34    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
35}
36
37fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
38    if ptr.is_null() {
39        None
40    } else {
41        Some(Tensor::from_raw(ptr))
42    }
43}
44
45fn wrap_operation(ptr: *mut c_void) -> Option<Operation> {
46    if ptr.is_null() {
47        None
48    } else {
49        Some(Operation::from_raw(ptr))
50    }
51}
52
53fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
54    let mut values = collect_owned_tensors(box_handle);
55    if values.len() != 2 {
56        return None;
57    }
58    let second = values.pop()?;
59    let first = values.pop()?;
60    Some((first, second))
61}
62
63macro_rules! opaque_handle {
64    ($name:ident) => {
65        pub struct $name {
66            ptr: *mut c_void,
67        }
68
69        unsafe impl Send for $name {}
70        unsafe impl Sync for $name {}
71
72        impl Drop for $name {
73            fn drop(&mut self) {
74                release_handle(&mut self.ptr);
75            }
76        }
77
78        impl $name {
79            #[must_use]
80            pub const fn as_ptr(&self) -> *mut c_void {
81                self.ptr
82            }
83        }
84    };
85}
86
87/// `MPSGraphExecutionStage` constants.
88pub mod execution_stage {
89    pub const COMPLETED: u64 = 0;
90}
91
92/// `MPSGraphReductionMode` constants.
93pub mod reduction_mode {
94    pub const MIN: usize = 0;
95    pub const MAX: usize = 1;
96    pub const SUM: usize = 2;
97    pub const PRODUCT: usize = 3;
98    pub const ARGUMENT_MIN: usize = 4;
99    pub const ARGUMENT_MAX: usize = 5;
100}
101
102/// `MPSGraphPoolingReturnIndicesMode` constants.
103pub mod pooling_return_indices_mode {
104    pub const NONE: usize = 0;
105    pub const GLOBAL_FLATTEN_1D: usize = 1;
106    pub const GLOBAL_FLATTEN_2D: usize = 2;
107    pub const GLOBAL_FLATTEN_3D: usize = 3;
108    pub const GLOBAL_FLATTEN_4D: usize = 4;
109    pub const LOCAL_FLATTEN_1D: usize = 5;
110    pub const LOCAL_FLATTEN_2D: usize = 6;
111    pub const LOCAL_FLATTEN_3D: usize = 7;
112    pub const LOCAL_FLATTEN_4D: usize = 8;
113}
114
115/// `MPSGraphFFTScalingMode` constants.
116pub mod fft_scaling_mode {
117    pub const NONE: usize = 0;
118    pub const SIZE: usize = 1;
119    pub const UNITARY: usize = 2;
120}
121
122/// `MPSGraphLossReductionType` constants.
123pub mod loss_reduction_type {
124    pub const NONE: u64 = 0;
125    pub const AXIS: u64 = 0;
126    pub const SUM: u64 = 1;
127    pub const MEAN: u64 = 2;
128}
129
130/// `MPSGraphNonMaximumSuppressionCoordinateMode` constants.
131pub mod non_maximum_suppression_coordinate_mode {
132    pub const CORNERS_HEIGHT_FIRST: usize = 0;
133    pub const CORNERS_WIDTH_FIRST: usize = 1;
134    pub const CENTERS_HEIGHT_FIRST: usize = 2;
135    pub const CENTERS_WIDTH_FIRST: usize = 3;
136}
137
138/// `MPSGraphResizeMode` constants.
139pub mod resize_mode {
140    pub const NEAREST: usize = 0;
141    pub const BILINEAR: usize = 1;
142}
143
144/// `MPSGraphResizeNearestRoundingMode` constants.
145pub mod resize_nearest_rounding_mode {
146    pub const ROUND_PREFER_CEIL: usize = 0;
147    pub const ROUND_PREFER_FLOOR: usize = 1;
148    pub const CEIL: usize = 2;
149    pub const FLOOR: usize = 3;
150    pub const ROUND_TO_EVEN: usize = 4;
151    pub const ROUND_TO_ODD: usize = 5;
152}
153
154/// `MPSGraphScatterMode` constants.
155pub mod scatter_mode {
156    pub const ADD: isize = 0;
157    pub const SUB: isize = 1;
158    pub const MUL: isize = 2;
159    pub const DIV: isize = 3;
160    pub const MIN: isize = 4;
161    pub const MAX: isize = 5;
162    pub const SET: isize = 6;
163}
164
165/// `MPSGraphSparseStorageType` constants.
166pub mod sparse_storage_type {
167    pub const COO: u64 = 0;
168    pub const CSC: u64 = 1;
169    pub const CSR: u64 = 2;
170}
171
172opaque_handle!(Object);
173impl Object {
174    fn retain_from(ptr: *mut c_void) -> Self {
175        // SAFETY: `ptr` belongs to a live `MPSGraphObject` subclass and the bridge retains it for this wrapper.
176        let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
177        Self { ptr }
178    }
179}
180
181opaque_handle!(GraphType);
182impl GraphType {
183    fn retain_from(ptr: *mut c_void) -> Self {
184        // SAFETY: `ptr` belongs to a live `MPSGraphType` subclass and the bridge retains it for this wrapper.
185        let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
186        Self { ptr }
187    }
188
189    #[must_use]
190    pub fn as_object(&self) -> Object {
191        Object::retain_from(self.ptr)
192    }
193}
194
195opaque_handle!(VariableOp);
196impl VariableOp {
197    #[must_use]
198    pub fn shape(&self) -> Vec<isize> {
199        // SAFETY: `self.ptr` is a live variable-op handle.
200        let len = unsafe { ffi::mpsgraph_variable_op_shape_len(self.ptr) };
201        let mut shape = vec![0_isize; len];
202        if len > 0 {
203            // SAFETY: `shape` has space for exactly `len` elements.
204            unsafe { ffi::mpsgraph_variable_op_copy_shape(self.ptr, shape.as_mut_ptr()) };
205        }
206        shape
207    }
208
209    #[must_use]
210    pub fn data_type(&self) -> u32 {
211        // SAFETY: `self.ptr` is a live variable-op handle.
212        unsafe { ffi::mpsgraph_variable_op_data_type(self.ptr) }
213    }
214
215    #[must_use]
216    pub fn as_object(&self) -> Object {
217        Object::retain_from(self.ptr)
218    }
219
220    #[must_use]
221    pub fn as_operation(&self) -> Operation {
222        // SAFETY: `self.ptr` is a live variable-op handle and retains as an operation wrapper.
223        let ptr = unsafe { ffi::mpsgraph_object_retain(self.ptr) };
224        Operation::from_raw(ptr)
225    }
226}
227
228impl ShapedType {
229    #[must_use]
230    pub fn as_graph_type(&self) -> GraphType {
231        GraphType::retain_from(self.as_ptr())
232    }
233}
234
235impl Operation {
236    #[must_use]
237    pub fn as_variable(&self) -> Option<VariableOp> {
238        // SAFETY: `self.ptr` is a live operation handle.
239        let ptr = unsafe { ffi::mpsgraph_operation_as_variable(self.as_ptr()) };
240        if ptr.is_null() {
241            None
242        } else {
243            Some(VariableOp { ptr })
244        }
245    }
246}
247
248#[derive(Debug, Clone, Copy)]
249pub struct Convolution3DDescriptorInfo {
250    pub stride_in_x: usize,
251    pub stride_in_y: usize,
252    pub stride_in_z: usize,
253    pub dilation_rate_in_x: usize,
254    pub dilation_rate_in_y: usize,
255    pub dilation_rate_in_z: usize,
256    pub groups: usize,
257    pub padding_left: usize,
258    pub padding_right: usize,
259    pub padding_top: usize,
260    pub padding_bottom: usize,
261    pub padding_front: usize,
262    pub padding_back: usize,
263    pub padding_style: usize,
264    pub data_layout: usize,
265    pub weights_layout: usize,
266}
267
268impl Default for Convolution3DDescriptorInfo {
269    fn default() -> Self {
270        Self {
271            stride_in_x: 1,
272            stride_in_y: 1,
273            stride_in_z: 1,
274            dilation_rate_in_x: 1,
275            dilation_rate_in_y: 1,
276            dilation_rate_in_z: 1,
277            groups: 1,
278            padding_left: 0,
279            padding_right: 0,
280            padding_top: 0,
281            padding_bottom: 0,
282            padding_front: 0,
283            padding_back: 0,
284            padding_style: padding_style::EXPLICIT,
285            data_layout: tensor_named_data_layout::NDHWC,
286            weights_layout: tensor_named_data_layout::DHWIO,
287        }
288    }
289}
290
291opaque_handle!(Convolution3DDescriptor);
292impl Convolution3DDescriptor {
293    #[must_use]
294    pub fn new(info: Convolution3DDescriptorInfo) -> Option<Self> {
295        // SAFETY: all arguments are POD configuration values.
296        let ptr = unsafe {
297            ffi::mpsgraph_convolution3d_descriptor_new(
298                info.stride_in_x,
299                info.stride_in_y,
300                info.stride_in_z,
301                info.dilation_rate_in_x,
302                info.dilation_rate_in_y,
303                info.dilation_rate_in_z,
304                info.groups,
305                info.padding_left,
306                info.padding_right,
307                info.padding_top,
308                info.padding_bottom,
309                info.padding_front,
310                info.padding_back,
311                info.padding_style,
312                info.data_layout,
313                info.weights_layout,
314            )
315        };
316        if ptr.is_null() {
317            None
318        } else {
319            Some(Self { ptr })
320        }
321    }
322}
323
324#[derive(Debug, Clone, Copy)]
325pub struct DepthwiseConvolution2DDescriptorInfo {
326    pub stride_in_x: usize,
327    pub stride_in_y: usize,
328    pub dilation_rate_in_x: usize,
329    pub dilation_rate_in_y: usize,
330    pub padding_left: usize,
331    pub padding_right: usize,
332    pub padding_top: usize,
333    pub padding_bottom: usize,
334    pub padding_style: usize,
335    pub data_layout: usize,
336    pub weights_layout: usize,
337}
338
339impl Default for DepthwiseConvolution2DDescriptorInfo {
340    fn default() -> Self {
341        Self {
342            stride_in_x: 1,
343            stride_in_y: 1,
344            dilation_rate_in_x: 1,
345            dilation_rate_in_y: 1,
346            padding_left: 0,
347            padding_right: 0,
348            padding_top: 0,
349            padding_bottom: 0,
350            padding_style: padding_style::EXPLICIT,
351            data_layout: tensor_named_data_layout::NHWC,
352            weights_layout: tensor_named_data_layout::HWIO,
353        }
354    }
355}
356
357opaque_handle!(DepthwiseConvolution2DDescriptor);
358impl DepthwiseConvolution2DDescriptor {
359    #[must_use]
360    pub fn new(info: DepthwiseConvolution2DDescriptorInfo) -> Option<Self> {
361        // SAFETY: all arguments are POD configuration values.
362        let ptr = unsafe {
363            ffi::mpsgraph_depthwise_convolution2d_descriptor_new(
364                info.stride_in_x,
365                info.stride_in_y,
366                info.dilation_rate_in_x,
367                info.dilation_rate_in_y,
368                info.padding_left,
369                info.padding_right,
370                info.padding_top,
371                info.padding_bottom,
372                info.padding_style,
373                info.data_layout,
374                info.weights_layout,
375            )
376        };
377        if ptr.is_null() {
378            None
379        } else {
380            Some(Self { ptr })
381        }
382    }
383}
384
385#[derive(Debug, Clone, Copy)]
386pub struct DepthwiseConvolution3DDescriptorInfo {
387    pub strides: [usize; 3],
388    pub dilation_rates: [usize; 3],
389    pub padding_values: [usize; 6],
390    pub padding_style: usize,
391    pub channel_dimension_index: isize,
392}
393
394impl Default for DepthwiseConvolution3DDescriptorInfo {
395    fn default() -> Self {
396        Self {
397            strides: [1, 1, 1],
398            dilation_rates: [1, 1, 1],
399            padding_values: [0, 0, 0, 0, 0, 0],
400            padding_style: padding_style::EXPLICIT,
401            channel_dimension_index: -1,
402        }
403    }
404}
405
406opaque_handle!(DepthwiseConvolution3DDescriptor);
407impl DepthwiseConvolution3DDescriptor {
408    #[must_use]
409    pub fn new(info: DepthwiseConvolution3DDescriptorInfo) -> Option<Self> {
410        // SAFETY: all slices stay alive for the duration of the call.
411        let ptr = unsafe {
412            ffi::mpsgraph_depthwise_convolution3d_descriptor_new(
413                info.strides.as_ptr(),
414                info.strides.len(),
415                info.dilation_rates.as_ptr(),
416                info.dilation_rates.len(),
417                info.padding_values.as_ptr(),
418                info.padding_values.len(),
419                info.padding_style,
420                info.channel_dimension_index,
421            )
422        };
423        if ptr.is_null() {
424            None
425        } else {
426            Some(Self { ptr })
427        }
428    }
429}
430
431#[derive(Debug, Clone, Copy)]
432pub struct FftDescriptorInfo {
433    pub inverse: bool,
434    pub scaling_mode: usize,
435    pub round_to_odd_hermitean: bool,
436}
437
438impl Default for FftDescriptorInfo {
439    fn default() -> Self {
440        Self {
441            inverse: false,
442            scaling_mode: fft_scaling_mode::NONE,
443            round_to_odd_hermitean: false,
444        }
445    }
446}
447
448opaque_handle!(FftDescriptor);
449impl FftDescriptor {
450    #[must_use]
451    pub fn new(info: FftDescriptorInfo) -> Option<Self> {
452        // SAFETY: all arguments are POD configuration values.
453        let ptr = unsafe {
454            ffi::mpsgraph_fft_descriptor_new(
455                info.inverse,
456                info.scaling_mode,
457                info.round_to_odd_hermitean,
458            )
459        };
460        if ptr.is_null() {
461            None
462        } else {
463            Some(Self { ptr })
464        }
465    }
466}
467
468#[derive(Debug, Clone, Copy)]
469pub struct ImToColDescriptorInfo {
470    pub kernel_width: usize,
471    pub kernel_height: usize,
472    pub stride_in_x: usize,
473    pub stride_in_y: usize,
474    pub dilation_rate_in_x: usize,
475    pub dilation_rate_in_y: usize,
476    pub padding_left: usize,
477    pub padding_right: usize,
478    pub padding_top: usize,
479    pub padding_bottom: usize,
480    pub data_layout: usize,
481}
482
483impl Default for ImToColDescriptorInfo {
484    fn default() -> Self {
485        Self {
486            kernel_width: 1,
487            kernel_height: 1,
488            stride_in_x: 1,
489            stride_in_y: 1,
490            dilation_rate_in_x: 1,
491            dilation_rate_in_y: 1,
492            padding_left: 0,
493            padding_right: 0,
494            padding_top: 0,
495            padding_bottom: 0,
496            data_layout: tensor_named_data_layout::NHWC,
497        }
498    }
499}
500
501opaque_handle!(ImToColDescriptor);
502impl ImToColDescriptor {
503    #[must_use]
504    pub fn new(info: ImToColDescriptorInfo) -> Option<Self> {
505        // SAFETY: all arguments are POD configuration values.
506        let ptr = unsafe {
507            ffi::mpsgraph_im_to_col_descriptor_new(
508                info.kernel_width,
509                info.kernel_height,
510                info.stride_in_x,
511                info.stride_in_y,
512                info.dilation_rate_in_x,
513                info.dilation_rate_in_y,
514                info.padding_left,
515                info.padding_right,
516                info.padding_top,
517                info.padding_bottom,
518                info.data_layout,
519            )
520        };
521        if ptr.is_null() {
522            None
523        } else {
524            Some(Self { ptr })
525        }
526    }
527}
528
529#[derive(Debug, Clone, Copy)]
530pub struct Pooling4DDescriptorInfo {
531    pub kernel_sizes: [usize; 4],
532    pub strides: [usize; 4],
533    pub dilation_rates: [usize; 4],
534    pub padding_values: [usize; 8],
535    pub padding_style: usize,
536    pub ceil_mode: bool,
537    pub include_zero_pad_to_average: bool,
538    pub return_indices_mode: usize,
539    pub return_indices_data_type: u32,
540}
541
542impl Default for Pooling4DDescriptorInfo {
543    fn default() -> Self {
544        Self {
545            kernel_sizes: [1, 1, 1, 1],
546            strides: [1, 1, 1, 1],
547            dilation_rates: [1, 1, 1, 1],
548            padding_values: [0, 0, 0, 0, 0, 0, 0, 0],
549            padding_style: padding_style::EXPLICIT,
550            ceil_mode: false,
551            include_zero_pad_to_average: false,
552            return_indices_mode: pooling_return_indices_mode::NONE,
553            return_indices_data_type: data_type::INT32,
554        }
555    }
556}
557
558opaque_handle!(Pooling4DDescriptor);
559impl Pooling4DDescriptor {
560    #[must_use]
561    pub fn new(info: Pooling4DDescriptorInfo) -> Option<Self> {
562        // SAFETY: all slices stay alive for the duration of the call.
563        let ptr = unsafe {
564            ffi::mpsgraph_pooling4d_descriptor_new(
565                info.kernel_sizes.as_ptr(),
566                info.kernel_sizes.len(),
567                info.strides.as_ptr(),
568                info.strides.len(),
569                info.dilation_rates.as_ptr(),
570                info.dilation_rates.len(),
571                info.padding_values.as_ptr(),
572                info.padding_values.len(),
573                info.padding_style,
574                info.ceil_mode,
575                info.include_zero_pad_to_average,
576                info.return_indices_mode,
577                info.return_indices_data_type,
578            )
579        };
580        if ptr.is_null() {
581            None
582        } else {
583            Some(Self { ptr })
584        }
585    }
586}
587
588opaque_handle!(CreateSparseDescriptor);
589impl CreateSparseDescriptor {
590    #[must_use]
591    pub fn new(storage_type: u64, data_type: u32) -> Option<Self> {
592        // SAFETY: all arguments are POD configuration values.
593        let ptr = unsafe { ffi::mpsgraph_sparse_descriptor_new(storage_type, data_type) };
594        if ptr.is_null() {
595            None
596        } else {
597            Some(Self { ptr })
598        }
599    }
600}
601
602#[derive(Debug, Clone, Copy)]
603pub struct StencilDescriptorInfo {
604    pub reduction_mode: usize,
605    pub offsets: [isize; 4],
606    pub strides: [usize; 4],
607    pub dilation_rates: [usize; 4],
608    pub explicit_padding: [usize; 8],
609    pub boundary_mode: isize,
610    pub padding_style: usize,
611    pub padding_constant: f32,
612}
613
614impl Default for StencilDescriptorInfo {
615    fn default() -> Self {
616        Self {
617            reduction_mode: reduction_mode::SUM,
618            offsets: [0, 0, 0, 0],
619            strides: [1, 1, 1, 1],
620            dilation_rates: [1, 1, 1, 1],
621            explicit_padding: [0, 0, 0, 0, 0, 0, 0, 0],
622            boundary_mode: padding_mode::ZERO,
623            padding_style: padding_style::EXPLICIT,
624            padding_constant: 0.0,
625        }
626    }
627}
628
629opaque_handle!(StencilDescriptor);
630impl StencilDescriptor {
631    #[must_use]
632    pub fn new(info: StencilDescriptorInfo) -> Option<Self> {
633        // SAFETY: all slices stay alive for the duration of the call.
634        let ptr = unsafe {
635            ffi::mpsgraph_stencil_descriptor_new(
636                info.reduction_mode,
637                info.offsets.as_ptr(),
638                info.offsets.len(),
639                info.strides.as_ptr(),
640                info.strides.len(),
641                info.dilation_rates.as_ptr(),
642                info.dilation_rates.len(),
643                info.explicit_padding.as_ptr(),
644                info.explicit_padding.len(),
645                info.boundary_mode,
646                info.padding_style,
647                info.padding_constant,
648            )
649        };
650        if ptr.is_null() {
651            None
652        } else {
653            Some(Self { ptr })
654        }
655    }
656}
657
658impl Graph {
659    #[must_use]
660    pub fn convolution3d(
661        &self,
662        source: &Tensor,
663        weights: &Tensor,
664        descriptor: &Convolution3DDescriptor,
665        name: Option<&str>,
666    ) -> Option<Tensor> {
667        let name = optional_cstring(name);
668        // SAFETY: all handles remain valid for the duration of the call.
669        let ptr = unsafe {
670            ffi::mpsgraph_graph_convolution3d(
671                self.as_ptr(),
672                source.as_ptr(),
673                weights.as_ptr(),
674                descriptor.as_ptr(),
675                cstring_ptr(&name),
676            )
677        };
678        wrap_tensor(ptr)
679    }
680
681    #[must_use]
682    pub fn convolution_transpose2d(
683        &self,
684        source: &Tensor,
685        weights: &Tensor,
686        output_shape: &[usize],
687        descriptor: &Convolution2DDescriptor,
688        name: Option<&str>,
689    ) -> Option<Tensor> {
690        let name = optional_cstring(name);
691        // SAFETY: all handles and slices remain valid for the duration of the call.
692        let ptr = unsafe {
693            ffi::mpsgraph_graph_convolution_transpose2d(
694                self.as_ptr(),
695                source.as_ptr(),
696                weights.as_ptr(),
697                output_shape.as_ptr(),
698                output_shape.len(),
699                descriptor.as_ptr(),
700                cstring_ptr(&name),
701            )
702        };
703        wrap_tensor(ptr)
704    }
705
706    #[must_use]
707    pub fn cumulative_sum(
708        &self,
709        tensor: &Tensor,
710        axis: isize,
711        exclusive: bool,
712        reverse: bool,
713        name: Option<&str>,
714    ) -> Option<Tensor> {
715        let name = optional_cstring(name);
716        // SAFETY: all handles remain valid for the duration of the call.
717        let ptr = unsafe {
718            ffi::mpsgraph_graph_cumulative_sum(
719                self.as_ptr(),
720                tensor.as_ptr(),
721                axis,
722                exclusive,
723                reverse,
724                cstring_ptr(&name),
725            )
726        };
727        wrap_tensor(ptr)
728    }
729
730    #[must_use]
731    pub fn depthwise_convolution2d(
732        &self,
733        source: &Tensor,
734        weights: &Tensor,
735        descriptor: &DepthwiseConvolution2DDescriptor,
736        name: Option<&str>,
737    ) -> Option<Tensor> {
738        let name = optional_cstring(name);
739        // SAFETY: all handles remain valid for the duration of the call.
740        let ptr = unsafe {
741            ffi::mpsgraph_graph_depthwise_convolution2d(
742                self.as_ptr(),
743                source.as_ptr(),
744                weights.as_ptr(),
745                descriptor.as_ptr(),
746                cstring_ptr(&name),
747            )
748        };
749        wrap_tensor(ptr)
750    }
751
752    #[must_use]
753    pub fn depthwise_convolution3d(
754        &self,
755        source: &Tensor,
756        weights: &Tensor,
757        descriptor: &DepthwiseConvolution3DDescriptor,
758        name: Option<&str>,
759    ) -> Option<Tensor> {
760        let name = optional_cstring(name);
761        // SAFETY: all handles remain valid for the duration of the call.
762        let ptr = unsafe {
763            ffi::mpsgraph_graph_depthwise_convolution3d(
764                self.as_ptr(),
765                source.as_ptr(),
766                weights.as_ptr(),
767                descriptor.as_ptr(),
768                cstring_ptr(&name),
769            )
770        };
771        wrap_tensor(ptr)
772    }
773
774    #[must_use]
775    pub fn fast_fourier_transform(
776        &self,
777        tensor: &Tensor,
778        axes: &[usize],
779        descriptor: &FftDescriptor,
780        name: Option<&str>,
781    ) -> Option<Tensor> {
782        let name = optional_cstring(name);
783        // SAFETY: all handles and slices remain valid for the duration of the call.
784        let ptr = unsafe {
785            ffi::mpsgraph_graph_fast_fourier_transform(
786                self.as_ptr(),
787                tensor.as_ptr(),
788                axes.as_ptr(),
789                axes.len(),
790                descriptor.as_ptr(),
791                cstring_ptr(&name),
792            )
793        };
794        wrap_tensor(ptr)
795    }
796
797    #[must_use]
798    pub fn im_to_col(
799        &self,
800        source: &Tensor,
801        descriptor: &ImToColDescriptor,
802        name: Option<&str>,
803    ) -> Option<Tensor> {
804        let name = optional_cstring(name);
805        // SAFETY: all handles remain valid for the duration of the call.
806        let ptr = unsafe {
807            ffi::mpsgraph_graph_im_to_col(
808                self.as_ptr(),
809                source.as_ptr(),
810                descriptor.as_ptr(),
811                cstring_ptr(&name),
812            )
813        };
814        wrap_tensor(ptr)
815    }
816
817    #[must_use]
818    pub fn band_part(
819        &self,
820        tensor: &Tensor,
821        num_lower: isize,
822        num_upper: isize,
823        name: Option<&str>,
824    ) -> Option<Tensor> {
825        let name = optional_cstring(name);
826        // SAFETY: all handles remain valid for the duration of the call.
827        let ptr = unsafe {
828            ffi::mpsgraph_graph_band_part(
829                self.as_ptr(),
830                tensor.as_ptr(),
831                num_lower,
832                num_upper,
833                cstring_ptr(&name),
834            )
835        };
836        wrap_tensor(ptr)
837    }
838
839    #[must_use]
840    pub fn softmax_cross_entropy(
841        &self,
842        source: &Tensor,
843        labels: &Tensor,
844        axis: isize,
845        reduction_type: u64,
846        name: Option<&str>,
847    ) -> Option<Tensor> {
848        let name = optional_cstring(name);
849        // SAFETY: all handles remain valid for the duration of the call.
850        let ptr = unsafe {
851            ffi::mpsgraph_graph_softmax_cross_entropy(
852                self.as_ptr(),
853                source.as_ptr(),
854                labels.as_ptr(),
855                axis,
856                reduction_type,
857                cstring_ptr(&name),
858            )
859        };
860        wrap_tensor(ptr)
861    }
862
863    #[must_use]
864    pub fn matrix_inverse(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
865        let name = optional_cstring(name);
866        // SAFETY: all handles remain valid for the duration of the call.
867        let ptr = unsafe {
868            ffi::mpsgraph_graph_matrix_inverse(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
869        };
870        wrap_tensor(ptr)
871    }
872
873    #[must_use]
874    pub fn variable_bytes(
875        &self,
876        data: &[u8],
877        shape: &[usize],
878        data_type: u32,
879        name: Option<&str>,
880    ) -> Option<Tensor> {
881        let expected = checked_byte_len(shape, data_type)?;
882        if data.len() != expected {
883            return None;
884        }
885
886        let name = optional_cstring(name);
887        // SAFETY: all handles and slices remain valid for the duration of the call.
888        let ptr = unsafe {
889            ffi::mpsgraph_graph_variable_data(
890                self.as_ptr(),
891                data.as_ptr().cast(),
892                data.len(),
893                shape.as_ptr(),
894                shape.len(),
895                data_type,
896                cstring_ptr(&name),
897            )
898        };
899        wrap_tensor(ptr)
900    }
901
902    #[must_use]
903    pub fn variable_f32_slice(
904        &self,
905        values: &[f32],
906        shape: &[usize],
907        name: Option<&str>,
908    ) -> Option<Tensor> {
909        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
910        let bytes = unsafe {
911            core::slice::from_raw_parts(
912                values.as_ptr().cast::<u8>(),
913                core::mem::size_of_val(values),
914            )
915        };
916        self.variable_bytes(bytes, shape, data_type::FLOAT32, name)
917    }
918
919    #[must_use]
920    pub fn read_variable(&self, variable: &Tensor, name: Option<&str>) -> Option<Tensor> {
921        let name = optional_cstring(name);
922        // SAFETY: all handles remain valid for the duration of the call.
923        let ptr = unsafe {
924            ffi::mpsgraph_graph_read_variable(self.as_ptr(), variable.as_ptr(), cstring_ptr(&name))
925        };
926        wrap_tensor(ptr)
927    }
928
929    #[must_use]
930    pub fn assign_variable(
931        &self,
932        variable: &Tensor,
933        value: &Tensor,
934        name: Option<&str>,
935    ) -> Option<Operation> {
936        let name = optional_cstring(name);
937        // SAFETY: all handles remain valid for the duration of the call.
938        let ptr = unsafe {
939            ffi::mpsgraph_graph_assign_variable(
940                self.as_ptr(),
941                variable.as_ptr(),
942                value.as_ptr(),
943                cstring_ptr(&name),
944            )
945        };
946        wrap_operation(ptr)
947    }
948
949    #[must_use]
950    #[allow(clippy::too_many_arguments)]
951    pub fn non_maximum_suppression(
952        &self,
953        boxes: &Tensor,
954        scores: &Tensor,
955        iou_threshold: f32,
956        score_threshold: f32,
957        per_class_suppression: bool,
958        coordinate_mode: usize,
959        name: Option<&str>,
960    ) -> Option<Tensor> {
961        let name = optional_cstring(name);
962        // SAFETY: all handles remain valid for the duration of the call.
963        let ptr = unsafe {
964            ffi::mpsgraph_graph_non_maximum_suppression(
965                self.as_ptr(),
966                boxes.as_ptr(),
967                scores.as_ptr(),
968                iou_threshold,
969                score_threshold,
970                per_class_suppression,
971                coordinate_mode,
972                cstring_ptr(&name),
973            )
974        };
975        wrap_tensor(ptr)
976    }
977
978    #[must_use]
979    pub fn non_zero_indices(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
980        let name = optional_cstring(name);
981        // SAFETY: all handles remain valid for the duration of the call.
982        let ptr = unsafe {
983            ffi::mpsgraph_graph_non_zero_indices(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
984        };
985        wrap_tensor(ptr)
986    }
987
988    #[must_use]
989    pub fn one_hot(
990        &self,
991        indices: &Tensor,
992        depth: usize,
993        data_type: u32,
994        name: Option<&str>,
995    ) -> Option<Tensor> {
996        let name = optional_cstring(name);
997        // SAFETY: all handles remain valid for the duration of the call.
998        let ptr = unsafe {
999            ffi::mpsgraph_graph_one_hot(
1000                self.as_ptr(),
1001                indices.as_ptr(),
1002                depth,
1003                data_type,
1004                cstring_ptr(&name),
1005            )
1006        };
1007        wrap_tensor(ptr)
1008    }
1009
1010    #[must_use]
1011    pub fn stochastic_gradient_descent(
1012        &self,
1013        learning_rate: &Tensor,
1014        values: &Tensor,
1015        gradient: &Tensor,
1016        name: Option<&str>,
1017    ) -> Option<Tensor> {
1018        let name = optional_cstring(name);
1019        // SAFETY: all handles remain valid for the duration of the call.
1020        let ptr = unsafe {
1021            ffi::mpsgraph_graph_stochastic_gradient_descent(
1022                self.as_ptr(),
1023                learning_rate.as_ptr(),
1024                values.as_ptr(),
1025                gradient.as_ptr(),
1026                cstring_ptr(&name),
1027            )
1028        };
1029        wrap_tensor(ptr)
1030    }
1031
1032    #[must_use]
1033    pub fn max_pooling4d(
1034        &self,
1035        source: &Tensor,
1036        descriptor: &Pooling4DDescriptor,
1037        name: Option<&str>,
1038    ) -> Option<Tensor> {
1039        let name = optional_cstring(name);
1040        // SAFETY: all handles remain valid for the duration of the call.
1041        let ptr = unsafe {
1042            ffi::mpsgraph_graph_max_pooling4d(
1043                self.as_ptr(),
1044                source.as_ptr(),
1045                descriptor.as_ptr(),
1046                cstring_ptr(&name),
1047            )
1048        };
1049        wrap_tensor(ptr)
1050    }
1051
1052    #[must_use]
1053    pub fn max_pooling4d_return_indices(
1054        &self,
1055        source: &Tensor,
1056        descriptor: &Pooling4DDescriptor,
1057        name: Option<&str>,
1058    ) -> Option<(Tensor, Tensor)> {
1059        let name = optional_cstring(name);
1060        // SAFETY: all handles remain valid for the duration of the call.
1061        let box_handle = unsafe {
1062            ffi::mpsgraph_graph_max_pooling4d_return_indices(
1063                self.as_ptr(),
1064                source.as_ptr(),
1065                descriptor.as_ptr(),
1066                cstring_ptr(&name),
1067            )
1068        };
1069        wrap_tensor_pair(box_handle)
1070    }
1071
1072    #[must_use]
1073    pub fn quantize(
1074        &self,
1075        tensor: &Tensor,
1076        scale: f64,
1077        zero_point: f64,
1078        data_type: u32,
1079        name: Option<&str>,
1080    ) -> Option<Tensor> {
1081        let name = optional_cstring(name);
1082        // SAFETY: all handles remain valid for the duration of the call.
1083        let ptr = unsafe {
1084            ffi::mpsgraph_graph_quantize(
1085                self.as_ptr(),
1086                tensor.as_ptr(),
1087                scale,
1088                zero_point,
1089                data_type,
1090                cstring_ptr(&name),
1091            )
1092        };
1093        wrap_tensor(ptr)
1094    }
1095
1096    #[must_use]
1097    pub fn dequantize(
1098        &self,
1099        tensor: &Tensor,
1100        scale: f64,
1101        zero_point: f64,
1102        data_type: u32,
1103        name: Option<&str>,
1104    ) -> Option<Tensor> {
1105        let name = optional_cstring(name);
1106        // SAFETY: all handles remain valid for the duration of the call.
1107        let ptr = unsafe {
1108            ffi::mpsgraph_graph_dequantize(
1109                self.as_ptr(),
1110                tensor.as_ptr(),
1111                scale,
1112                zero_point,
1113                data_type,
1114                cstring_ptr(&name),
1115            )
1116        };
1117        wrap_tensor(ptr)
1118    }
1119
1120    #[must_use]
1121    #[allow(clippy::too_many_arguments)]
1122    pub fn resize(
1123        &self,
1124        images: &Tensor,
1125        size: &[usize],
1126        mode: usize,
1127        center_result: bool,
1128        align_corners: bool,
1129        layout: usize,
1130        name: Option<&str>,
1131    ) -> Option<Tensor> {
1132        let name = optional_cstring(name);
1133        // SAFETY: all handles and slices remain valid for the duration of the call.
1134        let ptr = unsafe {
1135            ffi::mpsgraph_graph_resize(
1136                self.as_ptr(),
1137                images.as_ptr(),
1138                size.as_ptr(),
1139                size.len(),
1140                mode,
1141                center_result,
1142                align_corners,
1143                layout,
1144                cstring_ptr(&name),
1145            )
1146        };
1147        wrap_tensor(ptr)
1148    }
1149
1150    #[must_use]
1151    #[allow(clippy::too_many_arguments)]
1152    pub fn resize_nearest(
1153        &self,
1154        images: &Tensor,
1155        size_tensor: &Tensor,
1156        nearest_rounding_mode: usize,
1157        center_result: bool,
1158        align_corners: bool,
1159        layout: usize,
1160        name: Option<&str>,
1161    ) -> Option<Tensor> {
1162        let name = optional_cstring(name);
1163        // SAFETY: all handles remain valid for the duration of the call.
1164        let ptr = unsafe {
1165            ffi::mpsgraph_graph_resize_nearest(
1166                self.as_ptr(),
1167                images.as_ptr(),
1168                size_tensor.as_ptr(),
1169                nearest_rounding_mode,
1170                center_result,
1171                align_corners,
1172                layout,
1173                cstring_ptr(&name),
1174            )
1175        };
1176        wrap_tensor(ptr)
1177    }
1178
1179    #[must_use]
1180    #[allow(clippy::too_many_arguments)]
1181    pub fn sample_grid(
1182        &self,
1183        source: &Tensor,
1184        coordinates: &Tensor,
1185        layout: usize,
1186        normalize_coordinates: bool,
1187        relative_coordinates: bool,
1188        align_corners: bool,
1189        padding_mode: isize,
1190        sampling_mode: usize,
1191        constant_value: f64,
1192        name: Option<&str>,
1193    ) -> Option<Tensor> {
1194        let name = optional_cstring(name);
1195        // SAFETY: all handles remain valid for the duration of the call.
1196        let ptr = unsafe {
1197            ffi::mpsgraph_graph_sample_grid(
1198                self.as_ptr(),
1199                source.as_ptr(),
1200                coordinates.as_ptr(),
1201                layout,
1202                normalize_coordinates,
1203                relative_coordinates,
1204                align_corners,
1205                padding_mode,
1206                sampling_mode,
1207                constant_value,
1208                cstring_ptr(&name),
1209            )
1210        };
1211        wrap_tensor(ptr)
1212    }
1213
1214    #[must_use]
1215    pub fn scatter_nd(
1216        &self,
1217        updates: &Tensor,
1218        indices: &Tensor,
1219        shape: &[usize],
1220        batch_dimensions: usize,
1221        mode: isize,
1222        name: Option<&str>,
1223    ) -> Option<Tensor> {
1224        let name = optional_cstring(name);
1225        // SAFETY: all handles and slices remain valid for the duration of the call.
1226        let ptr = unsafe {
1227            ffi::mpsgraph_graph_scatter_nd(
1228                self.as_ptr(),
1229                updates.as_ptr(),
1230                indices.as_ptr(),
1231                shape.as_ptr(),
1232                shape.len(),
1233                batch_dimensions,
1234                mode,
1235                cstring_ptr(&name),
1236            )
1237        };
1238        wrap_tensor(ptr)
1239    }
1240
1241    #[must_use]
1242    pub fn scatter(
1243        &self,
1244        updates: &Tensor,
1245        indices: &Tensor,
1246        shape: &[usize],
1247        axis: isize,
1248        mode: isize,
1249        name: Option<&str>,
1250    ) -> Option<Tensor> {
1251        let name = optional_cstring(name);
1252        // SAFETY: all handles and slices remain valid for the duration of the call.
1253        let ptr = unsafe {
1254            ffi::mpsgraph_graph_scatter(
1255                self.as_ptr(),
1256                updates.as_ptr(),
1257                indices.as_ptr(),
1258                shape.as_ptr(),
1259                shape.len(),
1260                axis,
1261                mode,
1262                cstring_ptr(&name),
1263            )
1264        };
1265        wrap_tensor(ptr)
1266    }
1267
1268    #[must_use]
1269    pub fn scatter_along_axis(
1270        &self,
1271        axis: isize,
1272        updates: &Tensor,
1273        indices: &Tensor,
1274        shape: &[usize],
1275        mode: isize,
1276        name: Option<&str>,
1277    ) -> Option<Tensor> {
1278        let name = optional_cstring(name);
1279        // SAFETY: all handles and slices remain valid for the duration of the call.
1280        let ptr = unsafe {
1281            ffi::mpsgraph_graph_scatter_along_axis(
1282                self.as_ptr(),
1283                axis,
1284                updates.as_ptr(),
1285                indices.as_ptr(),
1286                shape.as_ptr(),
1287                shape.len(),
1288                mode,
1289                cstring_ptr(&name),
1290            )
1291        };
1292        wrap_tensor(ptr)
1293    }
1294
1295    #[must_use]
1296    pub fn sort(
1297        &self,
1298        tensor: &Tensor,
1299        axis: isize,
1300        descending: bool,
1301        name: Option<&str>,
1302    ) -> Option<Tensor> {
1303        let name = optional_cstring(name);
1304        // SAFETY: all handles remain valid for the duration of the call.
1305        let ptr = unsafe {
1306            ffi::mpsgraph_graph_sort(
1307                self.as_ptr(),
1308                tensor.as_ptr(),
1309                axis,
1310                descending,
1311                cstring_ptr(&name),
1312            )
1313        };
1314        wrap_tensor(ptr)
1315    }
1316
1317    #[must_use]
1318    pub fn arg_sort(
1319        &self,
1320        tensor: &Tensor,
1321        axis: isize,
1322        descending: bool,
1323        name: Option<&str>,
1324    ) -> Option<Tensor> {
1325        let name = optional_cstring(name);
1326        // SAFETY: all handles remain valid for the duration of the call.
1327        let ptr = unsafe {
1328            ffi::mpsgraph_graph_arg_sort(
1329                self.as_ptr(),
1330                tensor.as_ptr(),
1331                axis,
1332                descending,
1333                cstring_ptr(&name),
1334            )
1335        };
1336        wrap_tensor(ptr)
1337    }
1338
1339    #[must_use]
1340    pub fn sparse_tensor_with_descriptor(
1341        &self,
1342        descriptor: &CreateSparseDescriptor,
1343        tensors: &[&Tensor],
1344        shape: &[usize],
1345        name: Option<&str>,
1346    ) -> Option<Tensor> {
1347        let name = optional_cstring(name);
1348        let handles = tensors
1349            .iter()
1350            .map(|tensor| tensor.as_ptr())
1351            .collect::<Vec<_>>();
1352        // SAFETY: all handles and slices remain valid for the duration of the call.
1353        let ptr = unsafe {
1354            ffi::mpsgraph_graph_sparse_tensor_with_descriptor(
1355                self.as_ptr(),
1356                descriptor.as_ptr(),
1357                handles.as_ptr(),
1358                handles.len(),
1359                shape.as_ptr(),
1360                shape.len(),
1361                cstring_ptr(&name),
1362            )
1363        };
1364        wrap_tensor(ptr)
1365    }
1366
1367    #[must_use]
1368    pub fn stencil(
1369        &self,
1370        source: &Tensor,
1371        weights: &Tensor,
1372        descriptor: &StencilDescriptor,
1373        name: Option<&str>,
1374    ) -> Option<Tensor> {
1375        let name = optional_cstring(name);
1376        // SAFETY: all handles remain valid for the duration of the call.
1377        let ptr = unsafe {
1378            ffi::mpsgraph_graph_stencil(
1379                self.as_ptr(),
1380                source.as_ptr(),
1381                weights.as_ptr(),
1382                descriptor.as_ptr(),
1383                cstring_ptr(&name),
1384            )
1385        };
1386        wrap_tensor(ptr)
1387    }
1388
1389    #[must_use]
1390    pub fn top_k_gradient(
1391        &self,
1392        gradient: &Tensor,
1393        source: &Tensor,
1394        k: usize,
1395        name: Option<&str>,
1396    ) -> Option<Tensor> {
1397        let name = optional_cstring(name);
1398        // SAFETY: all handles remain valid for the duration of the call.
1399        let ptr = unsafe {
1400            ffi::mpsgraph_graph_topk_gradient(
1401                self.as_ptr(),
1402                gradient.as_ptr(),
1403                source.as_ptr(),
1404                k,
1405                cstring_ptr(&name),
1406            )
1407        };
1408        wrap_tensor(ptr)
1409    }
1410}
1411
1412impl ExecutionDescriptor {
1413    /// # Safety
1414    ///
1415    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1416    pub unsafe fn wait_for_shared_event_raw(
1417        &self,
1418        event_handle: *mut c_void,
1419        value: u64,
1420    ) -> Result<()> {
1421        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1422        let ok = unsafe {
1423            ffi::mpsgraph_execution_descriptor_wait_for_event(self.as_ptr(), event_handle, value)
1424        };
1425        if ok {
1426            Ok(())
1427        } else {
1428            Err(Error::OperationFailed(
1429                "failed to register execution descriptor shared-event wait",
1430            ))
1431        }
1432    }
1433
1434    /// # Safety
1435    ///
1436    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1437    pub unsafe fn signal_shared_event_raw(
1438        &self,
1439        event_handle: *mut c_void,
1440        execution_stage: u64,
1441        value: u64,
1442    ) -> Result<()> {
1443        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1444        let ok = unsafe {
1445            ffi::mpsgraph_execution_descriptor_signal_event(
1446                self.as_ptr(),
1447                event_handle,
1448                execution_stage,
1449                value,
1450            )
1451        };
1452        if ok {
1453            Ok(())
1454        } else {
1455            Err(Error::OperationFailed(
1456                "failed to register execution descriptor shared-event signal",
1457            ))
1458        }
1459    }
1460}
1461
1462impl ExecutableExecutionDescriptor {
1463    /// # Safety
1464    ///
1465    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1466    pub unsafe fn wait_for_shared_event_raw(
1467        &self,
1468        event_handle: *mut c_void,
1469        value: u64,
1470    ) -> Result<()> {
1471        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1472        let ok = unsafe {
1473            ffi::mpsgraph_executable_execution_descriptor_wait_for_event(
1474                self.as_ptr(),
1475                event_handle,
1476                value,
1477            )
1478        };
1479        if ok {
1480            Ok(())
1481        } else {
1482            Err(Error::OperationFailed(
1483                "failed to register executable execution descriptor shared-event wait",
1484            ))
1485        }
1486    }
1487
1488    /// # Safety
1489    ///
1490    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1491    pub unsafe fn signal_shared_event_raw(
1492        &self,
1493        event_handle: *mut c_void,
1494        execution_stage: u64,
1495        value: u64,
1496    ) -> Result<()> {
1497        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1498        let ok = unsafe {
1499            ffi::mpsgraph_executable_execution_descriptor_signal_event(
1500                self.as_ptr(),
1501                event_handle,
1502                execution_stage,
1503                value,
1504            )
1505        };
1506        if ok {
1507            Ok(())
1508        } else {
1509            Err(Error::OperationFailed(
1510                "failed to register executable execution descriptor shared-event signal",
1511            ))
1512        }
1513    }
1514}