1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9pub mod data_type {
11pub const INVALID: u32 = 0;
13pub const FLOAT32: u32 = 0x1000_0020;
15pub const FLOAT16: u32 = 0x1000_0010;
17pub const INT8: u32 = 0x2000_0008;
19pub const INT16: u32 = 0x2000_0010;
21pub const INT32: u32 = 0x2000_0020;
23pub const INT64: u32 = 0x2000_0040;
25pub const UINT8: u32 = 0x0000_0008;
27pub const UINT16: u32 = 0x0000_0010;
29pub const UINT32: u32 = 0x0000_0020;
31pub const UINT64: u32 = 0x0000_0040;
33pub const BOOL: u32 = 0x8000_0008;
35pub const UNORM8: u32 = 0x4000_0008;
37}
38
39#[must_use]
41pub const fn data_type_size(data_type: u32) -> Option<usize> {
42 match data_type {
43 data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
44 data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
45 data_type::INT64 | data_type::UINT64 => Some(8),
46 data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
47 _ => None,
48 }
49}
50
51pub mod tensor_named_data_layout {
53pub const NCHW: usize = 0;
55pub const NHWC: usize = 1;
57pub const OIHW: usize = 2;
59pub const HWIO: usize = 3;
61pub const CHW: usize = 4;
63pub const HWC: usize = 5;
65pub const HW: usize = 6;
67pub const NCDHW: usize = 7;
69pub const NDHWC: usize = 8;
71pub const OIDHW: usize = 9;
73pub const DHWIO: usize = 10;
75}
76
77pub mod padding_style {
79pub const EXPLICIT: usize = 0;
81pub const TF_VALID: usize = 1;
83pub const TF_SAME: usize = 2;
85pub const EXPLICIT_OFFSET: usize = 3;
87pub const ONNX_SAME_LOWER: usize = 4;
89}
90
91pub mod padding_mode {
93pub const CONSTANT: isize = 0;
95pub const REFLECT: isize = 1;
97pub const SYMMETRIC: isize = 2;
99pub const CLAMP_TO_EDGE: isize = 3;
101pub const ZERO: isize = 4;
103pub const PERIODIC: isize = 5;
105pub const ANTI_PERIODIC: isize = 6;
107}
108
109macro_rules! opaque_handle {
110 ($name:ident) => {
111pub struct $name {
113 ptr: *mut c_void,
114 }
115
116 unsafe impl Send for $name {}
117 unsafe impl Sync for $name {}
118
119 impl Drop for $name {
120 fn drop(&mut self) {
121 if !self.ptr.is_null() {
122 unsafe { ffi::mpsgraph_object_release(self.ptr) };
124 self.ptr = ptr::null_mut();
125 }
126 }
127 }
128
129 impl $name {
130#[must_use]
132 pub const fn as_ptr(&self) -> *mut c_void {
133 self.ptr
134 }
135 }
136 };
137}
138
139fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
140 let element_size = data_type_size(data_type)?;
141 shape
142 .iter()
143 .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
144}
145
146fn optional_cstring(name: Option<&str>) -> Option<CString> {
147 name.and_then(|value| CString::new(value).ok())
148}
149
150#[allow(clippy::ref_option)]
151fn cstring_ptr(value: &Option<CString>) -> *const c_char {
152 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
153}
154
155fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
156 if ptr.is_null() {
157 None
158 } else {
159 Some(Tensor { ptr })
160 }
161}
162
163fn wrap_tensor_data_results(
164 handles: Vec<*mut c_void>,
165 message: &'static str,
166) -> Result<Vec<TensorData>> {
167 let mut results = Vec::with_capacity(handles.len());
168 for handle in handles {
169 if handle.is_null() {
170 return Err(Error::OperationFailed(message));
171 }
172 results.push(TensorData::from_raw(handle));
173 }
174 Ok(results)
175}
176
177macro_rules! impl_binary_tensor_op {
178 ($fn_name:ident, $ffi_name:ident) => {
179#[must_use]
181 pub fn $fn_name(
182 &self,
183 primary: &Tensor,
184 secondary: &Tensor,
185 name: Option<&str>,
186 ) -> Option<Tensor> {
187 let name = optional_cstring(name);
188 let ptr = unsafe {
190 ffi::$ffi_name(
191 self.ptr,
192 primary.as_ptr(),
193 secondary.as_ptr(),
194 cstring_ptr(&name),
195 )
196 };
197 wrap_tensor(ptr)
198 }
199 };
200}
201
202macro_rules! impl_unary_tensor_op {
203 ($fn_name:ident, $ffi_name:ident) => {
204#[must_use]
206 pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
207 let name = optional_cstring(name);
208 let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
210 wrap_tensor(ptr)
211 }
212 };
213}
214
215macro_rules! impl_axes_tensor_op {
216 ($fn_name:ident, $ffi_name:ident) => {
217#[must_use]
219 pub fn $fn_name(
220 &self,
221 tensor: &Tensor,
222 axes: &[usize],
223 name: Option<&str>,
224 ) -> Option<Tensor> {
225 let name = optional_cstring(name);
226 let ptr = unsafe {
228 ffi::$ffi_name(
229 self.ptr,
230 tensor.as_ptr(),
231 axes.as_ptr(),
232 axes.len(),
233 cstring_ptr(&name),
234 )
235 };
236 wrap_tensor(ptr)
237 }
238 };
239}
240
241#[derive(Clone, Copy)]
243pub struct Feed<'a> {
244pub tensor: &'a Tensor,
246pub data: &'a TensorData,
248}
249
250impl<'a> Feed<'a> {
251#[must_use]
253 pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
254 Self { tensor, data }
255 }
256}
257
258#[derive(Clone, Copy)]
260pub struct FeedDescription<'a> {
261pub tensor: &'a Tensor,
263pub shape: &'a [usize],
265pub data_type: u32,
267}
268
269impl<'a> FeedDescription<'a> {
270#[must_use]
272 pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
273 Self {
274 tensor,
275 shape,
276 data_type,
277 }
278 }
279}
280
281#[derive(Debug, Clone, Copy)]
283pub struct Convolution2DDescriptorInfo {
284pub stride_in_x: usize,
286pub stride_in_y: usize,
288pub dilation_rate_in_x: usize,
290pub dilation_rate_in_y: usize,
292pub groups: usize,
294pub padding_left: usize,
296pub padding_right: usize,
298pub padding_top: usize,
300pub padding_bottom: usize,
302pub padding_style: usize,
304pub data_layout: usize,
306pub weights_layout: usize,
308}
309
310impl Default for Convolution2DDescriptorInfo {
311 fn default() -> Self {
312 Self {
313 stride_in_x: 1,
314 stride_in_y: 1,
315 dilation_rate_in_x: 1,
316 dilation_rate_in_y: 1,
317 groups: 1,
318 padding_left: 0,
319 padding_right: 0,
320 padding_top: 0,
321 padding_bottom: 0,
322 padding_style: padding_style::EXPLICIT,
323 data_layout: tensor_named_data_layout::NHWC,
324 weights_layout: tensor_named_data_layout::HWIO,
325 }
326 }
327}
328
329opaque_handle!(Convolution2DDescriptor);
330impl Convolution2DDescriptor {
331#[must_use]
333 pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
334 let ptr = unsafe {
336 ffi::mpsgraph_convolution2d_descriptor_new(
337 info.stride_in_x,
338 info.stride_in_y,
339 info.dilation_rate_in_x,
340 info.dilation_rate_in_y,
341 info.groups,
342 info.padding_left,
343 info.padding_right,
344 info.padding_top,
345 info.padding_bottom,
346 info.padding_style,
347 info.data_layout,
348 info.weights_layout,
349 )
350 };
351 if ptr.is_null() {
352 None
353 } else {
354 Some(Self { ptr })
355 }
356 }
357}
358
359#[derive(Debug, Clone, Copy)]
361pub struct Pooling2DDescriptorInfo {
362pub kernel_width: usize,
364pub kernel_height: usize,
366pub stride_in_x: usize,
368pub stride_in_y: usize,
370pub dilation_rate_in_x: usize,
372pub dilation_rate_in_y: usize,
374pub padding_left: usize,
376pub padding_right: usize,
378pub padding_top: usize,
380pub padding_bottom: usize,
382pub padding_style: usize,
384pub data_layout: usize,
386}
387
388impl Pooling2DDescriptorInfo {
389#[must_use]
391 pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
392 Self {
393 kernel_width,
394 kernel_height,
395 stride_in_x: 1,
396 stride_in_y: 1,
397 dilation_rate_in_x: 1,
398 dilation_rate_in_y: 1,
399 padding_left: 0,
400 padding_right: 0,
401 padding_top: 0,
402 padding_bottom: 0,
403 padding_style: padding_style::EXPLICIT,
404 data_layout: tensor_named_data_layout::NHWC,
405 }
406 }
407}
408
409opaque_handle!(Pooling2DDescriptor);
410impl Pooling2DDescriptor {
411#[must_use]
413 pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
414 let ptr = unsafe {
416 ffi::mpsgraph_pooling2d_descriptor_new(
417 info.kernel_width,
418 info.kernel_height,
419 info.stride_in_x,
420 info.stride_in_y,
421 info.dilation_rate_in_x,
422 info.dilation_rate_in_y,
423 info.padding_left,
424 info.padding_right,
425 info.padding_top,
426 info.padding_bottom,
427 info.padding_style,
428 info.data_layout,
429 )
430 };
431 if ptr.is_null() {
432 None
433 } else {
434 Some(Self { ptr })
435 }
436 }
437}
438
439opaque_handle!(Graph);
440opaque_handle!(Tensor);
441
442impl Tensor {
443 pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
444 Self { ptr }
445 }
446}
447
448impl Graph {
449#[must_use]
451 pub fn new() -> Option<Self> {
452 let ptr = unsafe { ffi::mpsgraph_graph_new() };
454 if ptr.is_null() {
455 None
456 } else {
457 Some(Self { ptr })
458 }
459 }
460
461#[must_use]
463 pub fn placeholder(
464 &self,
465 shape: Option<&[usize]>,
466 data_type: u32,
467 name: Option<&str>,
468 ) -> Option<Tensor> {
469 let name = optional_cstring(name);
470 let (shape_ptr, shape_len) =
471 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
472
473 let ptr = unsafe {
475 ffi::mpsgraph_graph_placeholder(
476 self.ptr,
477 shape_ptr,
478 shape_len,
479 data_type,
480 cstring_ptr(&name),
481 )
482 };
483 wrap_tensor(ptr)
484 }
485
486#[must_use]
488 pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
489 let expected = checked_byte_len(shape, data_type)?;
490 if data.len() != expected {
491 return None;
492 }
493
494 let ptr = unsafe {
496 ffi::mpsgraph_graph_constant_data(
497 self.ptr,
498 data.as_ptr().cast(),
499 data.len(),
500 shape.as_ptr(),
501 shape.len(),
502 data_type,
503 )
504 };
505 wrap_tensor(ptr)
506 }
507
508#[must_use]
510 pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
511 let bytes = unsafe {
513 core::slice::from_raw_parts(
514 values.as_ptr().cast::<u8>(),
515 core::mem::size_of_val(values),
516 )
517 };
518 self.constant_bytes(bytes, shape, data_type::FLOAT32)
519 }
520
521#[must_use]
523 pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
524 let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
526 wrap_tensor(ptr)
527 }
528
529#[must_use]
531 pub fn constant_scalar_shaped(
532 &self,
533 scalar: f64,
534 shape: &[usize],
535 data_type: u32,
536 ) -> Option<Tensor> {
537 let ptr = unsafe {
539 ffi::mpsgraph_graph_constant_scalar_shaped(
540 self.ptr,
541 scalar,
542 shape.as_ptr(),
543 shape.len(),
544 data_type,
545 )
546 };
547 wrap_tensor(ptr)
548 }
549
550 impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
551 impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
552 impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
553 impl_binary_tensor_op!(division, mpsgraph_graph_division);
554 impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
555 impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
556 impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
557 impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
558 impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
559 impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
560 impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
561
562#[must_use]
564 pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
565 let name = optional_cstring(name);
566 let ptr = unsafe {
568 ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
569 };
570 wrap_tensor(ptr)
571 }
572
573#[must_use]
575 pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
576 let name = optional_cstring(name);
577 let ptr = unsafe {
579 ffi::mpsgraph_graph_reshape(
580 self.ptr,
581 tensor.as_ptr(),
582 shape.as_ptr(),
583 shape.len(),
584 cstring_ptr(&name),
585 )
586 };
587 wrap_tensor(ptr)
588 }
589
590#[must_use]
592 pub fn transpose(
593 &self,
594 tensor: &Tensor,
595 permutation: &[usize],
596 name: Option<&str>,
597 ) -> Option<Tensor> {
598 let name = optional_cstring(name);
599 let ptr = unsafe {
601 ffi::mpsgraph_graph_transpose(
602 self.ptr,
603 tensor.as_ptr(),
604 permutation.as_ptr(),
605 permutation.len(),
606 cstring_ptr(&name),
607 )
608 };
609 wrap_tensor(ptr)
610 }
611
612#[must_use]
614 pub fn slice(
615 &self,
616 tensor: &Tensor,
617 dimension: usize,
618 start: isize,
619 length: isize,
620 name: Option<&str>,
621 ) -> Option<Tensor> {
622 let name = optional_cstring(name);
623 let ptr = unsafe {
625 ffi::mpsgraph_graph_slice(
626 self.ptr,
627 tensor.as_ptr(),
628 dimension,
629 start,
630 length,
631 cstring_ptr(&name),
632 )
633 };
634 wrap_tensor(ptr)
635 }
636
637#[must_use]
639 pub fn broadcast(
640 &self,
641 tensor: &Tensor,
642 shape: &[usize],
643 name: Option<&str>,
644 ) -> Option<Tensor> {
645 let name = optional_cstring(name);
646 let ptr = unsafe {
648 ffi::mpsgraph_graph_broadcast(
649 self.ptr,
650 tensor.as_ptr(),
651 shape.as_ptr(),
652 shape.len(),
653 cstring_ptr(&name),
654 )
655 };
656 wrap_tensor(ptr)
657 }
658
659#[must_use]
661 pub fn convolution2d(
662 &self,
663 source: &Tensor,
664 weights: &Tensor,
665 descriptor: &Convolution2DDescriptor,
666 name: Option<&str>,
667 ) -> Option<Tensor> {
668 let name = optional_cstring(name);
669 let ptr = unsafe {
671 ffi::mpsgraph_graph_convolution2d(
672 self.ptr,
673 source.as_ptr(),
674 weights.as_ptr(),
675 descriptor.as_ptr(),
676 cstring_ptr(&name),
677 )
678 };
679 wrap_tensor(ptr)
680 }
681
682#[must_use]
684 pub fn max_pooling2d(
685 &self,
686 source: &Tensor,
687 descriptor: &Pooling2DDescriptor,
688 name: Option<&str>,
689 ) -> Option<Tensor> {
690 let name = optional_cstring(name);
691 let ptr = unsafe {
693 ffi::mpsgraph_graph_max_pooling2d(
694 self.ptr,
695 source.as_ptr(),
696 descriptor.as_ptr(),
697 cstring_ptr(&name),
698 )
699 };
700 wrap_tensor(ptr)
701 }
702
703#[allow(clippy::too_many_arguments)]
705 #[must_use]
706 pub fn normalize(
707 &self,
708 tensor: &Tensor,
709 mean: &Tensor,
710 variance: &Tensor,
711 gamma: Option<&Tensor>,
712 beta: Option<&Tensor>,
713 epsilon: f32,
714 name: Option<&str>,
715 ) -> Option<Tensor> {
716 let name = optional_cstring(name);
717 let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
718 let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
719 let ptr = unsafe {
721 ffi::mpsgraph_graph_normalize(
722 self.ptr,
723 tensor.as_ptr(),
724 mean.as_ptr(),
725 variance.as_ptr(),
726 gamma_ptr,
727 beta_ptr,
728 epsilon,
729 cstring_ptr(&name),
730 )
731 };
732 wrap_tensor(ptr)
733 }
734
735pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
737 let feed_tensors = feeds
738 .iter()
739 .map(|feed| feed.tensor.as_ptr())
740 .collect::<Vec<_>>();
741 let feed_data = feeds
742 .iter()
743 .map(|feed| feed.data.as_ptr())
744 .collect::<Vec<_>>();
745 let target_tensors = targets
746 .iter()
747 .map(|tensor| tensor.as_ptr())
748 .collect::<Vec<_>>();
749 let mut results = vec![ptr::null_mut(); targets.len()];
750
751 let ok = unsafe {
753 ffi::mpsgraph_graph_run(
754 self.ptr,
755 feed_tensors.as_ptr(),
756 feed_data.as_ptr(),
757 feeds.len(),
758 target_tensors.as_ptr(),
759 targets.len(),
760 results.as_mut_ptr(),
761 )
762 };
763 if ok {
764 wrap_tensor_data_results(results, "failed to run graph")
765 } else {
766 Err(Error::OperationFailed("failed to run graph"))
767 }
768 }
769
770pub fn run_with_command_queue(
772 &self,
773 command_queue: &CommandQueue,
774 feeds: &[Feed<'_>],
775 targets: &[&Tensor],
776 ) -> Result<Vec<TensorData>> {
777 let feed_tensors = feeds
778 .iter()
779 .map(|feed| feed.tensor.as_ptr())
780 .collect::<Vec<_>>();
781 let feed_data = feeds
782 .iter()
783 .map(|feed| feed.data.as_ptr())
784 .collect::<Vec<_>>();
785 let target_tensors = targets
786 .iter()
787 .map(|tensor| tensor.as_ptr())
788 .collect::<Vec<_>>();
789 let mut results = vec![ptr::null_mut(); targets.len()];
790
791 let ok = unsafe {
793 ffi::mpsgraph_graph_run_with_command_queue(
794 self.ptr,
795 command_queue.as_ptr(),
796 feed_tensors.as_ptr(),
797 feed_data.as_ptr(),
798 feeds.len(),
799 target_tensors.as_ptr(),
800 targets.len(),
801 results.as_mut_ptr(),
802 )
803 };
804 if ok {
805 wrap_tensor_data_results(results, "failed to run graph with command queue")
806 } else {
807 Err(Error::OperationFailed(
808 "failed to run graph with command queue",
809 ))
810 }
811 }
812
813#[must_use]
815 pub fn compile(
816 &self,
817 device: &MetalDevice,
818 feeds: &[FeedDescription<'_>],
819 targets: &[&Tensor],
820 ) -> Option<Executable> {
821 let feed_tensors = feeds
822 .iter()
823 .map(|feed| feed.tensor.as_ptr())
824 .collect::<Vec<_>>();
825 let shape_lengths = feeds
826 .iter()
827 .map(|feed| feed.shape.len())
828 .collect::<Vec<_>>();
829 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
830 let flat_shapes = feeds
831 .iter()
832 .flat_map(|feed| feed.shape.iter().copied())
833 .collect::<Vec<_>>();
834 let target_tensors = targets
835 .iter()
836 .map(|tensor| tensor.as_ptr())
837 .collect::<Vec<_>>();
838
839 let ptr = unsafe {
841 ffi::mpsgraph_graph_compile(
842 self.ptr,
843 device.as_ptr(),
844 feed_tensors.as_ptr(),
845 feeds.len(),
846 flat_shapes.as_ptr(),
847 shape_lengths.as_ptr(),
848 data_types.as_ptr(),
849 target_tensors.as_ptr(),
850 targets.len(),
851 )
852 };
853 if ptr.is_null() {
854 None
855 } else {
856 Some(Executable::from_raw(ptr, targets.len()))
857 }
858 }
859}
860
861pub struct Executable {
863 ptr: *mut c_void,
864 output_count: usize,
865}
866
867unsafe impl Send for Executable {}
868unsafe impl Sync for Executable {}
869
870impl Drop for Executable {
871 fn drop(&mut self) {
872 if !self.ptr.is_null() {
873 unsafe { ffi::mpsgraph_object_release(self.ptr) };
875 self.ptr = ptr::null_mut();
876 }
877 }
878}
879
880impl Executable {
881 pub(crate) const fn from_raw(ptr: *mut c_void, output_count: usize) -> Self {
882 Self { ptr, output_count }
883 }
884
885#[must_use]
887 pub const fn as_ptr(&self) -> *mut c_void {
888 self.ptr
889 }
890
891#[must_use]
893 pub const fn output_count(&self) -> usize {
894 self.output_count
895 }
896
897pub fn run(
899 &self,
900 command_queue: &CommandQueue,
901 inputs: &[&TensorData],
902 ) -> Result<Vec<TensorData>> {
903 let input_data = inputs
904 .iter()
905 .map(|tensor_data| tensor_data.as_ptr())
906 .collect::<Vec<_>>();
907 let mut results = vec![ptr::null_mut(); self.output_count];
908
909 let ok = unsafe {
911 ffi::mpsgraph_executable_run(
912 self.ptr,
913 command_queue.as_ptr(),
914 input_data.as_ptr(),
915 inputs.len(),
916 self.output_count,
917 results.as_mut_ptr(),
918 )
919 };
920 if ok {
921 wrap_tensor_data_results(results, "failed to run executable")
922 } else {
923 Err(Error::OperationFailed("failed to run executable"))
924 }
925 }
926}