1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::{Executable, FeedDescription, Graph, Tensor};
5use crate::types::{
6 collect_owned_tensors, collect_shaped_type_array_box, collect_tensor_data_array_box, ShapedType,
7};
8use apple_metal::{CommandQueue, MetalDevice};
9use core::ffi::c_void;
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14 if !ptr.is_null() {
15 unsafe { ffi::mpsgraph_object_release(*ptr) };
17 *ptr = ptr::null_mut();
18 }
19}
20
21fn copy_string(
22 len: unsafe extern "C" fn(*mut c_void) -> usize,
23 copy: unsafe extern "C" fn(*mut c_void, *mut u8, usize) -> bool,
24 handle: *mut c_void,
25) -> Result<String> {
26 let len = unsafe { len(handle) };
28 let mut bytes = vec![0_u8; len];
29 let ok = unsafe { copy(handle, bytes.as_mut_ptr(), len) };
31 if ok {
32 String::from_utf8(bytes)
33 .map_err(|_| Error::OperationFailed("bridge returned invalid UTF-8"))
34 } else {
35 Err(Error::OperationFailed("failed to copy string from bridge"))
36 }
37}
38
39pub mod graph_options {
41pub const NONE: u64 = 0;
43pub const SYNCHRONIZE_RESULTS: u64 = 1;
45pub const VERBOSE: u64 = 2;
47pub const DEFAULT: u64 = SYNCHRONIZE_RESULTS;
49}
50
51pub mod optimization {
53pub const LEVEL0: u64 = 0;
55pub const LEVEL1: u64 = 1;
57}
58
59pub mod optimization_profile {
61pub const PERFORMANCE: u64 = 0;
63pub const POWER_EFFICIENCY: u64 = 1;
65}
66
67pub mod reduced_precision_fast_math {
69pub const NONE: usize = 0;
71pub const ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE: usize = 1 << 1;
73pub const ALLOW_FP16_INTERMEDIATES: usize = ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE;
75pub const DEFAULT: usize = NONE;
77}
78
79pub mod deployment_platform {
81pub const MACOS: u64 = 0;
83pub const IOS: u64 = 1;
85pub const TVOS: u64 = 2;
87pub const VISIONOS: u64 = 3;
89}
90
91pub struct CompilationDescriptor {
93 ptr: *mut c_void,
94}
95
96unsafe impl Send for CompilationDescriptor {}
97unsafe impl Sync for CompilationDescriptor {}
98
99impl Drop for CompilationDescriptor {
100 fn drop(&mut self) {
101 release_handle(&mut self.ptr);
102 }
103}
104
105impl CompilationDescriptor {
106#[must_use]
108 pub fn new() -> Option<Self> {
109 let ptr = unsafe { ffi::mpsgraph_compilation_descriptor_new() };
111 if ptr.is_null() {
112 None
113 } else {
114 Some(Self { ptr })
115 }
116 }
117
118 #[must_use]
119 pub(crate) const fn as_ptr(&self) -> *mut c_void {
120 self.ptr
121 }
122
123pub fn disable_type_inference(&self) -> Result<()> {
125 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_disable_type_inference(self.ptr) };
127 if ok {
128 Ok(())
129 } else {
130 Err(Error::OperationFailed("failed to disable type inference"))
131 }
132 }
133
134#[must_use]
136 pub fn optimization_level(&self) -> u64 {
137 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_level(self.ptr) }
139 }
140
141pub fn set_optimization_level(&self, value: u64) -> Result<()> {
143 let ok =
145 unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_level(self.ptr, value) };
146 if ok {
147 Ok(())
148 } else {
149 Err(Error::OperationFailed("failed to set optimization level"))
150 }
151 }
152
153#[must_use]
155 pub fn wait_for_compilation_completion(&self) -> bool {
156 unsafe { ffi::mpsgraph_compilation_descriptor_wait_for_completion(self.ptr) }
158 }
159
160pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()> {
162 let ok = unsafe {
164 ffi::mpsgraph_compilation_descriptor_set_wait_for_completion(self.ptr, value)
165 };
166 if ok {
167 Ok(())
168 } else {
169 Err(Error::OperationFailed(
170 "failed to set waitForCompilationCompletion",
171 ))
172 }
173 }
174
175#[must_use]
177 pub fn optimization_profile(&self) -> u64 {
178 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_profile(self.ptr) }
180 }
181
182pub fn set_optimization_profile(&self, value: u64) -> Result<()> {
184 let ok = unsafe {
186 ffi::mpsgraph_compilation_descriptor_set_optimization_profile(self.ptr, value)
187 };
188 if ok {
189 Ok(())
190 } else {
191 Err(Error::OperationFailed("failed to set optimization profile"))
192 }
193 }
194
195#[must_use]
197 pub fn reduced_precision_fast_math(&self) -> usize {
198 unsafe { ffi::mpsgraph_compilation_descriptor_reduced_precision_fast_math(self.ptr) }
200 }
201
202pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()> {
204 let ok = unsafe {
206 ffi::mpsgraph_compilation_descriptor_set_reduced_precision_fast_math(self.ptr, value)
207 };
208 if ok {
209 Ok(())
210 } else {
211 Err(Error::OperationFailed(
212 "failed to set reducedPrecisionFastMath",
213 ))
214 }
215 }
216
217pub fn set_callable(&self, symbol_name: &str, executable: Option<&Executable>) -> Result<()> {
219 let symbol_name = CString::new(symbol_name)
220 .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
221 let executable_ptr = executable.map_or(ptr::null_mut(), Executable::as_ptr);
222 let ok = unsafe {
224 ffi::mpsgraph_compilation_descriptor_set_callable(
225 self.ptr,
226 symbol_name.as_ptr(),
227 executable_ptr,
228 )
229 };
230 if ok {
231 Ok(())
232 } else {
233 Err(Error::OperationFailed(
234 "failed to set compilation descriptor callable",
235 ))
236 }
237 }
238}
239
240pub struct ExecutionDescriptor {
242 ptr: *mut c_void,
243}
244
245unsafe impl Send for ExecutionDescriptor {}
246unsafe impl Sync for ExecutionDescriptor {}
247
248impl Drop for ExecutionDescriptor {
249 fn drop(&mut self) {
250 release_handle(&mut self.ptr);
251 }
252}
253
254impl ExecutionDescriptor {
255#[must_use]
257 pub fn new() -> Option<Self> {
258 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_new() };
260 if ptr.is_null() {
261 None
262 } else {
263 Some(Self { ptr })
264 }
265 }
266
267#[must_use]
269 pub const fn as_ptr(&self) -> *mut c_void {
270 self.ptr
271 }
272
273#[must_use]
275 pub fn wait_until_completed(&self) -> bool {
276 unsafe { ffi::mpsgraph_execution_descriptor_wait_until_completed(self.ptr) }
278 }
279
280pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
282 let ok =
284 unsafe { ffi::mpsgraph_execution_descriptor_set_wait_until_completed(self.ptr, value) };
285 if ok {
286 Ok(())
287 } else {
288 Err(Error::OperationFailed("failed to set waitUntilCompleted"))
289 }
290 }
291
292#[must_use]
294 pub fn compilation_descriptor(&self) -> Option<CompilationDescriptor> {
295 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_compilation_descriptor(self.ptr) };
297 if ptr.is_null() {
298 None
299 } else {
300 Some(CompilationDescriptor { ptr })
301 }
302 }
303
304pub fn set_compilation_descriptor(
306 &self,
307 descriptor: Option<&CompilationDescriptor>,
308 ) -> Result<()> {
309 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
310 let ok = unsafe {
312 ffi::mpsgraph_execution_descriptor_set_compilation_descriptor(self.ptr, descriptor_ptr)
313 };
314 if ok {
315 Ok(())
316 } else {
317 Err(Error::OperationFailed(
318 "failed to set compilation descriptor",
319 ))
320 }
321 }
322}
323
324pub struct ExecutableExecutionDescriptor {
326 ptr: *mut c_void,
327}
328
329unsafe impl Send for ExecutableExecutionDescriptor {}
330unsafe impl Sync for ExecutableExecutionDescriptor {}
331
332impl Drop for ExecutableExecutionDescriptor {
333 fn drop(&mut self) {
334 release_handle(&mut self.ptr);
335 }
336}
337
338impl ExecutableExecutionDescriptor {
339#[must_use]
341 pub fn new() -> Option<Self> {
342 let ptr = unsafe { ffi::mpsgraph_executable_execution_descriptor_new() };
344 if ptr.is_null() {
345 None
346 } else {
347 Some(Self { ptr })
348 }
349 }
350
351 #[must_use]
352 pub(crate) const fn as_ptr(&self) -> *mut c_void {
353 self.ptr
354 }
355
356#[must_use]
358 pub fn wait_until_completed(&self) -> bool {
359 unsafe { ffi::mpsgraph_executable_execution_descriptor_wait_until_completed(self.ptr) }
361 }
362
363pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
365 let ok = unsafe {
367 ffi::mpsgraph_executable_execution_descriptor_set_wait_until_completed(self.ptr, value)
368 };
369 if ok {
370 Ok(())
371 } else {
372 Err(Error::OperationFailed(
373 "failed to set executable waitUntilCompleted",
374 ))
375 }
376 }
377}
378
379pub struct ExecutableSerializationDescriptor {
381 ptr: *mut c_void,
382}
383
384unsafe impl Send for ExecutableSerializationDescriptor {}
385unsafe impl Sync for ExecutableSerializationDescriptor {}
386
387impl Drop for ExecutableSerializationDescriptor {
388 fn drop(&mut self) {
389 release_handle(&mut self.ptr);
390 }
391}
392
393impl ExecutableSerializationDescriptor {
394#[must_use]
396 pub fn new() -> Option<Self> {
397 let ptr = unsafe { ffi::mpsgraph_executable_serialization_descriptor_new() };
399 if ptr.is_null() {
400 None
401 } else {
402 Some(Self { ptr })
403 }
404 }
405
406 #[must_use]
407 pub(crate) const fn as_ptr(&self) -> *mut c_void {
408 self.ptr
409 }
410
411#[must_use]
413 pub fn append(&self) -> bool {
414 unsafe { ffi::mpsgraph_executable_serialization_descriptor_append(self.ptr) }
416 }
417
418pub fn set_append(&self, value: bool) -> Result<()> {
420 let ok = unsafe {
422 ffi::mpsgraph_executable_serialization_descriptor_set_append(self.ptr, value)
423 };
424 if ok {
425 Ok(())
426 } else {
427 Err(Error::OperationFailed("failed to set append"))
428 }
429 }
430
431#[must_use]
433 pub fn deployment_platform(&self) -> u64 {
434 unsafe { ffi::mpsgraph_executable_serialization_descriptor_deployment_platform(self.ptr) }
436 }
437
438pub fn set_deployment_platform(&self, value: u64) -> Result<()> {
440 let ok = unsafe {
442 ffi::mpsgraph_executable_serialization_descriptor_set_deployment_platform(
443 self.ptr, value,
444 )
445 };
446 if ok {
447 Ok(())
448 } else {
449 Err(Error::OperationFailed("failed to set deployment platform"))
450 }
451 }
452
453pub fn minimum_deployment_target(&self) -> Result<String> {
455 copy_string(
456 ffi::mpsgraph_executable_serialization_descriptor_minimum_deployment_target_len,
457 ffi::mpsgraph_executable_serialization_descriptor_copy_minimum_deployment_target,
458 self.ptr,
459 )
460 }
461
462pub fn set_minimum_deployment_target(&self, value: &str) -> Result<()> {
464 let value = CString::new(value)
465 .map_err(|_| Error::OperationFailed("minimum deployment target contained NUL"))?;
466 let ok = unsafe {
468 ffi::mpsgraph_executable_serialization_descriptor_set_minimum_deployment_target(
469 self.ptr,
470 value.as_ptr(),
471 )
472 };
473 if ok {
474 Ok(())
475 } else {
476 Err(Error::OperationFailed(
477 "failed to set minimum deployment target",
478 ))
479 }
480 }
481}
482
483impl Graph {
484 #[must_use]
486 pub fn options(&self) -> u64 {
487 unsafe { ffi::mpsgraph_graph_options(self.as_ptr()) }
489 }
490
491 pub fn set_options(&self, options: u64) -> Result<()> {
493 let ok = unsafe { ffi::mpsgraph_graph_set_options(self.as_ptr(), options) };
495 if ok {
496 Ok(())
497 } else {
498 Err(Error::OperationFailed("failed to set graph options"))
499 }
500 }
501
502 #[must_use]
504 pub fn placeholder_tensors(&self) -> Vec<Tensor> {
505 let box_handle = unsafe { ffi::mpsgraph_graph_placeholder_tensors(self.as_ptr()) };
507 collect_owned_tensors(box_handle)
508 }
509
510 #[must_use]
512 pub fn compile_with_descriptor(
513 &self,
514 device: Option<&MetalDevice>,
515 feeds: &[FeedDescription<'_>],
516 targets: &[&Tensor],
517 descriptor: Option<&CompilationDescriptor>,
518 ) -> Option<Executable> {
519 let feed_tensors = feeds
520 .iter()
521 .map(|feed| feed.tensor.as_ptr())
522 .collect::<Vec<_>>();
523 let shape_lengths = feeds
524 .iter()
525 .map(|feed| feed.shape.len())
526 .collect::<Vec<_>>();
527 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
528 let flat_shapes = feeds
529 .iter()
530 .flat_map(|feed| feed.shape.iter().copied())
531 .collect::<Vec<_>>();
532 let target_tensors = targets
533 .iter()
534 .map(|tensor| tensor.as_ptr())
535 .collect::<Vec<_>>();
536 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
537 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
538
539 let ptr = unsafe {
541 ffi::mpsgraph_graph_compile_with_descriptor(
542 self.as_ptr(),
543 device_ptr,
544 feed_tensors.as_ptr(),
545 feeds.len(),
546 flat_shapes.as_ptr(),
547 shape_lengths.as_ptr(),
548 data_types.as_ptr(),
549 target_tensors.as_ptr(),
550 targets.len(),
551 descriptor_ptr,
552 )
553 };
554 if ptr.is_null() {
555 None
556 } else {
557 Some(Executable::from_raw(ptr, targets.len()))
558 }
559 }
560}
561
562impl Executable {
563 #[must_use]
565 pub fn options(&self) -> u64 {
566 unsafe { ffi::mpsgraph_executable_options(self.as_ptr()) }
568 }
569
570 pub fn set_options(&self, options: u64) -> Result<()> {
572 let ok = unsafe { ffi::mpsgraph_executable_set_options(self.as_ptr(), options) };
574 if ok {
575 Ok(())
576 } else {
577 Err(Error::OperationFailed("failed to set executable options"))
578 }
579 }
580
581 #[must_use]
583 pub fn feed_tensors(&self) -> Vec<Tensor> {
584 let box_handle = unsafe { ffi::mpsgraph_executable_feed_tensors(self.as_ptr()) };
586 collect_owned_tensors(box_handle)
587 }
588
589 #[must_use]
591 pub fn target_tensors(&self) -> Vec<Tensor> {
592 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(self.as_ptr()) };
594 collect_owned_tensors(box_handle)
595 }
596
597 pub fn specialize(
599 &self,
600 device: Option<&MetalDevice>,
601 input_types: &[&ShapedType],
602 descriptor: Option<&CompilationDescriptor>,
603 ) -> Result<()> {
604 let input_type_handles = input_types
605 .iter()
606 .map(|value| value.as_ptr())
607 .collect::<Vec<_>>();
608 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
609 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
610
611 let ok = unsafe {
613 ffi::mpsgraph_executable_specialize(
614 self.as_ptr(),
615 device_ptr,
616 input_type_handles.as_ptr(),
617 input_types.len(),
618 descriptor_ptr,
619 )
620 };
621 if ok {
622 Ok(())
623 } else {
624 Err(Error::OperationFailed("failed to specialize executable"))
625 }
626 }
627
628 pub fn output_types(
630 &self,
631 device: Option<&MetalDevice>,
632 input_types: &[&ShapedType],
633 descriptor: Option<&CompilationDescriptor>,
634 ) -> Result<Vec<ShapedType>> {
635 let input_type_handles = input_types
636 .iter()
637 .map(|value| value.as_ptr())
638 .collect::<Vec<_>>();
639 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
640 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
641
642 let box_handle = unsafe {
644 ffi::mpsgraph_executable_get_output_types(
645 self.as_ptr(),
646 device_ptr,
647 input_type_handles.as_ptr(),
648 input_types.len(),
649 descriptor_ptr,
650 )
651 };
652 if box_handle.is_null() {
653 Err(Error::OperationFailed(
654 "failed to get executable output types",
655 ))
656 } else {
657 Ok(collect_shaped_type_array_box(box_handle))
658 }
659 }
660
661 pub fn run_with_descriptor(
663 &self,
664 command_queue: &CommandQueue,
665 inputs: &[&TensorData],
666 results: Option<&[&TensorData]>,
667 descriptor: Option<&ExecutableExecutionDescriptor>,
668 ) -> Result<Vec<TensorData>> {
669 let input_handles = inputs
670 .iter()
671 .map(|value| value.as_ptr())
672 .collect::<Vec<_>>();
673 let result_handles = results
674 .map(|values| {
675 values
676 .iter()
677 .map(|value| value.as_ptr())
678 .collect::<Vec<_>>()
679 })
680 .unwrap_or_default();
681 let descriptor_ptr =
682 descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
683
684 let box_handle = unsafe {
686 ffi::mpsgraph_executable_run_with_descriptor(
687 self.as_ptr(),
688 command_queue.as_ptr(),
689 input_handles.as_ptr(),
690 inputs.len(),
691 result_handles.as_ptr(),
692 result_handles.len(),
693 descriptor_ptr,
694 )
695 };
696 if box_handle.is_null() {
697 Err(Error::OperationFailed("failed to run executable"))
698 } else {
699 Ok(collect_tensor_data_array_box(box_handle))
700 }
701 }
702
703 pub fn run_async_with_descriptor(
705 &self,
706 command_queue: &CommandQueue,
707 inputs: &[&TensorData],
708 results: Option<&[&TensorData]>,
709 descriptor: Option<&ExecutableExecutionDescriptor>,
710 ) -> Result<Vec<TensorData>> {
711 let input_handles = inputs
712 .iter()
713 .map(|value| value.as_ptr())
714 .collect::<Vec<_>>();
715 let result_handles = results
716 .map(|values| {
717 values
718 .iter()
719 .map(|value| value.as_ptr())
720 .collect::<Vec<_>>()
721 })
722 .unwrap_or_default();
723 let descriptor_ptr =
724 descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
725
726 let box_handle = unsafe {
728 ffi::mpsgraph_executable_run_async_with_descriptor(
729 self.as_ptr(),
730 command_queue.as_ptr(),
731 input_handles.as_ptr(),
732 inputs.len(),
733 result_handles.as_ptr(),
734 result_handles.len(),
735 descriptor_ptr,
736 )
737 };
738 if box_handle.is_null() {
739 Err(Error::OperationFailed(
740 "failed to run executable asynchronously",
741 ))
742 } else {
743 Ok(collect_tensor_data_array_box(box_handle))
744 }
745 }
746
747 pub fn serialize_package(
749 &self,
750 path: &str,
751 descriptor: Option<&ExecutableSerializationDescriptor>,
752 ) -> Result<()> {
753 let path =
754 CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
755 let descriptor_ptr =
756 descriptor.map_or(ptr::null_mut(), ExecutableSerializationDescriptor::as_ptr);
757 let ok = unsafe {
759 ffi::mpsgraph_executable_serialize_package(self.as_ptr(), path.as_ptr(), descriptor_ptr)
760 };
761 if ok {
762 Ok(())
763 } else {
764 Err(Error::OperationFailed(
765 "failed to serialize executable package",
766 ))
767 }
768 }
769
770 pub fn from_package(path: &str, descriptor: Option<&CompilationDescriptor>) -> Result<Self> {
772 let path =
773 CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
774 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
775 let ptr =
777 unsafe { ffi::mpsgraph_executable_new_with_package(path.as_ptr(), descriptor_ptr) };
778 if ptr.is_null() {
779 return Err(Error::OperationFailed("failed to load executable package"));
780 }
781 let output_count = {
782 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(ptr) };
784 collect_owned_tensors(box_handle).len()
785 };
786 Ok(Self::from_raw(ptr, output_count))
787 }
788}