1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9pub mod data_type {
11 pub const INVALID: u32 = 0;
12 pub const FLOAT32: u32 = 0x1000_0020;
13 pub const FLOAT16: u32 = 0x1000_0010;
14 pub const INT8: u32 = 0x2000_0008;
15 pub const INT16: u32 = 0x2000_0010;
16 pub const INT32: u32 = 0x2000_0020;
17 pub const INT64: u32 = 0x2000_0040;
18 pub const UINT8: u32 = 0x0000_0008;
19 pub const UINT16: u32 = 0x0000_0010;
20 pub const UINT32: u32 = 0x0000_0020;
21 pub const UINT64: u32 = 0x0000_0040;
22 pub const BOOL: u32 = 0x8000_0008;
23 pub const UNORM8: u32 = 0x4000_0008;
24}
25
26#[must_use]
28pub const fn data_type_size(data_type: u32) -> Option<usize> {
29 match data_type {
30 data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
31 data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
32 data_type::INT64 | data_type::UINT64 => Some(8),
33 data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
34 _ => None,
35 }
36}
37
38pub mod tensor_named_data_layout {
40 pub const NCHW: usize = 0;
41 pub const NHWC: usize = 1;
42 pub const OIHW: usize = 2;
43 pub const HWIO: usize = 3;
44 pub const CHW: usize = 4;
45 pub const HWC: usize = 5;
46 pub const HW: usize = 6;
47}
48
49pub mod padding_style {
51 pub const EXPLICIT: usize = 0;
52 pub const TF_VALID: usize = 1;
53 pub const TF_SAME: usize = 2;
54 pub const EXPLICIT_OFFSET: usize = 3;
55 pub const ONNX_SAME_LOWER: usize = 4;
56}
57
58macro_rules! opaque_handle {
59 ($name:ident) => {
60 pub struct $name {
61 ptr: *mut c_void,
62 }
63
64 unsafe impl Send for $name {}
65 unsafe impl Sync for $name {}
66
67 impl Drop for $name {
68 fn drop(&mut self) {
69 if !self.ptr.is_null() {
70 unsafe { ffi::mpsgraph_object_release(self.ptr) };
72 self.ptr = ptr::null_mut();
73 }
74 }
75 }
76
77 impl $name {
78 #[must_use]
79 pub const fn as_ptr(&self) -> *mut c_void {
80 self.ptr
81 }
82 }
83 };
84}
85
86fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
87 let element_size = data_type_size(data_type)?;
88 shape
89 .iter()
90 .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
91}
92
93fn optional_cstring(name: Option<&str>) -> Option<CString> {
94 name.and_then(|value| CString::new(value).ok())
95}
96
97#[allow(clippy::ref_option)]
98fn cstring_ptr(value: &Option<CString>) -> *const c_char {
99 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
100}
101
102fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
103 if ptr.is_null() {
104 None
105 } else {
106 Some(Tensor { ptr })
107 }
108}
109
110fn wrap_tensor_data_results(
111 handles: Vec<*mut c_void>,
112 message: &'static str,
113) -> Result<Vec<TensorData>> {
114 let mut results = Vec::with_capacity(handles.len());
115 for handle in handles {
116 if handle.is_null() {
117 return Err(Error::OperationFailed(message));
118 }
119 results.push(TensorData::from_raw(handle));
120 }
121 Ok(results)
122}
123
124macro_rules! impl_binary_tensor_op {
125 ($fn_name:ident, $ffi_name:ident) => {
126 #[must_use]
127 pub fn $fn_name(
128 &self,
129 primary: &Tensor,
130 secondary: &Tensor,
131 name: Option<&str>,
132 ) -> Option<Tensor> {
133 let name = optional_cstring(name);
134 let ptr = unsafe {
136 ffi::$ffi_name(
137 self.ptr,
138 primary.as_ptr(),
139 secondary.as_ptr(),
140 cstring_ptr(&name),
141 )
142 };
143 wrap_tensor(ptr)
144 }
145 };
146}
147
148macro_rules! impl_unary_tensor_op {
149 ($fn_name:ident, $ffi_name:ident) => {
150 #[must_use]
151 pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
152 let name = optional_cstring(name);
153 let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
155 wrap_tensor(ptr)
156 }
157 };
158}
159
160macro_rules! impl_axes_tensor_op {
161 ($fn_name:ident, $ffi_name:ident) => {
162 #[must_use]
163 pub fn $fn_name(
164 &self,
165 tensor: &Tensor,
166 axes: &[usize],
167 name: Option<&str>,
168 ) -> Option<Tensor> {
169 let name = optional_cstring(name);
170 let ptr = unsafe {
172 ffi::$ffi_name(
173 self.ptr,
174 tensor.as_ptr(),
175 axes.as_ptr(),
176 axes.len(),
177 cstring_ptr(&name),
178 )
179 };
180 wrap_tensor(ptr)
181 }
182 };
183}
184
185#[derive(Clone, Copy)]
187pub struct Feed<'a> {
188 pub tensor: &'a Tensor,
189 pub data: &'a TensorData,
190}
191
192impl<'a> Feed<'a> {
193 #[must_use]
194 pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
195 Self { tensor, data }
196 }
197}
198
199#[derive(Clone, Copy)]
201pub struct FeedDescription<'a> {
202 pub tensor: &'a Tensor,
203 pub shape: &'a [usize],
204 pub data_type: u32,
205}
206
207impl<'a> FeedDescription<'a> {
208 #[must_use]
209 pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
210 Self {
211 tensor,
212 shape,
213 data_type,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Copy)]
220pub struct Convolution2DDescriptorInfo {
221 pub stride_in_x: usize,
222 pub stride_in_y: usize,
223 pub dilation_rate_in_x: usize,
224 pub dilation_rate_in_y: usize,
225 pub groups: usize,
226 pub padding_left: usize,
227 pub padding_right: usize,
228 pub padding_top: usize,
229 pub padding_bottom: usize,
230 pub padding_style: usize,
231 pub data_layout: usize,
232 pub weights_layout: usize,
233}
234
235impl Default for Convolution2DDescriptorInfo {
236 fn default() -> Self {
237 Self {
238 stride_in_x: 1,
239 stride_in_y: 1,
240 dilation_rate_in_x: 1,
241 dilation_rate_in_y: 1,
242 groups: 1,
243 padding_left: 0,
244 padding_right: 0,
245 padding_top: 0,
246 padding_bottom: 0,
247 padding_style: padding_style::EXPLICIT,
248 data_layout: tensor_named_data_layout::NHWC,
249 weights_layout: tensor_named_data_layout::HWIO,
250 }
251 }
252}
253
254opaque_handle!(Convolution2DDescriptor);
255impl Convolution2DDescriptor {
256 #[must_use]
257 pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
258 let ptr = unsafe {
260 ffi::mpsgraph_convolution2d_descriptor_new(
261 info.stride_in_x,
262 info.stride_in_y,
263 info.dilation_rate_in_x,
264 info.dilation_rate_in_y,
265 info.groups,
266 info.padding_left,
267 info.padding_right,
268 info.padding_top,
269 info.padding_bottom,
270 info.padding_style,
271 info.data_layout,
272 info.weights_layout,
273 )
274 };
275 if ptr.is_null() {
276 None
277 } else {
278 Some(Self { ptr })
279 }
280 }
281}
282
283#[derive(Debug, Clone, Copy)]
285pub struct Pooling2DDescriptorInfo {
286 pub kernel_width: usize,
287 pub kernel_height: usize,
288 pub stride_in_x: usize,
289 pub stride_in_y: usize,
290 pub dilation_rate_in_x: usize,
291 pub dilation_rate_in_y: usize,
292 pub padding_left: usize,
293 pub padding_right: usize,
294 pub padding_top: usize,
295 pub padding_bottom: usize,
296 pub padding_style: usize,
297 pub data_layout: usize,
298}
299
300impl Pooling2DDescriptorInfo {
301 #[must_use]
302 pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
303 Self {
304 kernel_width,
305 kernel_height,
306 stride_in_x: 1,
307 stride_in_y: 1,
308 dilation_rate_in_x: 1,
309 dilation_rate_in_y: 1,
310 padding_left: 0,
311 padding_right: 0,
312 padding_top: 0,
313 padding_bottom: 0,
314 padding_style: padding_style::EXPLICIT,
315 data_layout: tensor_named_data_layout::NHWC,
316 }
317 }
318}
319
320opaque_handle!(Pooling2DDescriptor);
321impl Pooling2DDescriptor {
322 #[must_use]
323 pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
324 let ptr = unsafe {
326 ffi::mpsgraph_pooling2d_descriptor_new(
327 info.kernel_width,
328 info.kernel_height,
329 info.stride_in_x,
330 info.stride_in_y,
331 info.dilation_rate_in_x,
332 info.dilation_rate_in_y,
333 info.padding_left,
334 info.padding_right,
335 info.padding_top,
336 info.padding_bottom,
337 info.padding_style,
338 info.data_layout,
339 )
340 };
341 if ptr.is_null() {
342 None
343 } else {
344 Some(Self { ptr })
345 }
346 }
347}
348
349opaque_handle!(Graph);
350opaque_handle!(Tensor);
351
352impl Graph {
353 #[must_use]
354 pub fn new() -> Option<Self> {
355 let ptr = unsafe { ffi::mpsgraph_graph_new() };
357 if ptr.is_null() {
358 None
359 } else {
360 Some(Self { ptr })
361 }
362 }
363
364 #[must_use]
365 pub fn placeholder(
366 &self,
367 shape: Option<&[usize]>,
368 data_type: u32,
369 name: Option<&str>,
370 ) -> Option<Tensor> {
371 let name = optional_cstring(name);
372 let (shape_ptr, shape_len) =
373 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
374
375 let ptr = unsafe {
377 ffi::mpsgraph_graph_placeholder(
378 self.ptr,
379 shape_ptr,
380 shape_len,
381 data_type,
382 cstring_ptr(&name),
383 )
384 };
385 wrap_tensor(ptr)
386 }
387
388 #[must_use]
389 pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
390 let expected = checked_byte_len(shape, data_type)?;
391 if data.len() != expected {
392 return None;
393 }
394
395 let ptr = unsafe {
397 ffi::mpsgraph_graph_constant_data(
398 self.ptr,
399 data.as_ptr().cast(),
400 data.len(),
401 shape.as_ptr(),
402 shape.len(),
403 data_type,
404 )
405 };
406 wrap_tensor(ptr)
407 }
408
409 #[must_use]
410 pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
411 let bytes = unsafe {
413 core::slice::from_raw_parts(
414 values.as_ptr().cast::<u8>(),
415 core::mem::size_of_val(values),
416 )
417 };
418 self.constant_bytes(bytes, shape, data_type::FLOAT32)
419 }
420
421 #[must_use]
422 pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
423 let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
425 wrap_tensor(ptr)
426 }
427
428 #[must_use]
429 pub fn constant_scalar_shaped(
430 &self,
431 scalar: f64,
432 shape: &[usize],
433 data_type: u32,
434 ) -> Option<Tensor> {
435 let ptr = unsafe {
437 ffi::mpsgraph_graph_constant_scalar_shaped(
438 self.ptr,
439 scalar,
440 shape.as_ptr(),
441 shape.len(),
442 data_type,
443 )
444 };
445 wrap_tensor(ptr)
446 }
447
448 impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
449 impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
450 impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
451 impl_binary_tensor_op!(division, mpsgraph_graph_division);
452 impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
453 impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
454 impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
455 impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
456 impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
457 impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
458 impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
459
460 #[must_use]
461 pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
462 let name = optional_cstring(name);
463 let ptr = unsafe {
465 ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
466 };
467 wrap_tensor(ptr)
468 }
469
470 #[must_use]
471 pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
472 let name = optional_cstring(name);
473 let ptr = unsafe {
475 ffi::mpsgraph_graph_reshape(
476 self.ptr,
477 tensor.as_ptr(),
478 shape.as_ptr(),
479 shape.len(),
480 cstring_ptr(&name),
481 )
482 };
483 wrap_tensor(ptr)
484 }
485
486 #[must_use]
487 pub fn transpose(
488 &self,
489 tensor: &Tensor,
490 permutation: &[usize],
491 name: Option<&str>,
492 ) -> Option<Tensor> {
493 let name = optional_cstring(name);
494 let ptr = unsafe {
496 ffi::mpsgraph_graph_transpose(
497 self.ptr,
498 tensor.as_ptr(),
499 permutation.as_ptr(),
500 permutation.len(),
501 cstring_ptr(&name),
502 )
503 };
504 wrap_tensor(ptr)
505 }
506
507 #[must_use]
508 pub fn slice(
509 &self,
510 tensor: &Tensor,
511 dimension: usize,
512 start: isize,
513 length: isize,
514 name: Option<&str>,
515 ) -> Option<Tensor> {
516 let name = optional_cstring(name);
517 let ptr = unsafe {
519 ffi::mpsgraph_graph_slice(
520 self.ptr,
521 tensor.as_ptr(),
522 dimension,
523 start,
524 length,
525 cstring_ptr(&name),
526 )
527 };
528 wrap_tensor(ptr)
529 }
530
531 #[must_use]
532 pub fn broadcast(
533 &self,
534 tensor: &Tensor,
535 shape: &[usize],
536 name: Option<&str>,
537 ) -> Option<Tensor> {
538 let name = optional_cstring(name);
539 let ptr = unsafe {
541 ffi::mpsgraph_graph_broadcast(
542 self.ptr,
543 tensor.as_ptr(),
544 shape.as_ptr(),
545 shape.len(),
546 cstring_ptr(&name),
547 )
548 };
549 wrap_tensor(ptr)
550 }
551
552 #[must_use]
553 pub fn convolution2d(
554 &self,
555 source: &Tensor,
556 weights: &Tensor,
557 descriptor: &Convolution2DDescriptor,
558 name: Option<&str>,
559 ) -> Option<Tensor> {
560 let name = optional_cstring(name);
561 let ptr = unsafe {
563 ffi::mpsgraph_graph_convolution2d(
564 self.ptr,
565 source.as_ptr(),
566 weights.as_ptr(),
567 descriptor.as_ptr(),
568 cstring_ptr(&name),
569 )
570 };
571 wrap_tensor(ptr)
572 }
573
574 #[must_use]
575 pub fn max_pooling2d(
576 &self,
577 source: &Tensor,
578 descriptor: &Pooling2DDescriptor,
579 name: Option<&str>,
580 ) -> Option<Tensor> {
581 let name = optional_cstring(name);
582 let ptr = unsafe {
584 ffi::mpsgraph_graph_max_pooling2d(
585 self.ptr,
586 source.as_ptr(),
587 descriptor.as_ptr(),
588 cstring_ptr(&name),
589 )
590 };
591 wrap_tensor(ptr)
592 }
593
594 #[allow(clippy::too_many_arguments)]
595 #[must_use]
596 pub fn normalize(
597 &self,
598 tensor: &Tensor,
599 mean: &Tensor,
600 variance: &Tensor,
601 gamma: Option<&Tensor>,
602 beta: Option<&Tensor>,
603 epsilon: f32,
604 name: Option<&str>,
605 ) -> Option<Tensor> {
606 let name = optional_cstring(name);
607 let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
608 let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
609 let ptr = unsafe {
611 ffi::mpsgraph_graph_normalize(
612 self.ptr,
613 tensor.as_ptr(),
614 mean.as_ptr(),
615 variance.as_ptr(),
616 gamma_ptr,
617 beta_ptr,
618 epsilon,
619 cstring_ptr(&name),
620 )
621 };
622 wrap_tensor(ptr)
623 }
624
625 pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
626 let feed_tensors = feeds
627 .iter()
628 .map(|feed| feed.tensor.as_ptr())
629 .collect::<Vec<_>>();
630 let feed_data = feeds
631 .iter()
632 .map(|feed| feed.data.as_ptr())
633 .collect::<Vec<_>>();
634 let target_tensors = targets
635 .iter()
636 .map(|tensor| tensor.as_ptr())
637 .collect::<Vec<_>>();
638 let mut results = vec![ptr::null_mut(); targets.len()];
639
640 let ok = unsafe {
642 ffi::mpsgraph_graph_run(
643 self.ptr,
644 feed_tensors.as_ptr(),
645 feed_data.as_ptr(),
646 feeds.len(),
647 target_tensors.as_ptr(),
648 targets.len(),
649 results.as_mut_ptr(),
650 )
651 };
652 if ok {
653 wrap_tensor_data_results(results, "failed to run graph")
654 } else {
655 Err(Error::OperationFailed("failed to run graph"))
656 }
657 }
658
659 pub fn run_with_command_queue(
660 &self,
661 command_queue: &CommandQueue,
662 feeds: &[Feed<'_>],
663 targets: &[&Tensor],
664 ) -> Result<Vec<TensorData>> {
665 let feed_tensors = feeds
666 .iter()
667 .map(|feed| feed.tensor.as_ptr())
668 .collect::<Vec<_>>();
669 let feed_data = feeds
670 .iter()
671 .map(|feed| feed.data.as_ptr())
672 .collect::<Vec<_>>();
673 let target_tensors = targets
674 .iter()
675 .map(|tensor| tensor.as_ptr())
676 .collect::<Vec<_>>();
677 let mut results = vec![ptr::null_mut(); targets.len()];
678
679 let ok = unsafe {
681 ffi::mpsgraph_graph_run_with_command_queue(
682 self.ptr,
683 command_queue.as_ptr(),
684 feed_tensors.as_ptr(),
685 feed_data.as_ptr(),
686 feeds.len(),
687 target_tensors.as_ptr(),
688 targets.len(),
689 results.as_mut_ptr(),
690 )
691 };
692 if ok {
693 wrap_tensor_data_results(results, "failed to run graph with command queue")
694 } else {
695 Err(Error::OperationFailed(
696 "failed to run graph with command queue",
697 ))
698 }
699 }
700
701 #[must_use]
702 pub fn compile(
703 &self,
704 device: &MetalDevice,
705 feeds: &[FeedDescription<'_>],
706 targets: &[&Tensor],
707 ) -> Option<Executable> {
708 let feed_tensors = feeds
709 .iter()
710 .map(|feed| feed.tensor.as_ptr())
711 .collect::<Vec<_>>();
712 let shape_lengths = feeds
713 .iter()
714 .map(|feed| feed.shape.len())
715 .collect::<Vec<_>>();
716 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
717 let flat_shapes = feeds
718 .iter()
719 .flat_map(|feed| feed.shape.iter().copied())
720 .collect::<Vec<_>>();
721 let target_tensors = targets
722 .iter()
723 .map(|tensor| tensor.as_ptr())
724 .collect::<Vec<_>>();
725
726 let ptr = unsafe {
728 ffi::mpsgraph_graph_compile(
729 self.ptr,
730 device.as_ptr(),
731 feed_tensors.as_ptr(),
732 feeds.len(),
733 flat_shapes.as_ptr(),
734 shape_lengths.as_ptr(),
735 data_types.as_ptr(),
736 target_tensors.as_ptr(),
737 targets.len(),
738 )
739 };
740 if ptr.is_null() {
741 None
742 } else {
743 Some(Executable {
744 ptr,
745 output_count: targets.len(),
746 })
747 }
748 }
749}
750
751pub struct Executable {
753 ptr: *mut c_void,
754 output_count: usize,
755}
756
757unsafe impl Send for Executable {}
758unsafe impl Sync for Executable {}
759
760impl Drop for Executable {
761 fn drop(&mut self) {
762 if !self.ptr.is_null() {
763 unsafe { ffi::mpsgraph_object_release(self.ptr) };
765 self.ptr = ptr::null_mut();
766 }
767 }
768}
769
770impl Executable {
771 #[must_use]
772 pub const fn as_ptr(&self) -> *mut c_void {
773 self.ptr
774 }
775
776 #[must_use]
777 pub const fn output_count(&self) -> usize {
778 self.output_count
779 }
780
781 pub fn run(
782 &self,
783 command_queue: &CommandQueue,
784 inputs: &[&TensorData],
785 ) -> Result<Vec<TensorData>> {
786 let input_data = inputs
787 .iter()
788 .map(|tensor_data| tensor_data.as_ptr())
789 .collect::<Vec<_>>();
790 let mut results = vec![ptr::null_mut(); self.output_count];
791
792 let ok = unsafe {
794 ffi::mpsgraph_executable_run(
795 self.ptr,
796 command_queue.as_ptr(),
797 input_data.as_ptr(),
798 inputs.len(),
799 self.output_count,
800 results.as_mut_ptr(),
801 )
802 };
803 if ok {
804 wrap_tensor_data_results(results, "failed to run executable")
805 } else {
806 Err(Error::OperationFailed("failed to run executable"))
807 }
808 }
809}