1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::{Executable, FeedDescription, Graph, Tensor};
5use crate::types::{
6 collect_owned_tensors, collect_shaped_type_array_box, collect_tensor_data_array_box, ShapedType,
7};
8use apple_metal::{CommandQueue, MetalDevice};
9use core::ffi::c_void;
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14 if !ptr.is_null() {
15 unsafe { ffi::mpsgraph_object_release(*ptr) };
17 *ptr = ptr::null_mut();
18 }
19}
20
21fn copy_string(
22 len: unsafe extern "C" fn(*mut c_void) -> usize,
23 copy: unsafe extern "C" fn(*mut c_void, *mut u8, usize) -> bool,
24 handle: *mut c_void,
25) -> Result<String> {
26 let len = unsafe { len(handle) };
28 let mut bytes = vec![0_u8; len];
29 let ok = unsafe { copy(handle, bytes.as_mut_ptr(), len) };
31 if ok {
32 String::from_utf8(bytes)
33 .map_err(|_| Error::OperationFailed("bridge returned invalid UTF-8"))
34 } else {
35 Err(Error::OperationFailed("failed to copy string from bridge"))
36 }
37}
38
39pub mod graph_options {
41 pub const NONE: u64 = 0;
42 pub const SYNCHRONIZE_RESULTS: u64 = 1;
43 pub const VERBOSE: u64 = 2;
44 pub const DEFAULT: u64 = SYNCHRONIZE_RESULTS;
45}
46
47pub mod optimization {
49 pub const LEVEL0: u64 = 0;
50 pub const LEVEL1: u64 = 1;
51}
52
53pub mod optimization_profile {
55 pub const PERFORMANCE: u64 = 0;
56 pub const POWER_EFFICIENCY: u64 = 1;
57}
58
59pub mod reduced_precision_fast_math {
61 pub const NONE: usize = 0;
62 pub const ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE: usize = 1 << 1;
63 pub const ALLOW_FP16_INTERMEDIATES: usize = ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE;
64 pub const DEFAULT: usize = NONE;
65}
66
67pub mod deployment_platform {
69 pub const MACOS: u64 = 0;
70 pub const IOS: u64 = 1;
71 pub const TVOS: u64 = 2;
72 pub const VISIONOS: u64 = 3;
73}
74
75pub struct CompilationDescriptor {
77 ptr: *mut c_void,
78}
79
80unsafe impl Send for CompilationDescriptor {}
81unsafe impl Sync for CompilationDescriptor {}
82
83impl Drop for CompilationDescriptor {
84 fn drop(&mut self) {
85 release_handle(&mut self.ptr);
86 }
87}
88
89impl CompilationDescriptor {
90 #[must_use]
91 pub fn new() -> Option<Self> {
92 let ptr = unsafe { ffi::mpsgraph_compilation_descriptor_new() };
94 if ptr.is_null() {
95 None
96 } else {
97 Some(Self { ptr })
98 }
99 }
100
101 #[must_use]
102 pub(crate) const fn as_ptr(&self) -> *mut c_void {
103 self.ptr
104 }
105
106 pub fn disable_type_inference(&self) -> Result<()> {
107 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_disable_type_inference(self.ptr) };
109 if ok {
110 Ok(())
111 } else {
112 Err(Error::OperationFailed("failed to disable type inference"))
113 }
114 }
115
116 #[must_use]
117 pub fn optimization_level(&self) -> u64 {
118 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_level(self.ptr) }
120 }
121
122 pub fn set_optimization_level(&self, value: u64) -> Result<()> {
123 let ok =
125 unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_level(self.ptr, value) };
126 if ok {
127 Ok(())
128 } else {
129 Err(Error::OperationFailed("failed to set optimization level"))
130 }
131 }
132
133 #[must_use]
134 pub fn wait_for_compilation_completion(&self) -> bool {
135 unsafe { ffi::mpsgraph_compilation_descriptor_wait_for_completion(self.ptr) }
137 }
138
139 pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()> {
140 let ok = unsafe {
142 ffi::mpsgraph_compilation_descriptor_set_wait_for_completion(self.ptr, value)
143 };
144 if ok {
145 Ok(())
146 } else {
147 Err(Error::OperationFailed(
148 "failed to set waitForCompilationCompletion",
149 ))
150 }
151 }
152
153 #[must_use]
154 pub fn optimization_profile(&self) -> u64 {
155 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_profile(self.ptr) }
157 }
158
159 pub fn set_optimization_profile(&self, value: u64) -> Result<()> {
160 let ok = unsafe {
162 ffi::mpsgraph_compilation_descriptor_set_optimization_profile(self.ptr, value)
163 };
164 if ok {
165 Ok(())
166 } else {
167 Err(Error::OperationFailed("failed to set optimization profile"))
168 }
169 }
170
171 #[must_use]
172 pub fn reduced_precision_fast_math(&self) -> usize {
173 unsafe { ffi::mpsgraph_compilation_descriptor_reduced_precision_fast_math(self.ptr) }
175 }
176
177 pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()> {
178 let ok = unsafe {
180 ffi::mpsgraph_compilation_descriptor_set_reduced_precision_fast_math(self.ptr, value)
181 };
182 if ok {
183 Ok(())
184 } else {
185 Err(Error::OperationFailed(
186 "failed to set reducedPrecisionFastMath",
187 ))
188 }
189 }
190
191 pub fn set_callable(&self, symbol_name: &str, executable: Option<&Executable>) -> Result<()> {
192 let symbol_name = CString::new(symbol_name)
193 .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
194 let executable_ptr = executable.map_or(ptr::null_mut(), Executable::as_ptr);
195 let ok = unsafe {
197 ffi::mpsgraph_compilation_descriptor_set_callable(
198 self.ptr,
199 symbol_name.as_ptr(),
200 executable_ptr,
201 )
202 };
203 if ok {
204 Ok(())
205 } else {
206 Err(Error::OperationFailed(
207 "failed to set compilation descriptor callable",
208 ))
209 }
210 }
211}
212
213pub struct ExecutionDescriptor {
215 ptr: *mut c_void,
216}
217
218unsafe impl Send for ExecutionDescriptor {}
219unsafe impl Sync for ExecutionDescriptor {}
220
221impl Drop for ExecutionDescriptor {
222 fn drop(&mut self) {
223 release_handle(&mut self.ptr);
224 }
225}
226
227impl ExecutionDescriptor {
228 #[must_use]
229 pub fn new() -> Option<Self> {
230 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_new() };
232 if ptr.is_null() {
233 None
234 } else {
235 Some(Self { ptr })
236 }
237 }
238
239 #[must_use]
240 pub const fn as_ptr(&self) -> *mut c_void {
241 self.ptr
242 }
243
244 #[must_use]
245 pub fn wait_until_completed(&self) -> bool {
246 unsafe { ffi::mpsgraph_execution_descriptor_wait_until_completed(self.ptr) }
248 }
249
250 pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
251 let ok =
253 unsafe { ffi::mpsgraph_execution_descriptor_set_wait_until_completed(self.ptr, value) };
254 if ok {
255 Ok(())
256 } else {
257 Err(Error::OperationFailed("failed to set waitUntilCompleted"))
258 }
259 }
260
261 #[must_use]
262 pub fn compilation_descriptor(&self) -> Option<CompilationDescriptor> {
263 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_compilation_descriptor(self.ptr) };
265 if ptr.is_null() {
266 None
267 } else {
268 Some(CompilationDescriptor { ptr })
269 }
270 }
271
272 pub fn set_compilation_descriptor(
273 &self,
274 descriptor: Option<&CompilationDescriptor>,
275 ) -> Result<()> {
276 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
277 let ok = unsafe {
279 ffi::mpsgraph_execution_descriptor_set_compilation_descriptor(self.ptr, descriptor_ptr)
280 };
281 if ok {
282 Ok(())
283 } else {
284 Err(Error::OperationFailed(
285 "failed to set compilation descriptor",
286 ))
287 }
288 }
289}
290
291pub struct ExecutableExecutionDescriptor {
293 ptr: *mut c_void,
294}
295
296unsafe impl Send for ExecutableExecutionDescriptor {}
297unsafe impl Sync for ExecutableExecutionDescriptor {}
298
299impl Drop for ExecutableExecutionDescriptor {
300 fn drop(&mut self) {
301 release_handle(&mut self.ptr);
302 }
303}
304
305impl ExecutableExecutionDescriptor {
306 #[must_use]
307 pub fn new() -> Option<Self> {
308 let ptr = unsafe { ffi::mpsgraph_executable_execution_descriptor_new() };
310 if ptr.is_null() {
311 None
312 } else {
313 Some(Self { ptr })
314 }
315 }
316
317 #[must_use]
318 pub(crate) const fn as_ptr(&self) -> *mut c_void {
319 self.ptr
320 }
321
322 #[must_use]
323 pub fn wait_until_completed(&self) -> bool {
324 unsafe { ffi::mpsgraph_executable_execution_descriptor_wait_until_completed(self.ptr) }
326 }
327
328 pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
329 let ok = unsafe {
331 ffi::mpsgraph_executable_execution_descriptor_set_wait_until_completed(self.ptr, value)
332 };
333 if ok {
334 Ok(())
335 } else {
336 Err(Error::OperationFailed(
337 "failed to set executable waitUntilCompleted",
338 ))
339 }
340 }
341}
342
343pub struct ExecutableSerializationDescriptor {
345 ptr: *mut c_void,
346}
347
348unsafe impl Send for ExecutableSerializationDescriptor {}
349unsafe impl Sync for ExecutableSerializationDescriptor {}
350
351impl Drop for ExecutableSerializationDescriptor {
352 fn drop(&mut self) {
353 release_handle(&mut self.ptr);
354 }
355}
356
357impl ExecutableSerializationDescriptor {
358 #[must_use]
359 pub fn new() -> Option<Self> {
360 let ptr = unsafe { ffi::mpsgraph_executable_serialization_descriptor_new() };
362 if ptr.is_null() {
363 None
364 } else {
365 Some(Self { ptr })
366 }
367 }
368
369 #[must_use]
370 pub(crate) const fn as_ptr(&self) -> *mut c_void {
371 self.ptr
372 }
373
374 #[must_use]
375 pub fn append(&self) -> bool {
376 unsafe { ffi::mpsgraph_executable_serialization_descriptor_append(self.ptr) }
378 }
379
380 pub fn set_append(&self, value: bool) -> Result<()> {
381 let ok = unsafe {
383 ffi::mpsgraph_executable_serialization_descriptor_set_append(self.ptr, value)
384 };
385 if ok {
386 Ok(())
387 } else {
388 Err(Error::OperationFailed("failed to set append"))
389 }
390 }
391
392 #[must_use]
393 pub fn deployment_platform(&self) -> u64 {
394 unsafe { ffi::mpsgraph_executable_serialization_descriptor_deployment_platform(self.ptr) }
396 }
397
398 pub fn set_deployment_platform(&self, value: u64) -> Result<()> {
399 let ok = unsafe {
401 ffi::mpsgraph_executable_serialization_descriptor_set_deployment_platform(
402 self.ptr, value,
403 )
404 };
405 if ok {
406 Ok(())
407 } else {
408 Err(Error::OperationFailed("failed to set deployment platform"))
409 }
410 }
411
412 pub fn minimum_deployment_target(&self) -> Result<String> {
413 copy_string(
414 ffi::mpsgraph_executable_serialization_descriptor_minimum_deployment_target_len,
415 ffi::mpsgraph_executable_serialization_descriptor_copy_minimum_deployment_target,
416 self.ptr,
417 )
418 }
419
420 pub fn set_minimum_deployment_target(&self, value: &str) -> Result<()> {
421 let value = CString::new(value)
422 .map_err(|_| Error::OperationFailed("minimum deployment target contained NUL"))?;
423 let ok = unsafe {
425 ffi::mpsgraph_executable_serialization_descriptor_set_minimum_deployment_target(
426 self.ptr,
427 value.as_ptr(),
428 )
429 };
430 if ok {
431 Ok(())
432 } else {
433 Err(Error::OperationFailed(
434 "failed to set minimum deployment target",
435 ))
436 }
437 }
438}
439
440impl Graph {
441 #[must_use]
443 pub fn options(&self) -> u64 {
444 unsafe { ffi::mpsgraph_graph_options(self.as_ptr()) }
446 }
447
448 pub fn set_options(&self, options: u64) -> Result<()> {
450 let ok = unsafe { ffi::mpsgraph_graph_set_options(self.as_ptr(), options) };
452 if ok {
453 Ok(())
454 } else {
455 Err(Error::OperationFailed("failed to set graph options"))
456 }
457 }
458
459 #[must_use]
461 pub fn placeholder_tensors(&self) -> Vec<Tensor> {
462 let box_handle = unsafe { ffi::mpsgraph_graph_placeholder_tensors(self.as_ptr()) };
464 collect_owned_tensors(box_handle)
465 }
466
467 #[must_use]
469 pub fn compile_with_descriptor(
470 &self,
471 device: Option<&MetalDevice>,
472 feeds: &[FeedDescription<'_>],
473 targets: &[&Tensor],
474 descriptor: Option<&CompilationDescriptor>,
475 ) -> Option<Executable> {
476 let feed_tensors = feeds
477 .iter()
478 .map(|feed| feed.tensor.as_ptr())
479 .collect::<Vec<_>>();
480 let shape_lengths = feeds
481 .iter()
482 .map(|feed| feed.shape.len())
483 .collect::<Vec<_>>();
484 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
485 let flat_shapes = feeds
486 .iter()
487 .flat_map(|feed| feed.shape.iter().copied())
488 .collect::<Vec<_>>();
489 let target_tensors = targets
490 .iter()
491 .map(|tensor| tensor.as_ptr())
492 .collect::<Vec<_>>();
493 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
494 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
495
496 let ptr = unsafe {
498 ffi::mpsgraph_graph_compile_with_descriptor(
499 self.as_ptr(),
500 device_ptr,
501 feed_tensors.as_ptr(),
502 feeds.len(),
503 flat_shapes.as_ptr(),
504 shape_lengths.as_ptr(),
505 data_types.as_ptr(),
506 target_tensors.as_ptr(),
507 targets.len(),
508 descriptor_ptr,
509 )
510 };
511 if ptr.is_null() {
512 None
513 } else {
514 Some(Executable::from_raw(ptr, targets.len()))
515 }
516 }
517}
518
519impl Executable {
520 #[must_use]
522 pub fn options(&self) -> u64 {
523 unsafe { ffi::mpsgraph_executable_options(self.as_ptr()) }
525 }
526
527 pub fn set_options(&self, options: u64) -> Result<()> {
529 let ok = unsafe { ffi::mpsgraph_executable_set_options(self.as_ptr(), options) };
531 if ok {
532 Ok(())
533 } else {
534 Err(Error::OperationFailed("failed to set executable options"))
535 }
536 }
537
538 #[must_use]
540 pub fn feed_tensors(&self) -> Vec<Tensor> {
541 let box_handle = unsafe { ffi::mpsgraph_executable_feed_tensors(self.as_ptr()) };
543 collect_owned_tensors(box_handle)
544 }
545
546 #[must_use]
548 pub fn target_tensors(&self) -> Vec<Tensor> {
549 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(self.as_ptr()) };
551 collect_owned_tensors(box_handle)
552 }
553
554 pub fn specialize(
556 &self,
557 device: Option<&MetalDevice>,
558 input_types: &[&ShapedType],
559 descriptor: Option<&CompilationDescriptor>,
560 ) -> Result<()> {
561 let input_type_handles = input_types
562 .iter()
563 .map(|value| value.as_ptr())
564 .collect::<Vec<_>>();
565 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
566 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
567
568 let ok = unsafe {
570 ffi::mpsgraph_executable_specialize(
571 self.as_ptr(),
572 device_ptr,
573 input_type_handles.as_ptr(),
574 input_types.len(),
575 descriptor_ptr,
576 )
577 };
578 if ok {
579 Ok(())
580 } else {
581 Err(Error::OperationFailed("failed to specialize executable"))
582 }
583 }
584
585 pub fn output_types(
587 &self,
588 device: Option<&MetalDevice>,
589 input_types: &[&ShapedType],
590 descriptor: Option<&CompilationDescriptor>,
591 ) -> Result<Vec<ShapedType>> {
592 let input_type_handles = input_types
593 .iter()
594 .map(|value| value.as_ptr())
595 .collect::<Vec<_>>();
596 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
597 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
598
599 let box_handle = unsafe {
601 ffi::mpsgraph_executable_get_output_types(
602 self.as_ptr(),
603 device_ptr,
604 input_type_handles.as_ptr(),
605 input_types.len(),
606 descriptor_ptr,
607 )
608 };
609 if box_handle.is_null() {
610 Err(Error::OperationFailed(
611 "failed to get executable output types",
612 ))
613 } else {
614 Ok(collect_shaped_type_array_box(box_handle))
615 }
616 }
617
618 pub fn run_with_descriptor(
620 &self,
621 command_queue: &CommandQueue,
622 inputs: &[&TensorData],
623 results: Option<&[&TensorData]>,
624 descriptor: Option<&ExecutableExecutionDescriptor>,
625 ) -> Result<Vec<TensorData>> {
626 let input_handles = inputs
627 .iter()
628 .map(|value| value.as_ptr())
629 .collect::<Vec<_>>();
630 let result_handles = results
631 .map(|values| {
632 values
633 .iter()
634 .map(|value| value.as_ptr())
635 .collect::<Vec<_>>()
636 })
637 .unwrap_or_default();
638 let descriptor_ptr =
639 descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
640
641 let box_handle = unsafe {
643 ffi::mpsgraph_executable_run_with_descriptor(
644 self.as_ptr(),
645 command_queue.as_ptr(),
646 input_handles.as_ptr(),
647 inputs.len(),
648 result_handles.as_ptr(),
649 result_handles.len(),
650 descriptor_ptr,
651 )
652 };
653 if box_handle.is_null() {
654 Err(Error::OperationFailed("failed to run executable"))
655 } else {
656 Ok(collect_tensor_data_array_box(box_handle))
657 }
658 }
659
660 pub fn run_async_with_descriptor(
662 &self,
663 command_queue: &CommandQueue,
664 inputs: &[&TensorData],
665 results: Option<&[&TensorData]>,
666 descriptor: Option<&ExecutableExecutionDescriptor>,
667 ) -> Result<Vec<TensorData>> {
668 let input_handles = inputs
669 .iter()
670 .map(|value| value.as_ptr())
671 .collect::<Vec<_>>();
672 let result_handles = results
673 .map(|values| {
674 values
675 .iter()
676 .map(|value| value.as_ptr())
677 .collect::<Vec<_>>()
678 })
679 .unwrap_or_default();
680 let descriptor_ptr =
681 descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
682
683 let box_handle = unsafe {
685 ffi::mpsgraph_executable_run_async_with_descriptor(
686 self.as_ptr(),
687 command_queue.as_ptr(),
688 input_handles.as_ptr(),
689 inputs.len(),
690 result_handles.as_ptr(),
691 result_handles.len(),
692 descriptor_ptr,
693 )
694 };
695 if box_handle.is_null() {
696 Err(Error::OperationFailed(
697 "failed to run executable asynchronously",
698 ))
699 } else {
700 Ok(collect_tensor_data_array_box(box_handle))
701 }
702 }
703
704 pub fn serialize_package(
706 &self,
707 path: &str,
708 descriptor: Option<&ExecutableSerializationDescriptor>,
709 ) -> Result<()> {
710 let path =
711 CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
712 let descriptor_ptr =
713 descriptor.map_or(ptr::null_mut(), ExecutableSerializationDescriptor::as_ptr);
714 let ok = unsafe {
716 ffi::mpsgraph_executable_serialize_package(self.as_ptr(), path.as_ptr(), descriptor_ptr)
717 };
718 if ok {
719 Ok(())
720 } else {
721 Err(Error::OperationFailed(
722 "failed to serialize executable package",
723 ))
724 }
725 }
726
727 pub fn from_package(path: &str, descriptor: Option<&CompilationDescriptor>) -> Result<Self> {
729 let path =
730 CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
731 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
732 let ptr =
734 unsafe { ffi::mpsgraph_executable_new_with_package(path.as_ptr(), descriptor_ptr) };
735 if ptr.is_null() {
736 return Err(Error::OperationFailed("failed to load executable package"));
737 }
738 let output_count = {
739 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(ptr) };
741 collect_owned_tensors(box_handle).len()
742 };
743 Ok(Self::from_raw(ptr, output_count))
744 }
745}