Skip to main content

apple_mpsgraph/
specialized.rs

1use crate::error::{Error, Result};
2use crate::execution::{ExecutableExecutionDescriptor, ExecutionDescriptor};
3use crate::ffi;
4use crate::graph::{
5    data_type, data_type_size, padding_mode, padding_style, tensor_named_data_layout,
6    Convolution2DDescriptor, Graph, Tensor,
7};
8use crate::types::{collect_owned_tensors, Operation, ShapedType};
9use core::ffi::{c_char, c_void};
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14    if !ptr.is_null() {
15        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
16        unsafe { ffi::mpsgraph_object_release(*ptr) };
17        *ptr = ptr::null_mut();
18    }
19}
20
21fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
22    let element_size = data_type_size(data_type)?;
23    shape
24        .iter()
25        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
26}
27
28fn optional_cstring(name: Option<&str>) -> Option<CString> {
29    name.and_then(|value| CString::new(value).ok())
30}
31
32#[allow(clippy::ref_option)]
33fn cstring_ptr(value: &Option<CString>) -> *const c_char {
34    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
35}
36
37fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
38    if ptr.is_null() {
39        None
40    } else {
41        Some(Tensor::from_raw(ptr))
42    }
43}
44
45fn wrap_operation(ptr: *mut c_void) -> Option<Operation> {
46    if ptr.is_null() {
47        None
48    } else {
49        Some(Operation::from_raw(ptr))
50    }
51}
52
53fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
54    let mut values = collect_owned_tensors(box_handle);
55    if values.len() != 2 {
56        return None;
57    }
58    let second = values.pop()?;
59    let first = values.pop()?;
60    Some((first, second))
61}
62
63macro_rules! opaque_handle {
64    ($name:ident) => {
65/// Mirrors the `MPSGraph` framework counterpart for this type.
66        pub struct $name {
67            ptr: *mut c_void,
68        }
69
70        unsafe impl Send for $name {}
71        unsafe impl Sync for $name {}
72
73        impl Drop for $name {
74            fn drop(&mut self) {
75                release_handle(&mut self.ptr);
76            }
77        }
78
79        impl $name {
80/// Mirrors the `MPSGraph` framework constant `fn`.
81            #[must_use]
82            pub const fn as_ptr(&self) -> *mut c_void {
83                self.ptr
84            }
85        }
86    };
87}
88
89/// `MPSGraphExecutionStage` constants.
90pub mod execution_stage {
91/// Mirrors the `MPSGraph` framework constant `COMPLETED`.
92    pub const COMPLETED: u64 = 0;
93}
94
95/// `MPSGraphReductionMode` constants.
96pub mod reduction_mode {
97/// Mirrors the `MPSGraph` framework constant `MIN`.
98    pub const MIN: usize = 0;
99/// Mirrors the `MPSGraph` framework constant `MAX`.
100    pub const MAX: usize = 1;
101/// Mirrors the `MPSGraph` framework constant `SUM`.
102    pub const SUM: usize = 2;
103/// Mirrors the `MPSGraph` framework constant `PRODUCT`.
104    pub const PRODUCT: usize = 3;
105/// Mirrors the `MPSGraph` framework constant `ARGUMENT_MIN`.
106    pub const ARGUMENT_MIN: usize = 4;
107/// Mirrors the `MPSGraph` framework constant `ARGUMENT_MAX`.
108    pub const ARGUMENT_MAX: usize = 5;
109}
110
111/// `MPSGraphPoolingReturnIndicesMode` constants.
112pub mod pooling_return_indices_mode {
113/// Mirrors the `MPSGraph` framework constant `NONE`.
114    pub const NONE: usize = 0;
115/// Mirrors the `MPSGraph` framework constant `GLOBAL_FLATTEN_1D`.
116    pub const GLOBAL_FLATTEN_1D: usize = 1;
117/// Mirrors the `MPSGraph` framework constant `GLOBAL_FLATTEN_2D`.
118    pub const GLOBAL_FLATTEN_2D: usize = 2;
119/// Mirrors the `MPSGraph` framework constant `GLOBAL_FLATTEN_3D`.
120    pub const GLOBAL_FLATTEN_3D: usize = 3;
121/// Mirrors the `MPSGraph` framework constant `GLOBAL_FLATTEN_4D`.
122    pub const GLOBAL_FLATTEN_4D: usize = 4;
123/// Mirrors the `MPSGraph` framework constant `LOCAL_FLATTEN_1D`.
124    pub const LOCAL_FLATTEN_1D: usize = 5;
125/// Mirrors the `MPSGraph` framework constant `LOCAL_FLATTEN_2D`.
126    pub const LOCAL_FLATTEN_2D: usize = 6;
127/// Mirrors the `MPSGraph` framework constant `LOCAL_FLATTEN_3D`.
128    pub const LOCAL_FLATTEN_3D: usize = 7;
129/// Mirrors the `MPSGraph` framework constant `LOCAL_FLATTEN_4D`.
130    pub const LOCAL_FLATTEN_4D: usize = 8;
131}
132
133/// `MPSGraphFFTScalingMode` constants.
134pub mod fft_scaling_mode {
135/// Mirrors the `MPSGraph` framework constant `NONE`.
136    pub const NONE: usize = 0;
137/// Mirrors the `MPSGraph` framework constant `SIZE`.
138    pub const SIZE: usize = 1;
139/// Mirrors the `MPSGraph` framework constant `UNITARY`.
140    pub const UNITARY: usize = 2;
141}
142
143/// `MPSGraphLossReductionType` constants.
144pub mod loss_reduction_type {
145/// Mirrors the `MPSGraph` framework constant `NONE`.
146    pub const NONE: u64 = 0;
147/// Mirrors the `MPSGraph` framework constant `AXIS`.
148    pub const AXIS: u64 = 0;
149/// Mirrors the `MPSGraph` framework constant `SUM`.
150    pub const SUM: u64 = 1;
151/// Mirrors the `MPSGraph` framework constant `MEAN`.
152    pub const MEAN: u64 = 2;
153}
154
155/// `MPSGraphNonMaximumSuppressionCoordinateMode` constants.
156pub mod non_maximum_suppression_coordinate_mode {
157/// Mirrors the `MPSGraph` framework constant `CORNERS_HEIGHT_FIRST`.
158    pub const CORNERS_HEIGHT_FIRST: usize = 0;
159/// Mirrors the `MPSGraph` framework constant `CORNERS_WIDTH_FIRST`.
160    pub const CORNERS_WIDTH_FIRST: usize = 1;
161/// Mirrors the `MPSGraph` framework constant `CENTERS_HEIGHT_FIRST`.
162    pub const CENTERS_HEIGHT_FIRST: usize = 2;
163/// Mirrors the `MPSGraph` framework constant `CENTERS_WIDTH_FIRST`.
164    pub const CENTERS_WIDTH_FIRST: usize = 3;
165}
166
167/// `MPSGraphResizeMode` constants.
168pub mod resize_mode {
169/// Mirrors the `MPSGraph` framework constant `NEAREST`.
170    pub const NEAREST: usize = 0;
171/// Mirrors the `MPSGraph` framework constant `BILINEAR`.
172    pub const BILINEAR: usize = 1;
173}
174
175/// `MPSGraphResizeNearestRoundingMode` constants.
176pub mod resize_nearest_rounding_mode {
177/// Mirrors the `MPSGraph` framework constant `ROUND_PREFER_CEIL`.
178    pub const ROUND_PREFER_CEIL: usize = 0;
179/// Mirrors the `MPSGraph` framework constant `ROUND_PREFER_FLOOR`.
180    pub const ROUND_PREFER_FLOOR: usize = 1;
181/// Mirrors the `MPSGraph` framework constant `CEIL`.
182    pub const CEIL: usize = 2;
183/// Mirrors the `MPSGraph` framework constant `FLOOR`.
184    pub const FLOOR: usize = 3;
185/// Mirrors the `MPSGraph` framework constant `ROUND_TO_EVEN`.
186    pub const ROUND_TO_EVEN: usize = 4;
187/// Mirrors the `MPSGraph` framework constant `ROUND_TO_ODD`.
188    pub const ROUND_TO_ODD: usize = 5;
189}
190
191/// `MPSGraphScatterMode` constants.
192pub mod scatter_mode {
193/// Mirrors the `MPSGraph` framework constant `ADD`.
194    pub const ADD: isize = 0;
195/// Mirrors the `MPSGraph` framework constant `SUB`.
196    pub const SUB: isize = 1;
197/// Mirrors the `MPSGraph` framework constant `MUL`.
198    pub const MUL: isize = 2;
199/// Mirrors the `MPSGraph` framework constant `DIV`.
200    pub const DIV: isize = 3;
201/// Mirrors the `MPSGraph` framework constant `MIN`.
202    pub const MIN: isize = 4;
203/// Mirrors the `MPSGraph` framework constant `MAX`.
204    pub const MAX: isize = 5;
205/// Mirrors the `MPSGraph` framework constant `SET`.
206    pub const SET: isize = 6;
207}
208
209/// `MPSGraphSparseStorageType` constants.
210pub mod sparse_storage_type {
211/// Mirrors the `MPSGraph` framework constant `COO`.
212    pub const COO: u64 = 0;
213/// Mirrors the `MPSGraph` framework constant `CSC`.
214    pub const CSC: u64 = 1;
215/// Mirrors the `MPSGraph` framework constant `CSR`.
216    pub const CSR: u64 = 2;
217}
218
219opaque_handle!(Object);
220impl Object {
221    fn retain_from(ptr: *mut c_void) -> Self {
222        // SAFETY: `ptr` belongs to a live `MPSGraphObject` subclass and the bridge retains it for this wrapper.
223        let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
224        Self { ptr }
225    }
226}
227
228opaque_handle!(GraphType);
229impl GraphType {
230    fn retain_from(ptr: *mut c_void) -> Self {
231        // SAFETY: `ptr` belongs to a live `MPSGraphType` subclass and the bridge retains it for this wrapper.
232        let ptr = unsafe { ffi::mpsgraph_object_retain(ptr) };
233        Self { ptr }
234    }
235
236/// Calls the `MPSGraph` framework counterpart for `as_object`.
237    #[must_use]
238    pub fn as_object(&self) -> Object {
239        Object::retain_from(self.ptr)
240    }
241}
242
243opaque_handle!(VariableOp);
244impl VariableOp {
245/// Calls the `MPSGraph` framework counterpart for `shape`.
246    #[must_use]
247    pub fn shape(&self) -> Vec<isize> {
248        // SAFETY: `self.ptr` is a live variable-op handle.
249        let len = unsafe { ffi::mpsgraph_variable_op_shape_len(self.ptr) };
250        let mut shape = vec![0_isize; len];
251        if len > 0 {
252            // SAFETY: `shape` has space for exactly `len` elements.
253            unsafe { ffi::mpsgraph_variable_op_copy_shape(self.ptr, shape.as_mut_ptr()) };
254        }
255        shape
256    }
257
258/// Calls the `MPSGraph` framework counterpart for `data_type`.
259    #[must_use]
260    pub fn data_type(&self) -> u32 {
261        // SAFETY: `self.ptr` is a live variable-op handle.
262        unsafe { ffi::mpsgraph_variable_op_data_type(self.ptr) }
263    }
264
265/// Calls the `MPSGraph` framework counterpart for `as_object`.
266    #[must_use]
267    pub fn as_object(&self) -> Object {
268        Object::retain_from(self.ptr)
269    }
270
271/// Calls the `MPSGraph` framework counterpart for `as_operation`.
272    #[must_use]
273    pub fn as_operation(&self) -> Operation {
274        // SAFETY: `self.ptr` is a live variable-op handle and retains as an operation wrapper.
275        let ptr = unsafe { ffi::mpsgraph_object_retain(self.ptr) };
276        Operation::from_raw(ptr)
277    }
278}
279
280impl ShapedType {
281/// Calls the `MPSGraph` framework counterpart for `as_graph_type`.
282    #[must_use]
283    pub fn as_graph_type(&self) -> GraphType {
284        GraphType::retain_from(self.as_ptr())
285    }
286}
287
288impl Operation {
289/// Calls the `MPSGraph` framework counterpart for `as_variable`.
290    #[must_use]
291    pub fn as_variable(&self) -> Option<VariableOp> {
292        // SAFETY: `self.ptr` is a live operation handle.
293        let ptr = unsafe { ffi::mpsgraph_operation_as_variable(self.as_ptr()) };
294        if ptr.is_null() {
295            None
296        } else {
297            Some(VariableOp { ptr })
298        }
299    }
300}
301
302/// Mirrors the `MPSGraph` framework counterpart for `Convolution3DDescriptorInfo`.
303#[derive(Debug, Clone, Copy)]
304pub struct Convolution3DDescriptorInfo {
305/// Mirrors the `MPSGraph` framework property for `stride_in_x`.
306    pub stride_in_x: usize,
307/// Mirrors the `MPSGraph` framework property for `stride_in_y`.
308    pub stride_in_y: usize,
309/// Mirrors the `MPSGraph` framework property for `stride_in_z`.
310    pub stride_in_z: usize,
311/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_x`.
312    pub dilation_rate_in_x: usize,
313/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_y`.
314    pub dilation_rate_in_y: usize,
315/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_z`.
316    pub dilation_rate_in_z: usize,
317/// Mirrors the `MPSGraph` framework property for `groups`.
318    pub groups: usize,
319/// Mirrors the `MPSGraph` framework property for `padding_left`.
320    pub padding_left: usize,
321/// Mirrors the `MPSGraph` framework property for `padding_right`.
322    pub padding_right: usize,
323/// Mirrors the `MPSGraph` framework property for `padding_top`.
324    pub padding_top: usize,
325/// Mirrors the `MPSGraph` framework property for `padding_bottom`.
326    pub padding_bottom: usize,
327/// Mirrors the `MPSGraph` framework property for `padding_front`.
328    pub padding_front: usize,
329/// Mirrors the `MPSGraph` framework property for `padding_back`.
330    pub padding_back: usize,
331/// Mirrors the `MPSGraph` framework property for `padding_style`.
332    pub padding_style: usize,
333/// Mirrors the `MPSGraph` framework property for `data_layout`.
334    pub data_layout: usize,
335/// Mirrors the `MPSGraph` framework property for `weights_layout`.
336    pub weights_layout: usize,
337}
338
339impl Default for Convolution3DDescriptorInfo {
340    fn default() -> Self {
341        Self {
342            stride_in_x: 1,
343            stride_in_y: 1,
344            stride_in_z: 1,
345            dilation_rate_in_x: 1,
346            dilation_rate_in_y: 1,
347            dilation_rate_in_z: 1,
348            groups: 1,
349            padding_left: 0,
350            padding_right: 0,
351            padding_top: 0,
352            padding_bottom: 0,
353            padding_front: 0,
354            padding_back: 0,
355            padding_style: padding_style::EXPLICIT,
356            data_layout: tensor_named_data_layout::NDHWC,
357            weights_layout: tensor_named_data_layout::DHWIO,
358        }
359    }
360}
361
362opaque_handle!(Convolution3DDescriptor);
363impl Convolution3DDescriptor {
364/// Calls the `MPSGraph` framework counterpart for `new`.
365    #[must_use]
366    pub fn new(info: Convolution3DDescriptorInfo) -> Option<Self> {
367        // SAFETY: all arguments are POD configuration values.
368        let ptr = unsafe {
369            ffi::mpsgraph_convolution3d_descriptor_new(
370                info.stride_in_x,
371                info.stride_in_y,
372                info.stride_in_z,
373                info.dilation_rate_in_x,
374                info.dilation_rate_in_y,
375                info.dilation_rate_in_z,
376                info.groups,
377                info.padding_left,
378                info.padding_right,
379                info.padding_top,
380                info.padding_bottom,
381                info.padding_front,
382                info.padding_back,
383                info.padding_style,
384                info.data_layout,
385                info.weights_layout,
386            )
387        };
388        if ptr.is_null() {
389            None
390        } else {
391            Some(Self { ptr })
392        }
393    }
394}
395
396/// Mirrors the `MPSGraph` framework counterpart for `DepthwiseConvolution2DDescriptorInfo`.
397#[derive(Debug, Clone, Copy)]
398pub struct DepthwiseConvolution2DDescriptorInfo {
399/// Mirrors the `MPSGraph` framework property for `stride_in_x`.
400    pub stride_in_x: usize,
401/// Mirrors the `MPSGraph` framework property for `stride_in_y`.
402    pub stride_in_y: usize,
403/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_x`.
404    pub dilation_rate_in_x: usize,
405/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_y`.
406    pub dilation_rate_in_y: usize,
407/// Mirrors the `MPSGraph` framework property for `padding_left`.
408    pub padding_left: usize,
409/// Mirrors the `MPSGraph` framework property for `padding_right`.
410    pub padding_right: usize,
411/// Mirrors the `MPSGraph` framework property for `padding_top`.
412    pub padding_top: usize,
413/// Mirrors the `MPSGraph` framework property for `padding_bottom`.
414    pub padding_bottom: usize,
415/// Mirrors the `MPSGraph` framework property for `padding_style`.
416    pub padding_style: usize,
417/// Mirrors the `MPSGraph` framework property for `data_layout`.
418    pub data_layout: usize,
419/// Mirrors the `MPSGraph` framework property for `weights_layout`.
420    pub weights_layout: usize,
421}
422
423impl Default for DepthwiseConvolution2DDescriptorInfo {
424    fn default() -> Self {
425        Self {
426            stride_in_x: 1,
427            stride_in_y: 1,
428            dilation_rate_in_x: 1,
429            dilation_rate_in_y: 1,
430            padding_left: 0,
431            padding_right: 0,
432            padding_top: 0,
433            padding_bottom: 0,
434            padding_style: padding_style::EXPLICIT,
435            data_layout: tensor_named_data_layout::NHWC,
436            weights_layout: tensor_named_data_layout::HWIO,
437        }
438    }
439}
440
441opaque_handle!(DepthwiseConvolution2DDescriptor);
442impl DepthwiseConvolution2DDescriptor {
443/// Calls the `MPSGraph` framework counterpart for `new`.
444    #[must_use]
445    pub fn new(info: DepthwiseConvolution2DDescriptorInfo) -> Option<Self> {
446        // SAFETY: all arguments are POD configuration values.
447        let ptr = unsafe {
448            ffi::mpsgraph_depthwise_convolution2d_descriptor_new(
449                info.stride_in_x,
450                info.stride_in_y,
451                info.dilation_rate_in_x,
452                info.dilation_rate_in_y,
453                info.padding_left,
454                info.padding_right,
455                info.padding_top,
456                info.padding_bottom,
457                info.padding_style,
458                info.data_layout,
459                info.weights_layout,
460            )
461        };
462        if ptr.is_null() {
463            None
464        } else {
465            Some(Self { ptr })
466        }
467    }
468}
469
470/// Mirrors the `MPSGraph` framework counterpart for `DepthwiseConvolution3DDescriptorInfo`.
471#[derive(Debug, Clone, Copy)]
472pub struct DepthwiseConvolution3DDescriptorInfo {
473/// Mirrors the `MPSGraph` framework property for `strides`.
474    pub strides: [usize; 3],
475/// Mirrors the `MPSGraph` framework property for `dilation_rates`.
476    pub dilation_rates: [usize; 3],
477/// Mirrors the `MPSGraph` framework property for `padding_values`.
478    pub padding_values: [usize; 6],
479/// Mirrors the `MPSGraph` framework property for `padding_style`.
480    pub padding_style: usize,
481/// Mirrors the `MPSGraph` framework property for `channel_dimension_index`.
482    pub channel_dimension_index: isize,
483}
484
485impl Default for DepthwiseConvolution3DDescriptorInfo {
486    fn default() -> Self {
487        Self {
488            strides: [1, 1, 1],
489            dilation_rates: [1, 1, 1],
490            padding_values: [0, 0, 0, 0, 0, 0],
491            padding_style: padding_style::EXPLICIT,
492            channel_dimension_index: -1,
493        }
494    }
495}
496
497opaque_handle!(DepthwiseConvolution3DDescriptor);
498impl DepthwiseConvolution3DDescriptor {
499/// Calls the `MPSGraph` framework counterpart for `new`.
500    #[must_use]
501    pub fn new(info: DepthwiseConvolution3DDescriptorInfo) -> Option<Self> {
502        // SAFETY: all slices stay alive for the duration of the call.
503        let ptr = unsafe {
504            ffi::mpsgraph_depthwise_convolution3d_descriptor_new(
505                info.strides.as_ptr(),
506                info.strides.len(),
507                info.dilation_rates.as_ptr(),
508                info.dilation_rates.len(),
509                info.padding_values.as_ptr(),
510                info.padding_values.len(),
511                info.padding_style,
512                info.channel_dimension_index,
513            )
514        };
515        if ptr.is_null() {
516            None
517        } else {
518            Some(Self { ptr })
519        }
520    }
521}
522
523/// Mirrors the `MPSGraph` framework counterpart for `FftDescriptorInfo`.
524#[derive(Debug, Clone, Copy)]
525pub struct FftDescriptorInfo {
526/// Mirrors the `MPSGraph` framework property for `inverse`.
527    pub inverse: bool,
528/// Mirrors the `MPSGraph` framework property for `scaling_mode`.
529    pub scaling_mode: usize,
530/// Mirrors the `MPSGraph` framework property for `round_to_odd_hermitean`.
531    pub round_to_odd_hermitean: bool,
532}
533
534impl Default for FftDescriptorInfo {
535    fn default() -> Self {
536        Self {
537            inverse: false,
538            scaling_mode: fft_scaling_mode::NONE,
539            round_to_odd_hermitean: false,
540        }
541    }
542}
543
544opaque_handle!(FftDescriptor);
545impl FftDescriptor {
546/// Calls the `MPSGraph` framework counterpart for `new`.
547    #[must_use]
548    pub fn new(info: FftDescriptorInfo) -> Option<Self> {
549        // SAFETY: all arguments are POD configuration values.
550        let ptr = unsafe {
551            ffi::mpsgraph_fft_descriptor_new(
552                info.inverse,
553                info.scaling_mode,
554                info.round_to_odd_hermitean,
555            )
556        };
557        if ptr.is_null() {
558            None
559        } else {
560            Some(Self { ptr })
561        }
562    }
563}
564
565/// Mirrors the `MPSGraph` framework counterpart for `ImToColDescriptorInfo`.
566#[derive(Debug, Clone, Copy)]
567pub struct ImToColDescriptorInfo {
568/// Mirrors the `MPSGraph` framework property for `kernel_width`.
569    pub kernel_width: usize,
570/// Mirrors the `MPSGraph` framework property for `kernel_height`.
571    pub kernel_height: usize,
572/// Mirrors the `MPSGraph` framework property for `stride_in_x`.
573    pub stride_in_x: usize,
574/// Mirrors the `MPSGraph` framework property for `stride_in_y`.
575    pub stride_in_y: usize,
576/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_x`.
577    pub dilation_rate_in_x: usize,
578/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_y`.
579    pub dilation_rate_in_y: usize,
580/// Mirrors the `MPSGraph` framework property for `padding_left`.
581    pub padding_left: usize,
582/// Mirrors the `MPSGraph` framework property for `padding_right`.
583    pub padding_right: usize,
584/// Mirrors the `MPSGraph` framework property for `padding_top`.
585    pub padding_top: usize,
586/// Mirrors the `MPSGraph` framework property for `padding_bottom`.
587    pub padding_bottom: usize,
588/// Mirrors the `MPSGraph` framework property for `data_layout`.
589    pub data_layout: usize,
590}
591
592impl Default for ImToColDescriptorInfo {
593    fn default() -> Self {
594        Self {
595            kernel_width: 1,
596            kernel_height: 1,
597            stride_in_x: 1,
598            stride_in_y: 1,
599            dilation_rate_in_x: 1,
600            dilation_rate_in_y: 1,
601            padding_left: 0,
602            padding_right: 0,
603            padding_top: 0,
604            padding_bottom: 0,
605            data_layout: tensor_named_data_layout::NHWC,
606        }
607    }
608}
609
610opaque_handle!(ImToColDescriptor);
611impl ImToColDescriptor {
612/// Calls the `MPSGraph` framework counterpart for `new`.
613    #[must_use]
614    pub fn new(info: ImToColDescriptorInfo) -> Option<Self> {
615        // SAFETY: all arguments are POD configuration values.
616        let ptr = unsafe {
617            ffi::mpsgraph_im_to_col_descriptor_new(
618                info.kernel_width,
619                info.kernel_height,
620                info.stride_in_x,
621                info.stride_in_y,
622                info.dilation_rate_in_x,
623                info.dilation_rate_in_y,
624                info.padding_left,
625                info.padding_right,
626                info.padding_top,
627                info.padding_bottom,
628                info.data_layout,
629            )
630        };
631        if ptr.is_null() {
632            None
633        } else {
634            Some(Self { ptr })
635        }
636    }
637}
638
639/// Mirrors the `MPSGraph` framework counterpart for `Pooling4DDescriptorInfo`.
640#[derive(Debug, Clone, Copy)]
641pub struct Pooling4DDescriptorInfo {
642/// Mirrors the `MPSGraph` framework property for `kernel_sizes`.
643    pub kernel_sizes: [usize; 4],
644/// Mirrors the `MPSGraph` framework property for `strides`.
645    pub strides: [usize; 4],
646/// Mirrors the `MPSGraph` framework property for `dilation_rates`.
647    pub dilation_rates: [usize; 4],
648/// Mirrors the `MPSGraph` framework property for `padding_values`.
649    pub padding_values: [usize; 8],
650/// Mirrors the `MPSGraph` framework property for `padding_style`.
651    pub padding_style: usize,
652/// Mirrors the `MPSGraph` framework property for `ceil_mode`.
653    pub ceil_mode: bool,
654/// Mirrors the `MPSGraph` framework property for `include_zero_pad_to_average`.
655    pub include_zero_pad_to_average: bool,
656/// Mirrors the `MPSGraph` framework property for `return_indices_mode`.
657    pub return_indices_mode: usize,
658/// Mirrors the `MPSGraph` framework property for `return_indices_data_type`.
659    pub return_indices_data_type: u32,
660}
661
662impl Default for Pooling4DDescriptorInfo {
663    fn default() -> Self {
664        Self {
665            kernel_sizes: [1, 1, 1, 1],
666            strides: [1, 1, 1, 1],
667            dilation_rates: [1, 1, 1, 1],
668            padding_values: [0, 0, 0, 0, 0, 0, 0, 0],
669            padding_style: padding_style::EXPLICIT,
670            ceil_mode: false,
671            include_zero_pad_to_average: false,
672            return_indices_mode: pooling_return_indices_mode::NONE,
673            return_indices_data_type: data_type::INT32,
674        }
675    }
676}
677
678opaque_handle!(Pooling4DDescriptor);
679impl Pooling4DDescriptor {
680/// Calls the `MPSGraph` framework counterpart for `new`.
681    #[must_use]
682    pub fn new(info: Pooling4DDescriptorInfo) -> Option<Self> {
683        // SAFETY: all slices stay alive for the duration of the call.
684        let ptr = unsafe {
685            ffi::mpsgraph_pooling4d_descriptor_new(
686                info.kernel_sizes.as_ptr(),
687                info.kernel_sizes.len(),
688                info.strides.as_ptr(),
689                info.strides.len(),
690                info.dilation_rates.as_ptr(),
691                info.dilation_rates.len(),
692                info.padding_values.as_ptr(),
693                info.padding_values.len(),
694                info.padding_style,
695                info.ceil_mode,
696                info.include_zero_pad_to_average,
697                info.return_indices_mode,
698                info.return_indices_data_type,
699            )
700        };
701        if ptr.is_null() {
702            None
703        } else {
704            Some(Self { ptr })
705        }
706    }
707}
708
709opaque_handle!(CreateSparseDescriptor);
710impl CreateSparseDescriptor {
711/// Calls the `MPSGraph` framework counterpart for `new`.
712    #[must_use]
713    pub fn new(storage_type: u64, data_type: u32) -> Option<Self> {
714        // SAFETY: all arguments are POD configuration values.
715        let ptr = unsafe { ffi::mpsgraph_sparse_descriptor_new(storage_type, data_type) };
716        if ptr.is_null() {
717            None
718        } else {
719            Some(Self { ptr })
720        }
721    }
722}
723
724/// Mirrors the `MPSGraph` framework counterpart for `StencilDescriptorInfo`.
725#[derive(Debug, Clone, Copy)]
726pub struct StencilDescriptorInfo {
727/// Mirrors the `MPSGraph` framework property for `reduction_mode`.
728    pub reduction_mode: usize,
729/// Mirrors the `MPSGraph` framework property for `offsets`.
730    pub offsets: [isize; 4],
731/// Mirrors the `MPSGraph` framework property for `strides`.
732    pub strides: [usize; 4],
733/// Mirrors the `MPSGraph` framework property for `dilation_rates`.
734    pub dilation_rates: [usize; 4],
735/// Mirrors the `MPSGraph` framework property for `explicit_padding`.
736    pub explicit_padding: [usize; 8],
737/// Mirrors the `MPSGraph` framework property for `boundary_mode`.
738    pub boundary_mode: isize,
739/// Mirrors the `MPSGraph` framework property for `padding_style`.
740    pub padding_style: usize,
741/// Mirrors the `MPSGraph` framework property for `padding_constant`.
742    pub padding_constant: f32,
743}
744
745impl Default for StencilDescriptorInfo {
746    fn default() -> Self {
747        Self {
748            reduction_mode: reduction_mode::SUM,
749            offsets: [0, 0, 0, 0],
750            strides: [1, 1, 1, 1],
751            dilation_rates: [1, 1, 1, 1],
752            explicit_padding: [0, 0, 0, 0, 0, 0, 0, 0],
753            boundary_mode: padding_mode::ZERO,
754            padding_style: padding_style::EXPLICIT,
755            padding_constant: 0.0,
756        }
757    }
758}
759
760opaque_handle!(StencilDescriptor);
761impl StencilDescriptor {
762/// Calls the `MPSGraph` framework counterpart for `new`.
763    #[must_use]
764    pub fn new(info: StencilDescriptorInfo) -> Option<Self> {
765        // SAFETY: all slices stay alive for the duration of the call.
766        let ptr = unsafe {
767            ffi::mpsgraph_stencil_descriptor_new(
768                info.reduction_mode,
769                info.offsets.as_ptr(),
770                info.offsets.len(),
771                info.strides.as_ptr(),
772                info.strides.len(),
773                info.dilation_rates.as_ptr(),
774                info.dilation_rates.len(),
775                info.explicit_padding.as_ptr(),
776                info.explicit_padding.len(),
777                info.boundary_mode,
778                info.padding_style,
779                info.padding_constant,
780            )
781        };
782        if ptr.is_null() {
783            None
784        } else {
785            Some(Self { ptr })
786        }
787    }
788}
789
790impl Graph {
791/// Calls the `MPSGraph` framework counterpart for `convolution3d`.
792    #[must_use]
793    pub fn convolution3d(
794        &self,
795        source: &Tensor,
796        weights: &Tensor,
797        descriptor: &Convolution3DDescriptor,
798        name: Option<&str>,
799    ) -> Option<Tensor> {
800        let name = optional_cstring(name);
801        // SAFETY: all handles remain valid for the duration of the call.
802        let ptr = unsafe {
803            ffi::mpsgraph_graph_convolution3d(
804                self.as_ptr(),
805                source.as_ptr(),
806                weights.as_ptr(),
807                descriptor.as_ptr(),
808                cstring_ptr(&name),
809            )
810        };
811        wrap_tensor(ptr)
812    }
813
814/// Calls the `MPSGraph` framework counterpart for `convolution_transpose2d`.
815    #[must_use]
816    pub fn convolution_transpose2d(
817        &self,
818        source: &Tensor,
819        weights: &Tensor,
820        output_shape: &[usize],
821        descriptor: &Convolution2DDescriptor,
822        name: Option<&str>,
823    ) -> Option<Tensor> {
824        let name = optional_cstring(name);
825        // SAFETY: all handles and slices remain valid for the duration of the call.
826        let ptr = unsafe {
827            ffi::mpsgraph_graph_convolution_transpose2d(
828                self.as_ptr(),
829                source.as_ptr(),
830                weights.as_ptr(),
831                output_shape.as_ptr(),
832                output_shape.len(),
833                descriptor.as_ptr(),
834                cstring_ptr(&name),
835            )
836        };
837        wrap_tensor(ptr)
838    }
839
840/// Calls the `MPSGraph` framework counterpart for `cumulative_sum`.
841    #[must_use]
842    pub fn cumulative_sum(
843        &self,
844        tensor: &Tensor,
845        axis: isize,
846        exclusive: bool,
847        reverse: bool,
848        name: Option<&str>,
849    ) -> Option<Tensor> {
850        let name = optional_cstring(name);
851        // SAFETY: all handles remain valid for the duration of the call.
852        let ptr = unsafe {
853            ffi::mpsgraph_graph_cumulative_sum(
854                self.as_ptr(),
855                tensor.as_ptr(),
856                axis,
857                exclusive,
858                reverse,
859                cstring_ptr(&name),
860            )
861        };
862        wrap_tensor(ptr)
863    }
864
865/// Calls the `MPSGraph` framework counterpart for `depthwise_convolution2d`.
866    #[must_use]
867    pub fn depthwise_convolution2d(
868        &self,
869        source: &Tensor,
870        weights: &Tensor,
871        descriptor: &DepthwiseConvolution2DDescriptor,
872        name: Option<&str>,
873    ) -> Option<Tensor> {
874        let name = optional_cstring(name);
875        // SAFETY: all handles remain valid for the duration of the call.
876        let ptr = unsafe {
877            ffi::mpsgraph_graph_depthwise_convolution2d(
878                self.as_ptr(),
879                source.as_ptr(),
880                weights.as_ptr(),
881                descriptor.as_ptr(),
882                cstring_ptr(&name),
883            )
884        };
885        wrap_tensor(ptr)
886    }
887
888/// Calls the `MPSGraph` framework counterpart for `depthwise_convolution3d`.
889    #[must_use]
890    pub fn depthwise_convolution3d(
891        &self,
892        source: &Tensor,
893        weights: &Tensor,
894        descriptor: &DepthwiseConvolution3DDescriptor,
895        name: Option<&str>,
896    ) -> Option<Tensor> {
897        let name = optional_cstring(name);
898        // SAFETY: all handles remain valid for the duration of the call.
899        let ptr = unsafe {
900            ffi::mpsgraph_graph_depthwise_convolution3d(
901                self.as_ptr(),
902                source.as_ptr(),
903                weights.as_ptr(),
904                descriptor.as_ptr(),
905                cstring_ptr(&name),
906            )
907        };
908        wrap_tensor(ptr)
909    }
910
911/// Calls the `MPSGraph` framework counterpart for `fast_fourier_transform`.
912    #[must_use]
913    pub fn fast_fourier_transform(
914        &self,
915        tensor: &Tensor,
916        axes: &[usize],
917        descriptor: &FftDescriptor,
918        name: Option<&str>,
919    ) -> Option<Tensor> {
920        let name = optional_cstring(name);
921        // SAFETY: all handles and slices remain valid for the duration of the call.
922        let ptr = unsafe {
923            ffi::mpsgraph_graph_fast_fourier_transform(
924                self.as_ptr(),
925                tensor.as_ptr(),
926                axes.as_ptr(),
927                axes.len(),
928                descriptor.as_ptr(),
929                cstring_ptr(&name),
930            )
931        };
932        wrap_tensor(ptr)
933    }
934
935/// Calls the `MPSGraph` framework counterpart for `im_to_col`.
936    #[must_use]
937    pub fn im_to_col(
938        &self,
939        source: &Tensor,
940        descriptor: &ImToColDescriptor,
941        name: Option<&str>,
942    ) -> Option<Tensor> {
943        let name = optional_cstring(name);
944        // SAFETY: all handles remain valid for the duration of the call.
945        let ptr = unsafe {
946            ffi::mpsgraph_graph_im_to_col(
947                self.as_ptr(),
948                source.as_ptr(),
949                descriptor.as_ptr(),
950                cstring_ptr(&name),
951            )
952        };
953        wrap_tensor(ptr)
954    }
955
956/// Calls the `MPSGraph` framework counterpart for `band_part`.
957    #[must_use]
958    pub fn band_part(
959        &self,
960        tensor: &Tensor,
961        num_lower: isize,
962        num_upper: isize,
963        name: Option<&str>,
964    ) -> Option<Tensor> {
965        let name = optional_cstring(name);
966        // SAFETY: all handles remain valid for the duration of the call.
967        let ptr = unsafe {
968            ffi::mpsgraph_graph_band_part(
969                self.as_ptr(),
970                tensor.as_ptr(),
971                num_lower,
972                num_upper,
973                cstring_ptr(&name),
974            )
975        };
976        wrap_tensor(ptr)
977    }
978
979/// Calls the `MPSGraph` framework counterpart for `softmax_cross_entropy`.
980    #[must_use]
981    pub fn softmax_cross_entropy(
982        &self,
983        source: &Tensor,
984        labels: &Tensor,
985        axis: isize,
986        reduction_type: u64,
987        name: Option<&str>,
988    ) -> Option<Tensor> {
989        let name = optional_cstring(name);
990        // SAFETY: all handles remain valid for the duration of the call.
991        let ptr = unsafe {
992            ffi::mpsgraph_graph_softmax_cross_entropy(
993                self.as_ptr(),
994                source.as_ptr(),
995                labels.as_ptr(),
996                axis,
997                reduction_type,
998                cstring_ptr(&name),
999            )
1000        };
1001        wrap_tensor(ptr)
1002    }
1003
1004/// Calls the `MPSGraph` framework counterpart for `matrix_inverse`.
1005    #[must_use]
1006    pub fn matrix_inverse(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
1007        let name = optional_cstring(name);
1008        // SAFETY: all handles remain valid for the duration of the call.
1009        let ptr = unsafe {
1010            ffi::mpsgraph_graph_matrix_inverse(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
1011        };
1012        wrap_tensor(ptr)
1013    }
1014
1015/// Calls the `MPSGraph` framework counterpart for `variable_bytes`.
1016    #[must_use]
1017    pub fn variable_bytes(
1018        &self,
1019        data: &[u8],
1020        shape: &[usize],
1021        data_type: u32,
1022        name: Option<&str>,
1023    ) -> Option<Tensor> {
1024        let expected = checked_byte_len(shape, data_type)?;
1025        if data.len() != expected {
1026            return None;
1027        }
1028
1029        let name = optional_cstring(name);
1030        // SAFETY: all handles and slices remain valid for the duration of the call.
1031        let ptr = unsafe {
1032            ffi::mpsgraph_graph_variable_data(
1033                self.as_ptr(),
1034                data.as_ptr().cast(),
1035                data.len(),
1036                shape.as_ptr(),
1037                shape.len(),
1038                data_type,
1039                cstring_ptr(&name),
1040            )
1041        };
1042        wrap_tensor(ptr)
1043    }
1044
1045/// Calls the `MPSGraph` framework counterpart for `variable_f32_slice`.
1046    #[must_use]
1047    pub fn variable_f32_slice(
1048        &self,
1049        values: &[f32],
1050        shape: &[usize],
1051        name: Option<&str>,
1052    ) -> Option<Tensor> {
1053        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
1054        let bytes = unsafe {
1055            core::slice::from_raw_parts(
1056                values.as_ptr().cast::<u8>(),
1057                core::mem::size_of_val(values),
1058            )
1059        };
1060        self.variable_bytes(bytes, shape, data_type::FLOAT32, name)
1061    }
1062
1063/// Calls the `MPSGraph` framework counterpart for `read_variable`.
1064    #[must_use]
1065    pub fn read_variable(&self, variable: &Tensor, name: Option<&str>) -> Option<Tensor> {
1066        let name = optional_cstring(name);
1067        // SAFETY: all handles remain valid for the duration of the call.
1068        let ptr = unsafe {
1069            ffi::mpsgraph_graph_read_variable(self.as_ptr(), variable.as_ptr(), cstring_ptr(&name))
1070        };
1071        wrap_tensor(ptr)
1072    }
1073
1074/// Calls the `MPSGraph` framework counterpart for `assign_variable`.
1075    #[must_use]
1076    pub fn assign_variable(
1077        &self,
1078        variable: &Tensor,
1079        value: &Tensor,
1080        name: Option<&str>,
1081    ) -> Option<Operation> {
1082        let name = optional_cstring(name);
1083        // SAFETY: all handles remain valid for the duration of the call.
1084        let ptr = unsafe {
1085            ffi::mpsgraph_graph_assign_variable(
1086                self.as_ptr(),
1087                variable.as_ptr(),
1088                value.as_ptr(),
1089                cstring_ptr(&name),
1090            )
1091        };
1092        wrap_operation(ptr)
1093    }
1094
1095/// Calls the `MPSGraph` framework counterpart for `non_maximum_suppression`.
1096    #[must_use]
1097    #[allow(clippy::too_many_arguments)]
1098    pub fn non_maximum_suppression(
1099        &self,
1100        boxes: &Tensor,
1101        scores: &Tensor,
1102        iou_threshold: f32,
1103        score_threshold: f32,
1104        per_class_suppression: bool,
1105        coordinate_mode: usize,
1106        name: Option<&str>,
1107    ) -> Option<Tensor> {
1108        let name = optional_cstring(name);
1109        // SAFETY: all handles remain valid for the duration of the call.
1110        let ptr = unsafe {
1111            ffi::mpsgraph_graph_non_maximum_suppression(
1112                self.as_ptr(),
1113                boxes.as_ptr(),
1114                scores.as_ptr(),
1115                iou_threshold,
1116                score_threshold,
1117                per_class_suppression,
1118                coordinate_mode,
1119                cstring_ptr(&name),
1120            )
1121        };
1122        wrap_tensor(ptr)
1123    }
1124
1125/// Calls the `MPSGraph` framework counterpart for `non_zero_indices`.
1126    #[must_use]
1127    pub fn non_zero_indices(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
1128        let name = optional_cstring(name);
1129        // SAFETY: all handles remain valid for the duration of the call.
1130        let ptr = unsafe {
1131            ffi::mpsgraph_graph_non_zero_indices(self.as_ptr(), tensor.as_ptr(), cstring_ptr(&name))
1132        };
1133        wrap_tensor(ptr)
1134    }
1135
1136/// Calls the `MPSGraph` framework counterpart for `one_hot`.
1137    #[must_use]
1138    pub fn one_hot(
1139        &self,
1140        indices: &Tensor,
1141        depth: usize,
1142        data_type: u32,
1143        name: Option<&str>,
1144    ) -> Option<Tensor> {
1145        let name = optional_cstring(name);
1146        // SAFETY: all handles remain valid for the duration of the call.
1147        let ptr = unsafe {
1148            ffi::mpsgraph_graph_one_hot(
1149                self.as_ptr(),
1150                indices.as_ptr(),
1151                depth,
1152                data_type,
1153                cstring_ptr(&name),
1154            )
1155        };
1156        wrap_tensor(ptr)
1157    }
1158
1159/// Calls the `MPSGraph` framework counterpart for `stochastic_gradient_descent`.
1160    #[must_use]
1161    pub fn stochastic_gradient_descent(
1162        &self,
1163        learning_rate: &Tensor,
1164        values: &Tensor,
1165        gradient: &Tensor,
1166        name: Option<&str>,
1167    ) -> Option<Tensor> {
1168        let name = optional_cstring(name);
1169        // SAFETY: all handles remain valid for the duration of the call.
1170        let ptr = unsafe {
1171            ffi::mpsgraph_graph_stochastic_gradient_descent(
1172                self.as_ptr(),
1173                learning_rate.as_ptr(),
1174                values.as_ptr(),
1175                gradient.as_ptr(),
1176                cstring_ptr(&name),
1177            )
1178        };
1179        wrap_tensor(ptr)
1180    }
1181
1182/// Calls the `MPSGraph` framework counterpart for `max_pooling4d`.
1183    #[must_use]
1184    pub fn max_pooling4d(
1185        &self,
1186        source: &Tensor,
1187        descriptor: &Pooling4DDescriptor,
1188        name: Option<&str>,
1189    ) -> Option<Tensor> {
1190        let name = optional_cstring(name);
1191        // SAFETY: all handles remain valid for the duration of the call.
1192        let ptr = unsafe {
1193            ffi::mpsgraph_graph_max_pooling4d(
1194                self.as_ptr(),
1195                source.as_ptr(),
1196                descriptor.as_ptr(),
1197                cstring_ptr(&name),
1198            )
1199        };
1200        wrap_tensor(ptr)
1201    }
1202
1203/// Calls the `MPSGraph` framework counterpart for `max_pooling4d_return_indices`.
1204    #[must_use]
1205    pub fn max_pooling4d_return_indices(
1206        &self,
1207        source: &Tensor,
1208        descriptor: &Pooling4DDescriptor,
1209        name: Option<&str>,
1210    ) -> Option<(Tensor, Tensor)> {
1211        let name = optional_cstring(name);
1212        // SAFETY: all handles remain valid for the duration of the call.
1213        let box_handle = unsafe {
1214            ffi::mpsgraph_graph_max_pooling4d_return_indices(
1215                self.as_ptr(),
1216                source.as_ptr(),
1217                descriptor.as_ptr(),
1218                cstring_ptr(&name),
1219            )
1220        };
1221        wrap_tensor_pair(box_handle)
1222    }
1223
1224/// Calls the `MPSGraph` framework counterpart for `quantize`.
1225    #[must_use]
1226    pub fn quantize(
1227        &self,
1228        tensor: &Tensor,
1229        scale: f64,
1230        zero_point: f64,
1231        data_type: u32,
1232        name: Option<&str>,
1233    ) -> Option<Tensor> {
1234        let name = optional_cstring(name);
1235        // SAFETY: all handles remain valid for the duration of the call.
1236        let ptr = unsafe {
1237            ffi::mpsgraph_graph_quantize(
1238                self.as_ptr(),
1239                tensor.as_ptr(),
1240                scale,
1241                zero_point,
1242                data_type,
1243                cstring_ptr(&name),
1244            )
1245        };
1246        wrap_tensor(ptr)
1247    }
1248
1249/// Calls the `MPSGraph` framework counterpart for `dequantize`.
1250    #[must_use]
1251    pub fn dequantize(
1252        &self,
1253        tensor: &Tensor,
1254        scale: f64,
1255        zero_point: f64,
1256        data_type: u32,
1257        name: Option<&str>,
1258    ) -> Option<Tensor> {
1259        let name = optional_cstring(name);
1260        // SAFETY: all handles remain valid for the duration of the call.
1261        let ptr = unsafe {
1262            ffi::mpsgraph_graph_dequantize(
1263                self.as_ptr(),
1264                tensor.as_ptr(),
1265                scale,
1266                zero_point,
1267                data_type,
1268                cstring_ptr(&name),
1269            )
1270        };
1271        wrap_tensor(ptr)
1272    }
1273
1274/// Calls the `MPSGraph` framework counterpart for `resize`.
1275    #[must_use]
1276    #[allow(clippy::too_many_arguments)]
1277    pub fn resize(
1278        &self,
1279        images: &Tensor,
1280        size: &[usize],
1281        mode: usize,
1282        center_result: bool,
1283        align_corners: bool,
1284        layout: usize,
1285        name: Option<&str>,
1286    ) -> Option<Tensor> {
1287        let name = optional_cstring(name);
1288        // SAFETY: all handles and slices remain valid for the duration of the call.
1289        let ptr = unsafe {
1290            ffi::mpsgraph_graph_resize(
1291                self.as_ptr(),
1292                images.as_ptr(),
1293                size.as_ptr(),
1294                size.len(),
1295                mode,
1296                center_result,
1297                align_corners,
1298                layout,
1299                cstring_ptr(&name),
1300            )
1301        };
1302        wrap_tensor(ptr)
1303    }
1304
1305/// Calls the `MPSGraph` framework counterpart for `resize_nearest`.
1306    #[must_use]
1307    #[allow(clippy::too_many_arguments)]
1308    pub fn resize_nearest(
1309        &self,
1310        images: &Tensor,
1311        size_tensor: &Tensor,
1312        nearest_rounding_mode: usize,
1313        center_result: bool,
1314        align_corners: bool,
1315        layout: usize,
1316        name: Option<&str>,
1317    ) -> Option<Tensor> {
1318        let name = optional_cstring(name);
1319        // SAFETY: all handles remain valid for the duration of the call.
1320        let ptr = unsafe {
1321            ffi::mpsgraph_graph_resize_nearest(
1322                self.as_ptr(),
1323                images.as_ptr(),
1324                size_tensor.as_ptr(),
1325                nearest_rounding_mode,
1326                center_result,
1327                align_corners,
1328                layout,
1329                cstring_ptr(&name),
1330            )
1331        };
1332        wrap_tensor(ptr)
1333    }
1334
1335/// Calls the `MPSGraph` framework counterpart for `sample_grid`.
1336    #[must_use]
1337    #[allow(clippy::too_many_arguments)]
1338    pub fn sample_grid(
1339        &self,
1340        source: &Tensor,
1341        coordinates: &Tensor,
1342        layout: usize,
1343        normalize_coordinates: bool,
1344        relative_coordinates: bool,
1345        align_corners: bool,
1346        padding_mode: isize,
1347        sampling_mode: usize,
1348        constant_value: f64,
1349        name: Option<&str>,
1350    ) -> Option<Tensor> {
1351        let name = optional_cstring(name);
1352        // SAFETY: all handles remain valid for the duration of the call.
1353        let ptr = unsafe {
1354            ffi::mpsgraph_graph_sample_grid(
1355                self.as_ptr(),
1356                source.as_ptr(),
1357                coordinates.as_ptr(),
1358                layout,
1359                normalize_coordinates,
1360                relative_coordinates,
1361                align_corners,
1362                padding_mode,
1363                sampling_mode,
1364                constant_value,
1365                cstring_ptr(&name),
1366            )
1367        };
1368        wrap_tensor(ptr)
1369    }
1370
1371/// Calls the `MPSGraph` framework counterpart for `scatter_nd`.
1372    #[must_use]
1373    pub fn scatter_nd(
1374        &self,
1375        updates: &Tensor,
1376        indices: &Tensor,
1377        shape: &[usize],
1378        batch_dimensions: usize,
1379        mode: isize,
1380        name: Option<&str>,
1381    ) -> Option<Tensor> {
1382        let name = optional_cstring(name);
1383        // SAFETY: all handles and slices remain valid for the duration of the call.
1384        let ptr = unsafe {
1385            ffi::mpsgraph_graph_scatter_nd(
1386                self.as_ptr(),
1387                updates.as_ptr(),
1388                indices.as_ptr(),
1389                shape.as_ptr(),
1390                shape.len(),
1391                batch_dimensions,
1392                mode,
1393                cstring_ptr(&name),
1394            )
1395        };
1396        wrap_tensor(ptr)
1397    }
1398
1399/// Calls the `MPSGraph` framework counterpart for `scatter`.
1400    #[must_use]
1401    pub fn scatter(
1402        &self,
1403        updates: &Tensor,
1404        indices: &Tensor,
1405        shape: &[usize],
1406        axis: isize,
1407        mode: isize,
1408        name: Option<&str>,
1409    ) -> Option<Tensor> {
1410        let name = optional_cstring(name);
1411        // SAFETY: all handles and slices remain valid for the duration of the call.
1412        let ptr = unsafe {
1413            ffi::mpsgraph_graph_scatter(
1414                self.as_ptr(),
1415                updates.as_ptr(),
1416                indices.as_ptr(),
1417                shape.as_ptr(),
1418                shape.len(),
1419                axis,
1420                mode,
1421                cstring_ptr(&name),
1422            )
1423        };
1424        wrap_tensor(ptr)
1425    }
1426
1427/// Calls the `MPSGraph` framework counterpart for `scatter_along_axis`.
1428    #[must_use]
1429    pub fn scatter_along_axis(
1430        &self,
1431        axis: isize,
1432        updates: &Tensor,
1433        indices: &Tensor,
1434        shape: &[usize],
1435        mode: isize,
1436        name: Option<&str>,
1437    ) -> Option<Tensor> {
1438        let name = optional_cstring(name);
1439        // SAFETY: all handles and slices remain valid for the duration of the call.
1440        let ptr = unsafe {
1441            ffi::mpsgraph_graph_scatter_along_axis(
1442                self.as_ptr(),
1443                axis,
1444                updates.as_ptr(),
1445                indices.as_ptr(),
1446                shape.as_ptr(),
1447                shape.len(),
1448                mode,
1449                cstring_ptr(&name),
1450            )
1451        };
1452        wrap_tensor(ptr)
1453    }
1454
1455/// Calls the `MPSGraph` framework counterpart for `sort`.
1456    #[must_use]
1457    pub fn sort(
1458        &self,
1459        tensor: &Tensor,
1460        axis: isize,
1461        descending: bool,
1462        name: Option<&str>,
1463    ) -> Option<Tensor> {
1464        let name = optional_cstring(name);
1465        // SAFETY: all handles remain valid for the duration of the call.
1466        let ptr = unsafe {
1467            ffi::mpsgraph_graph_sort(
1468                self.as_ptr(),
1469                tensor.as_ptr(),
1470                axis,
1471                descending,
1472                cstring_ptr(&name),
1473            )
1474        };
1475        wrap_tensor(ptr)
1476    }
1477
1478/// Calls the `MPSGraph` framework counterpart for `arg_sort`.
1479    #[must_use]
1480    pub fn arg_sort(
1481        &self,
1482        tensor: &Tensor,
1483        axis: isize,
1484        descending: bool,
1485        name: Option<&str>,
1486    ) -> Option<Tensor> {
1487        let name = optional_cstring(name);
1488        // SAFETY: all handles remain valid for the duration of the call.
1489        let ptr = unsafe {
1490            ffi::mpsgraph_graph_arg_sort(
1491                self.as_ptr(),
1492                tensor.as_ptr(),
1493                axis,
1494                descending,
1495                cstring_ptr(&name),
1496            )
1497        };
1498        wrap_tensor(ptr)
1499    }
1500
1501/// Calls the `MPSGraph` framework counterpart for `sparse_tensor_with_descriptor`.
1502    #[must_use]
1503    pub fn sparse_tensor_with_descriptor(
1504        &self,
1505        descriptor: &CreateSparseDescriptor,
1506        tensors: &[&Tensor],
1507        shape: &[usize],
1508        name: Option<&str>,
1509    ) -> Option<Tensor> {
1510        let name = optional_cstring(name);
1511        let handles = tensors
1512            .iter()
1513            .map(|tensor| tensor.as_ptr())
1514            .collect::<Vec<_>>();
1515        // SAFETY: all handles and slices remain valid for the duration of the call.
1516        let ptr = unsafe {
1517            ffi::mpsgraph_graph_sparse_tensor_with_descriptor(
1518                self.as_ptr(),
1519                descriptor.as_ptr(),
1520                handles.as_ptr(),
1521                handles.len(),
1522                shape.as_ptr(),
1523                shape.len(),
1524                cstring_ptr(&name),
1525            )
1526        };
1527        wrap_tensor(ptr)
1528    }
1529
1530/// Calls the `MPSGraph` framework counterpart for `stencil`.
1531    #[must_use]
1532    pub fn stencil(
1533        &self,
1534        source: &Tensor,
1535        weights: &Tensor,
1536        descriptor: &StencilDescriptor,
1537        name: Option<&str>,
1538    ) -> Option<Tensor> {
1539        let name = optional_cstring(name);
1540        // SAFETY: all handles remain valid for the duration of the call.
1541        let ptr = unsafe {
1542            ffi::mpsgraph_graph_stencil(
1543                self.as_ptr(),
1544                source.as_ptr(),
1545                weights.as_ptr(),
1546                descriptor.as_ptr(),
1547                cstring_ptr(&name),
1548            )
1549        };
1550        wrap_tensor(ptr)
1551    }
1552
1553/// Calls the `MPSGraph` framework counterpart for `top_k_gradient`.
1554    #[must_use]
1555    pub fn top_k_gradient(
1556        &self,
1557        gradient: &Tensor,
1558        source: &Tensor,
1559        k: usize,
1560        name: Option<&str>,
1561    ) -> Option<Tensor> {
1562        let name = optional_cstring(name);
1563        // SAFETY: all handles remain valid for the duration of the call.
1564        let ptr = unsafe {
1565            ffi::mpsgraph_graph_topk_gradient(
1566                self.as_ptr(),
1567                gradient.as_ptr(),
1568                source.as_ptr(),
1569                k,
1570                cstring_ptr(&name),
1571            )
1572        };
1573        wrap_tensor(ptr)
1574    }
1575}
1576
1577impl ExecutionDescriptor {
1578    /// # Safety
1579    ///
1580    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1581    pub unsafe fn wait_for_shared_event_raw(
1582        &self,
1583        event_handle: *mut c_void,
1584        value: u64,
1585    ) -> Result<()> {
1586        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1587        let ok = unsafe {
1588            ffi::mpsgraph_execution_descriptor_wait_for_event(self.as_ptr(), event_handle, value)
1589        };
1590        if ok {
1591            Ok(())
1592        } else {
1593            Err(Error::OperationFailed(
1594                "failed to register execution descriptor shared-event wait",
1595            ))
1596        }
1597    }
1598
1599    /// # Safety
1600    ///
1601    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1602    pub unsafe fn signal_shared_event_raw(
1603        &self,
1604        event_handle: *mut c_void,
1605        execution_stage: u64,
1606        value: u64,
1607    ) -> Result<()> {
1608        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1609        let ok = unsafe {
1610            ffi::mpsgraph_execution_descriptor_signal_event(
1611                self.as_ptr(),
1612                event_handle,
1613                execution_stage,
1614                value,
1615            )
1616        };
1617        if ok {
1618            Ok(())
1619        } else {
1620            Err(Error::OperationFailed(
1621                "failed to register execution descriptor shared-event signal",
1622            ))
1623        }
1624    }
1625}
1626
1627impl ExecutableExecutionDescriptor {
1628    /// # Safety
1629    ///
1630    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1631    pub unsafe fn wait_for_shared_event_raw(
1632        &self,
1633        event_handle: *mut c_void,
1634        value: u64,
1635    ) -> Result<()> {
1636        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1637        let ok = unsafe {
1638            ffi::mpsgraph_executable_execution_descriptor_wait_for_event(
1639                self.as_ptr(),
1640                event_handle,
1641                value,
1642            )
1643        };
1644        if ok {
1645            Ok(())
1646        } else {
1647            Err(Error::OperationFailed(
1648                "failed to register executable execution descriptor shared-event wait",
1649            ))
1650        }
1651    }
1652
1653    /// # Safety
1654    ///
1655    /// `event_handle` must be a valid `id<MTLSharedEvent>` for the lifetime of the call.
1656    pub unsafe fn signal_shared_event_raw(
1657        &self,
1658        event_handle: *mut c_void,
1659        execution_stage: u64,
1660        value: u64,
1661    ) -> Result<()> {
1662        // SAFETY: caller guarantees `event_handle` is a valid shared-event pointer.
1663        let ok = unsafe {
1664            ffi::mpsgraph_executable_execution_descriptor_signal_event(
1665                self.as_ptr(),
1666                event_handle,
1667                execution_stage,
1668                value,
1669            )
1670        };
1671        if ok {
1672            Ok(())
1673        } else {
1674            Err(Error::OperationFailed(
1675                "failed to register executable execution descriptor shared-event signal",
1676            ))
1677        }
1678    }
1679}