1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9pub mod data_type {
11 pub const INVALID: u32 = 0;
12 pub const FLOAT32: u32 = 0x1000_0020;
13 pub const FLOAT16: u32 = 0x1000_0010;
14 pub const INT8: u32 = 0x2000_0008;
15 pub const INT16: u32 = 0x2000_0010;
16 pub const INT32: u32 = 0x2000_0020;
17 pub const INT64: u32 = 0x2000_0040;
18 pub const UINT8: u32 = 0x0000_0008;
19 pub const UINT16: u32 = 0x0000_0010;
20 pub const UINT32: u32 = 0x0000_0020;
21 pub const UINT64: u32 = 0x0000_0040;
22 pub const BOOL: u32 = 0x8000_0008;
23 pub const UNORM8: u32 = 0x4000_0008;
24}
25
26#[must_use]
28pub const fn data_type_size(data_type: u32) -> Option<usize> {
29 match data_type {
30 data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
31 data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
32 data_type::INT64 | data_type::UINT64 => Some(8),
33 data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
34 _ => None,
35 }
36}
37
38pub mod tensor_named_data_layout {
40 pub const NCHW: usize = 0;
41 pub const NHWC: usize = 1;
42 pub const OIHW: usize = 2;
43 pub const HWIO: usize = 3;
44 pub const CHW: usize = 4;
45 pub const HWC: usize = 5;
46 pub const HW: usize = 6;
47}
48
49pub mod padding_style {
51 pub const EXPLICIT: usize = 0;
52 pub const TF_VALID: usize = 1;
53 pub const TF_SAME: usize = 2;
54 pub const EXPLICIT_OFFSET: usize = 3;
55 pub const ONNX_SAME_LOWER: usize = 4;
56}
57
58macro_rules! opaque_handle {
59 ($name:ident) => {
60 pub struct $name {
61 ptr: *mut c_void,
62 }
63
64 unsafe impl Send for $name {}
65 unsafe impl Sync for $name {}
66
67 impl Drop for $name {
68 fn drop(&mut self) {
69 if !self.ptr.is_null() {
70 unsafe { ffi::mpsgraph_object_release(self.ptr) };
72 self.ptr = ptr::null_mut();
73 }
74 }
75 }
76
77 impl $name {
78 #[must_use]
79 pub const fn as_ptr(&self) -> *mut c_void {
80 self.ptr
81 }
82 }
83 };
84}
85
86fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
87 let element_size = data_type_size(data_type)?;
88 shape
89 .iter()
90 .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
91}
92
93fn optional_cstring(name: Option<&str>) -> Option<CString> {
94 name.and_then(|value| CString::new(value).ok())
95}
96
97#[allow(clippy::ref_option)]
98fn cstring_ptr(value: &Option<CString>) -> *const c_char {
99 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
100}
101
102fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
103 if ptr.is_null() {
104 None
105 } else {
106 Some(Tensor { ptr })
107 }
108}
109
110fn wrap_tensor_data_results(
111 handles: Vec<*mut c_void>,
112 message: &'static str,
113) -> Result<Vec<TensorData>> {
114 let mut results = Vec::with_capacity(handles.len());
115 for handle in handles {
116 if handle.is_null() {
117 return Err(Error::OperationFailed(message));
118 }
119 results.push(TensorData::from_raw(handle));
120 }
121 Ok(results)
122}
123
124macro_rules! impl_binary_tensor_op {
125 ($fn_name:ident, $ffi_name:ident) => {
126 #[must_use]
127 pub fn $fn_name(
128 &self,
129 primary: &Tensor,
130 secondary: &Tensor,
131 name: Option<&str>,
132 ) -> Option<Tensor> {
133 let name = optional_cstring(name);
134 let ptr = unsafe {
136 ffi::$ffi_name(
137 self.ptr,
138 primary.as_ptr(),
139 secondary.as_ptr(),
140 cstring_ptr(&name),
141 )
142 };
143 wrap_tensor(ptr)
144 }
145 };
146}
147
148macro_rules! impl_unary_tensor_op {
149 ($fn_name:ident, $ffi_name:ident) => {
150 #[must_use]
151 pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
152 let name = optional_cstring(name);
153 let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
155 wrap_tensor(ptr)
156 }
157 };
158}
159
160macro_rules! impl_axes_tensor_op {
161 ($fn_name:ident, $ffi_name:ident) => {
162 #[must_use]
163 pub fn $fn_name(
164 &self,
165 tensor: &Tensor,
166 axes: &[usize],
167 name: Option<&str>,
168 ) -> Option<Tensor> {
169 let name = optional_cstring(name);
170 let ptr = unsafe {
172 ffi::$ffi_name(
173 self.ptr,
174 tensor.as_ptr(),
175 axes.as_ptr(),
176 axes.len(),
177 cstring_ptr(&name),
178 )
179 };
180 wrap_tensor(ptr)
181 }
182 };
183}
184
185#[derive(Clone, Copy)]
187pub struct Feed<'a> {
188 pub tensor: &'a Tensor,
189 pub data: &'a TensorData,
190}
191
192impl<'a> Feed<'a> {
193 #[must_use]
194 pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
195 Self { tensor, data }
196 }
197}
198
199#[derive(Clone, Copy)]
201pub struct FeedDescription<'a> {
202 pub tensor: &'a Tensor,
203 pub shape: &'a [usize],
204 pub data_type: u32,
205}
206
207impl<'a> FeedDescription<'a> {
208 #[must_use]
209 pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
210 Self {
211 tensor,
212 shape,
213 data_type,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Copy)]
220pub struct Convolution2DDescriptorInfo {
221 pub stride_in_x: usize,
222 pub stride_in_y: usize,
223 pub dilation_rate_in_x: usize,
224 pub dilation_rate_in_y: usize,
225 pub groups: usize,
226 pub padding_left: usize,
227 pub padding_right: usize,
228 pub padding_top: usize,
229 pub padding_bottom: usize,
230 pub padding_style: usize,
231 pub data_layout: usize,
232 pub weights_layout: usize,
233}
234
235impl Default for Convolution2DDescriptorInfo {
236 fn default() -> Self {
237 Self {
238 stride_in_x: 1,
239 stride_in_y: 1,
240 dilation_rate_in_x: 1,
241 dilation_rate_in_y: 1,
242 groups: 1,
243 padding_left: 0,
244 padding_right: 0,
245 padding_top: 0,
246 padding_bottom: 0,
247 padding_style: padding_style::EXPLICIT,
248 data_layout: tensor_named_data_layout::NHWC,
249 weights_layout: tensor_named_data_layout::HWIO,
250 }
251 }
252}
253
254opaque_handle!(Convolution2DDescriptor);
255impl Convolution2DDescriptor {
256 #[must_use]
257 pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
258 let ptr = unsafe {
260 ffi::mpsgraph_convolution2d_descriptor_new(
261 info.stride_in_x,
262 info.stride_in_y,
263 info.dilation_rate_in_x,
264 info.dilation_rate_in_y,
265 info.groups,
266 info.padding_left,
267 info.padding_right,
268 info.padding_top,
269 info.padding_bottom,
270 info.padding_style,
271 info.data_layout,
272 info.weights_layout,
273 )
274 };
275 if ptr.is_null() {
276 None
277 } else {
278 Some(Self { ptr })
279 }
280 }
281}
282
283#[derive(Debug, Clone, Copy)]
285pub struct Pooling2DDescriptorInfo {
286 pub kernel_width: usize,
287 pub kernel_height: usize,
288 pub stride_in_x: usize,
289 pub stride_in_y: usize,
290 pub dilation_rate_in_x: usize,
291 pub dilation_rate_in_y: usize,
292 pub padding_left: usize,
293 pub padding_right: usize,
294 pub padding_top: usize,
295 pub padding_bottom: usize,
296 pub padding_style: usize,
297 pub data_layout: usize,
298}
299
300impl Pooling2DDescriptorInfo {
301 #[must_use]
302 pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
303 Self {
304 kernel_width,
305 kernel_height,
306 stride_in_x: 1,
307 stride_in_y: 1,
308 dilation_rate_in_x: 1,
309 dilation_rate_in_y: 1,
310 padding_left: 0,
311 padding_right: 0,
312 padding_top: 0,
313 padding_bottom: 0,
314 padding_style: padding_style::EXPLICIT,
315 data_layout: tensor_named_data_layout::NHWC,
316 }
317 }
318}
319
320opaque_handle!(Pooling2DDescriptor);
321impl Pooling2DDescriptor {
322 #[must_use]
323 pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
324 let ptr = unsafe {
326 ffi::mpsgraph_pooling2d_descriptor_new(
327 info.kernel_width,
328 info.kernel_height,
329 info.stride_in_x,
330 info.stride_in_y,
331 info.dilation_rate_in_x,
332 info.dilation_rate_in_y,
333 info.padding_left,
334 info.padding_right,
335 info.padding_top,
336 info.padding_bottom,
337 info.padding_style,
338 info.data_layout,
339 )
340 };
341 if ptr.is_null() {
342 None
343 } else {
344 Some(Self { ptr })
345 }
346 }
347}
348
349opaque_handle!(Graph);
350opaque_handle!(Tensor);
351
352impl Tensor {
353 pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
354 Self { ptr }
355 }
356}
357
358impl Graph {
359 #[must_use]
360 pub fn new() -> Option<Self> {
361 let ptr = unsafe { ffi::mpsgraph_graph_new() };
363 if ptr.is_null() {
364 None
365 } else {
366 Some(Self { ptr })
367 }
368 }
369
370 #[must_use]
371 pub fn placeholder(
372 &self,
373 shape: Option<&[usize]>,
374 data_type: u32,
375 name: Option<&str>,
376 ) -> Option<Tensor> {
377 let name = optional_cstring(name);
378 let (shape_ptr, shape_len) =
379 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
380
381 let ptr = unsafe {
383 ffi::mpsgraph_graph_placeholder(
384 self.ptr,
385 shape_ptr,
386 shape_len,
387 data_type,
388 cstring_ptr(&name),
389 )
390 };
391 wrap_tensor(ptr)
392 }
393
394 #[must_use]
395 pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
396 let expected = checked_byte_len(shape, data_type)?;
397 if data.len() != expected {
398 return None;
399 }
400
401 let ptr = unsafe {
403 ffi::mpsgraph_graph_constant_data(
404 self.ptr,
405 data.as_ptr().cast(),
406 data.len(),
407 shape.as_ptr(),
408 shape.len(),
409 data_type,
410 )
411 };
412 wrap_tensor(ptr)
413 }
414
415 #[must_use]
416 pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
417 let bytes = unsafe {
419 core::slice::from_raw_parts(
420 values.as_ptr().cast::<u8>(),
421 core::mem::size_of_val(values),
422 )
423 };
424 self.constant_bytes(bytes, shape, data_type::FLOAT32)
425 }
426
427 #[must_use]
428 pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
429 let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
431 wrap_tensor(ptr)
432 }
433
434 #[must_use]
435 pub fn constant_scalar_shaped(
436 &self,
437 scalar: f64,
438 shape: &[usize],
439 data_type: u32,
440 ) -> Option<Tensor> {
441 let ptr = unsafe {
443 ffi::mpsgraph_graph_constant_scalar_shaped(
444 self.ptr,
445 scalar,
446 shape.as_ptr(),
447 shape.len(),
448 data_type,
449 )
450 };
451 wrap_tensor(ptr)
452 }
453
454 impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
455 impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
456 impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
457 impl_binary_tensor_op!(division, mpsgraph_graph_division);
458 impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
459 impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
460 impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
461 impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
462 impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
463 impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
464 impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
465
466 #[must_use]
467 pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
468 let name = optional_cstring(name);
469 let ptr = unsafe {
471 ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
472 };
473 wrap_tensor(ptr)
474 }
475
476 #[must_use]
477 pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
478 let name = optional_cstring(name);
479 let ptr = unsafe {
481 ffi::mpsgraph_graph_reshape(
482 self.ptr,
483 tensor.as_ptr(),
484 shape.as_ptr(),
485 shape.len(),
486 cstring_ptr(&name),
487 )
488 };
489 wrap_tensor(ptr)
490 }
491
492 #[must_use]
493 pub fn transpose(
494 &self,
495 tensor: &Tensor,
496 permutation: &[usize],
497 name: Option<&str>,
498 ) -> Option<Tensor> {
499 let name = optional_cstring(name);
500 let ptr = unsafe {
502 ffi::mpsgraph_graph_transpose(
503 self.ptr,
504 tensor.as_ptr(),
505 permutation.as_ptr(),
506 permutation.len(),
507 cstring_ptr(&name),
508 )
509 };
510 wrap_tensor(ptr)
511 }
512
513 #[must_use]
514 pub fn slice(
515 &self,
516 tensor: &Tensor,
517 dimension: usize,
518 start: isize,
519 length: isize,
520 name: Option<&str>,
521 ) -> Option<Tensor> {
522 let name = optional_cstring(name);
523 let ptr = unsafe {
525 ffi::mpsgraph_graph_slice(
526 self.ptr,
527 tensor.as_ptr(),
528 dimension,
529 start,
530 length,
531 cstring_ptr(&name),
532 )
533 };
534 wrap_tensor(ptr)
535 }
536
537 #[must_use]
538 pub fn broadcast(
539 &self,
540 tensor: &Tensor,
541 shape: &[usize],
542 name: Option<&str>,
543 ) -> Option<Tensor> {
544 let name = optional_cstring(name);
545 let ptr = unsafe {
547 ffi::mpsgraph_graph_broadcast(
548 self.ptr,
549 tensor.as_ptr(),
550 shape.as_ptr(),
551 shape.len(),
552 cstring_ptr(&name),
553 )
554 };
555 wrap_tensor(ptr)
556 }
557
558 #[must_use]
559 pub fn convolution2d(
560 &self,
561 source: &Tensor,
562 weights: &Tensor,
563 descriptor: &Convolution2DDescriptor,
564 name: Option<&str>,
565 ) -> Option<Tensor> {
566 let name = optional_cstring(name);
567 let ptr = unsafe {
569 ffi::mpsgraph_graph_convolution2d(
570 self.ptr,
571 source.as_ptr(),
572 weights.as_ptr(),
573 descriptor.as_ptr(),
574 cstring_ptr(&name),
575 )
576 };
577 wrap_tensor(ptr)
578 }
579
580 #[must_use]
581 pub fn max_pooling2d(
582 &self,
583 source: &Tensor,
584 descriptor: &Pooling2DDescriptor,
585 name: Option<&str>,
586 ) -> Option<Tensor> {
587 let name = optional_cstring(name);
588 let ptr = unsafe {
590 ffi::mpsgraph_graph_max_pooling2d(
591 self.ptr,
592 source.as_ptr(),
593 descriptor.as_ptr(),
594 cstring_ptr(&name),
595 )
596 };
597 wrap_tensor(ptr)
598 }
599
600 #[allow(clippy::too_many_arguments)]
601 #[must_use]
602 pub fn normalize(
603 &self,
604 tensor: &Tensor,
605 mean: &Tensor,
606 variance: &Tensor,
607 gamma: Option<&Tensor>,
608 beta: Option<&Tensor>,
609 epsilon: f32,
610 name: Option<&str>,
611 ) -> Option<Tensor> {
612 let name = optional_cstring(name);
613 let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
614 let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
615 let ptr = unsafe {
617 ffi::mpsgraph_graph_normalize(
618 self.ptr,
619 tensor.as_ptr(),
620 mean.as_ptr(),
621 variance.as_ptr(),
622 gamma_ptr,
623 beta_ptr,
624 epsilon,
625 cstring_ptr(&name),
626 )
627 };
628 wrap_tensor(ptr)
629 }
630
631 pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
632 let feed_tensors = feeds
633 .iter()
634 .map(|feed| feed.tensor.as_ptr())
635 .collect::<Vec<_>>();
636 let feed_data = feeds
637 .iter()
638 .map(|feed| feed.data.as_ptr())
639 .collect::<Vec<_>>();
640 let target_tensors = targets
641 .iter()
642 .map(|tensor| tensor.as_ptr())
643 .collect::<Vec<_>>();
644 let mut results = vec![ptr::null_mut(); targets.len()];
645
646 let ok = unsafe {
648 ffi::mpsgraph_graph_run(
649 self.ptr,
650 feed_tensors.as_ptr(),
651 feed_data.as_ptr(),
652 feeds.len(),
653 target_tensors.as_ptr(),
654 targets.len(),
655 results.as_mut_ptr(),
656 )
657 };
658 if ok {
659 wrap_tensor_data_results(results, "failed to run graph")
660 } else {
661 Err(Error::OperationFailed("failed to run graph"))
662 }
663 }
664
665 pub fn run_with_command_queue(
666 &self,
667 command_queue: &CommandQueue,
668 feeds: &[Feed<'_>],
669 targets: &[&Tensor],
670 ) -> Result<Vec<TensorData>> {
671 let feed_tensors = feeds
672 .iter()
673 .map(|feed| feed.tensor.as_ptr())
674 .collect::<Vec<_>>();
675 let feed_data = feeds
676 .iter()
677 .map(|feed| feed.data.as_ptr())
678 .collect::<Vec<_>>();
679 let target_tensors = targets
680 .iter()
681 .map(|tensor| tensor.as_ptr())
682 .collect::<Vec<_>>();
683 let mut results = vec![ptr::null_mut(); targets.len()];
684
685 let ok = unsafe {
687 ffi::mpsgraph_graph_run_with_command_queue(
688 self.ptr,
689 command_queue.as_ptr(),
690 feed_tensors.as_ptr(),
691 feed_data.as_ptr(),
692 feeds.len(),
693 target_tensors.as_ptr(),
694 targets.len(),
695 results.as_mut_ptr(),
696 )
697 };
698 if ok {
699 wrap_tensor_data_results(results, "failed to run graph with command queue")
700 } else {
701 Err(Error::OperationFailed(
702 "failed to run graph with command queue",
703 ))
704 }
705 }
706
707 #[must_use]
708 pub fn compile(
709 &self,
710 device: &MetalDevice,
711 feeds: &[FeedDescription<'_>],
712 targets: &[&Tensor],
713 ) -> Option<Executable> {
714 let feed_tensors = feeds
715 .iter()
716 .map(|feed| feed.tensor.as_ptr())
717 .collect::<Vec<_>>();
718 let shape_lengths = feeds
719 .iter()
720 .map(|feed| feed.shape.len())
721 .collect::<Vec<_>>();
722 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
723 let flat_shapes = feeds
724 .iter()
725 .flat_map(|feed| feed.shape.iter().copied())
726 .collect::<Vec<_>>();
727 let target_tensors = targets
728 .iter()
729 .map(|tensor| tensor.as_ptr())
730 .collect::<Vec<_>>();
731
732 let ptr = unsafe {
734 ffi::mpsgraph_graph_compile(
735 self.ptr,
736 device.as_ptr(),
737 feed_tensors.as_ptr(),
738 feeds.len(),
739 flat_shapes.as_ptr(),
740 shape_lengths.as_ptr(),
741 data_types.as_ptr(),
742 target_tensors.as_ptr(),
743 targets.len(),
744 )
745 };
746 if ptr.is_null() {
747 None
748 } else {
749 Some(Executable::from_raw(ptr, targets.len()))
750 }
751 }
752}
753
754pub struct Executable {
756 ptr: *mut c_void,
757 output_count: usize,
758}
759
760unsafe impl Send for Executable {}
761unsafe impl Sync for Executable {}
762
763impl Drop for Executable {
764 fn drop(&mut self) {
765 if !self.ptr.is_null() {
766 unsafe { ffi::mpsgraph_object_release(self.ptr) };
768 self.ptr = ptr::null_mut();
769 }
770 }
771}
772
773impl Executable {
774 pub(crate) const fn from_raw(ptr: *mut c_void, output_count: usize) -> Self {
775 Self { ptr, output_count }
776 }
777
778 #[must_use]
779 pub const fn as_ptr(&self) -> *mut c_void {
780 self.ptr
781 }
782
783 #[must_use]
784 pub const fn output_count(&self) -> usize {
785 self.output_count
786 }
787
788 pub fn run(
789 &self,
790 command_queue: &CommandQueue,
791 inputs: &[&TensorData],
792 ) -> Result<Vec<TensorData>> {
793 let input_data = inputs
794 .iter()
795 .map(|tensor_data| tensor_data.as_ptr())
796 .collect::<Vec<_>>();
797 let mut results = vec![ptr::null_mut(); self.output_count];
798
799 let ok = unsafe {
801 ffi::mpsgraph_executable_run(
802 self.ptr,
803 command_queue.as_ptr(),
804 input_data.as_ptr(),
805 inputs.len(),
806 self.output_count,
807 results.as_mut_ptr(),
808 )
809 };
810 if ok {
811 wrap_tensor_data_results(results, "failed to run executable")
812 } else {
813 Err(Error::OperationFailed("failed to run executable"))
814 }
815 }
816}