1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::{Executable, FeedDescription, Graph, Tensor};
5use crate::types::{collect_owned_tensors, collect_shaped_type_array_box, collect_tensor_data_array_box, ShapedType};
6use apple_metal::{CommandQueue, MetalDevice};
7use core::ffi::c_void;
8use core::ptr;
9use std::ffi::CString;
10
11fn release_handle(ptr: &mut *mut c_void) {
12 if !ptr.is_null() {
13 unsafe { ffi::mpsgraph_object_release(*ptr) };
15 *ptr = ptr::null_mut();
16 }
17}
18
19fn copy_string(
20 len: unsafe extern "C" fn(*mut c_void) -> usize,
21 copy: unsafe extern "C" fn(*mut c_void, *mut u8, usize) -> bool,
22 handle: *mut c_void,
23) -> Result<String> {
24 let len = unsafe { len(handle) };
26 let mut bytes = vec![0_u8; len];
27 let ok = unsafe { copy(handle, bytes.as_mut_ptr(), len) };
29 if ok {
30 String::from_utf8(bytes).map_err(|_| Error::OperationFailed("bridge returned invalid UTF-8"))
31 } else {
32 Err(Error::OperationFailed("failed to copy string from bridge"))
33 }
34}
35
36pub mod graph_options {
38 pub const NONE: u64 = 0;
39 pub const SYNCHRONIZE_RESULTS: u64 = 1;
40 pub const VERBOSE: u64 = 2;
41 pub const DEFAULT: u64 = SYNCHRONIZE_RESULTS;
42}
43
44pub mod optimization {
46 pub const LEVEL0: u64 = 0;
47 pub const LEVEL1: u64 = 1;
48}
49
50pub mod optimization_profile {
52 pub const PERFORMANCE: u64 = 0;
53 pub const POWER_EFFICIENCY: u64 = 1;
54}
55
56pub mod reduced_precision_fast_math {
58 pub const NONE: usize = 0;
59 pub const ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE: usize = 1 << 1;
60 pub const ALLOW_FP16_INTERMEDIATES: usize = ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE;
61 pub const DEFAULT: usize = NONE;
62}
63
64pub mod deployment_platform {
66 pub const MACOS: u64 = 0;
67 pub const IOS: u64 = 1;
68 pub const TVOS: u64 = 2;
69 pub const VISIONOS: u64 = 3;
70}
71
72pub struct CompilationDescriptor {
74 ptr: *mut c_void,
75}
76
77unsafe impl Send for CompilationDescriptor {}
78unsafe impl Sync for CompilationDescriptor {}
79
80impl Drop for CompilationDescriptor {
81 fn drop(&mut self) {
82 release_handle(&mut self.ptr);
83 }
84}
85
86impl CompilationDescriptor {
87 #[must_use]
88 pub fn new() -> Option<Self> {
89 let ptr = unsafe { ffi::mpsgraph_compilation_descriptor_new() };
91 if ptr.is_null() {
92 None
93 } else {
94 Some(Self { ptr })
95 }
96 }
97
98 #[must_use]
99 pub(crate) const fn as_ptr(&self) -> *mut c_void {
100 self.ptr
101 }
102
103 pub fn disable_type_inference(&self) -> Result<()> {
104 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_disable_type_inference(self.ptr) };
106 if ok {
107 Ok(())
108 } else {
109 Err(Error::OperationFailed("failed to disable type inference"))
110 }
111 }
112
113 #[must_use]
114 pub fn optimization_level(&self) -> u64 {
115 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_level(self.ptr) }
117 }
118
119 pub fn set_optimization_level(&self, value: u64) -> Result<()> {
120 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_level(self.ptr, value) };
122 if ok {
123 Ok(())
124 } else {
125 Err(Error::OperationFailed("failed to set optimization level"))
126 }
127 }
128
129 #[must_use]
130 pub fn wait_for_compilation_completion(&self) -> bool {
131 unsafe { ffi::mpsgraph_compilation_descriptor_wait_for_completion(self.ptr) }
133 }
134
135 pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()> {
136 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_set_wait_for_completion(self.ptr, value) };
138 if ok {
139 Ok(())
140 } else {
141 Err(Error::OperationFailed("failed to set waitForCompilationCompletion"))
142 }
143 }
144
145 #[must_use]
146 pub fn optimization_profile(&self) -> u64 {
147 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_profile(self.ptr) }
149 }
150
151 pub fn set_optimization_profile(&self, value: u64) -> Result<()> {
152 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_profile(self.ptr, value) };
154 if ok {
155 Ok(())
156 } else {
157 Err(Error::OperationFailed("failed to set optimization profile"))
158 }
159 }
160
161 #[must_use]
162 pub fn reduced_precision_fast_math(&self) -> usize {
163 unsafe { ffi::mpsgraph_compilation_descriptor_reduced_precision_fast_math(self.ptr) }
165 }
166
167 pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()> {
168 let ok = unsafe {
170 ffi::mpsgraph_compilation_descriptor_set_reduced_precision_fast_math(self.ptr, value)
171 };
172 if ok {
173 Ok(())
174 } else {
175 Err(Error::OperationFailed("failed to set reducedPrecisionFastMath"))
176 }
177 }
178
179 pub fn set_callable(&self, symbol_name: &str, executable: Option<&Executable>) -> Result<()> {
180 let symbol_name =
181 CString::new(symbol_name).map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
182 let executable_ptr = executable.map_or(ptr::null_mut(), Executable::as_ptr);
183 let ok = unsafe {
185 ffi::mpsgraph_compilation_descriptor_set_callable(
186 self.ptr,
187 symbol_name.as_ptr(),
188 executable_ptr,
189 )
190 };
191 if ok {
192 Ok(())
193 } else {
194 Err(Error::OperationFailed("failed to set compilation descriptor callable"))
195 }
196 }
197}
198
199pub struct ExecutionDescriptor {
201 ptr: *mut c_void,
202}
203
204unsafe impl Send for ExecutionDescriptor {}
205unsafe impl Sync for ExecutionDescriptor {}
206
207impl Drop for ExecutionDescriptor {
208 fn drop(&mut self) {
209 release_handle(&mut self.ptr);
210 }
211}
212
213impl ExecutionDescriptor {
214 #[must_use]
215 pub fn new() -> Option<Self> {
216 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_new() };
218 if ptr.is_null() {
219 None
220 } else {
221 Some(Self { ptr })
222 }
223 }
224
225 #[must_use]
226 pub const fn as_ptr(&self) -> *mut c_void {
227 self.ptr
228 }
229
230 #[must_use]
231 pub fn wait_until_completed(&self) -> bool {
232 unsafe { ffi::mpsgraph_execution_descriptor_wait_until_completed(self.ptr) }
234 }
235
236 pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
237 let ok = unsafe { ffi::mpsgraph_execution_descriptor_set_wait_until_completed(self.ptr, value) };
239 if ok {
240 Ok(())
241 } else {
242 Err(Error::OperationFailed("failed to set waitUntilCompleted"))
243 }
244 }
245
246 #[must_use]
247 pub fn compilation_descriptor(&self) -> Option<CompilationDescriptor> {
248 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_compilation_descriptor(self.ptr) };
250 if ptr.is_null() {
251 None
252 } else {
253 Some(CompilationDescriptor { ptr })
254 }
255 }
256
257 pub fn set_compilation_descriptor(&self, descriptor: Option<&CompilationDescriptor>) -> Result<()> {
258 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
259 let ok = unsafe {
261 ffi::mpsgraph_execution_descriptor_set_compilation_descriptor(self.ptr, descriptor_ptr)
262 };
263 if ok {
264 Ok(())
265 } else {
266 Err(Error::OperationFailed("failed to set compilation descriptor"))
267 }
268 }
269}
270
271pub struct ExecutableExecutionDescriptor {
273 ptr: *mut c_void,
274}
275
276unsafe impl Send for ExecutableExecutionDescriptor {}
277unsafe impl Sync for ExecutableExecutionDescriptor {}
278
279impl Drop for ExecutableExecutionDescriptor {
280 fn drop(&mut self) {
281 release_handle(&mut self.ptr);
282 }
283}
284
285impl ExecutableExecutionDescriptor {
286 #[must_use]
287 pub fn new() -> Option<Self> {
288 let ptr = unsafe { ffi::mpsgraph_executable_execution_descriptor_new() };
290 if ptr.is_null() {
291 None
292 } else {
293 Some(Self { ptr })
294 }
295 }
296
297 #[must_use]
298 pub(crate) const fn as_ptr(&self) -> *mut c_void {
299 self.ptr
300 }
301
302 #[must_use]
303 pub fn wait_until_completed(&self) -> bool {
304 unsafe { ffi::mpsgraph_executable_execution_descriptor_wait_until_completed(self.ptr) }
306 }
307
308 pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
309 let ok = unsafe {
311 ffi::mpsgraph_executable_execution_descriptor_set_wait_until_completed(self.ptr, value)
312 };
313 if ok {
314 Ok(())
315 } else {
316 Err(Error::OperationFailed("failed to set executable waitUntilCompleted"))
317 }
318 }
319}
320
321pub struct ExecutableSerializationDescriptor {
323 ptr: *mut c_void,
324}
325
326unsafe impl Send for ExecutableSerializationDescriptor {}
327unsafe impl Sync for ExecutableSerializationDescriptor {}
328
329impl Drop for ExecutableSerializationDescriptor {
330 fn drop(&mut self) {
331 release_handle(&mut self.ptr);
332 }
333}
334
335impl ExecutableSerializationDescriptor {
336 #[must_use]
337 pub fn new() -> Option<Self> {
338 let ptr = unsafe { ffi::mpsgraph_executable_serialization_descriptor_new() };
340 if ptr.is_null() {
341 None
342 } else {
343 Some(Self { ptr })
344 }
345 }
346
347 #[must_use]
348 pub(crate) const fn as_ptr(&self) -> *mut c_void {
349 self.ptr
350 }
351
352 #[must_use]
353 pub fn append(&self) -> bool {
354 unsafe { ffi::mpsgraph_executable_serialization_descriptor_append(self.ptr) }
356 }
357
358 pub fn set_append(&self, value: bool) -> Result<()> {
359 let ok = unsafe { ffi::mpsgraph_executable_serialization_descriptor_set_append(self.ptr, value) };
361 if ok {
362 Ok(())
363 } else {
364 Err(Error::OperationFailed("failed to set append"))
365 }
366 }
367
368 #[must_use]
369 pub fn deployment_platform(&self) -> u64 {
370 unsafe { ffi::mpsgraph_executable_serialization_descriptor_deployment_platform(self.ptr) }
372 }
373
374 pub fn set_deployment_platform(&self, value: u64) -> Result<()> {
375 let ok = unsafe {
377 ffi::mpsgraph_executable_serialization_descriptor_set_deployment_platform(self.ptr, value)
378 };
379 if ok {
380 Ok(())
381 } else {
382 Err(Error::OperationFailed("failed to set deployment platform"))
383 }
384 }
385
386 pub fn minimum_deployment_target(&self) -> Result<String> {
387 copy_string(
388 ffi::mpsgraph_executable_serialization_descriptor_minimum_deployment_target_len,
389 ffi::mpsgraph_executable_serialization_descriptor_copy_minimum_deployment_target,
390 self.ptr,
391 )
392 }
393
394 pub fn set_minimum_deployment_target(&self, value: &str) -> Result<()> {
395 let value = CString::new(value).map_err(|_| Error::OperationFailed("minimum deployment target contained NUL"))?;
396 let ok = unsafe {
398 ffi::mpsgraph_executable_serialization_descriptor_set_minimum_deployment_target(
399 self.ptr,
400 value.as_ptr(),
401 )
402 };
403 if ok {
404 Ok(())
405 } else {
406 Err(Error::OperationFailed("failed to set minimum deployment target"))
407 }
408 }
409}
410
411impl Graph {
412 #[must_use]
414 pub fn options(&self) -> u64 {
415 unsafe { ffi::mpsgraph_graph_options(self.as_ptr()) }
417 }
418
419 pub fn set_options(&self, options: u64) -> Result<()> {
421 let ok = unsafe { ffi::mpsgraph_graph_set_options(self.as_ptr(), options) };
423 if ok {
424 Ok(())
425 } else {
426 Err(Error::OperationFailed("failed to set graph options"))
427 }
428 }
429
430 #[must_use]
432 pub fn placeholder_tensors(&self) -> Vec<Tensor> {
433 let box_handle = unsafe { ffi::mpsgraph_graph_placeholder_tensors(self.as_ptr()) };
435 collect_owned_tensors(box_handle)
436 }
437
438 #[must_use]
440 pub fn compile_with_descriptor(
441 &self,
442 device: Option<&MetalDevice>,
443 feeds: &[FeedDescription<'_>],
444 targets: &[&Tensor],
445 descriptor: Option<&CompilationDescriptor>,
446 ) -> Option<Executable> {
447 let feed_tensors = feeds.iter().map(|feed| feed.tensor.as_ptr()).collect::<Vec<_>>();
448 let shape_lengths = feeds.iter().map(|feed| feed.shape.len()).collect::<Vec<_>>();
449 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
450 let flat_shapes = feeds
451 .iter()
452 .flat_map(|feed| feed.shape.iter().copied())
453 .collect::<Vec<_>>();
454 let target_tensors = targets.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
455 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
456 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
457
458 let ptr = unsafe {
460 ffi::mpsgraph_graph_compile_with_descriptor(
461 self.as_ptr(),
462 device_ptr,
463 feed_tensors.as_ptr(),
464 feeds.len(),
465 flat_shapes.as_ptr(),
466 shape_lengths.as_ptr(),
467 data_types.as_ptr(),
468 target_tensors.as_ptr(),
469 targets.len(),
470 descriptor_ptr,
471 )
472 };
473 if ptr.is_null() {
474 None
475 } else {
476 Some(Executable::from_raw(ptr, targets.len()))
477 }
478 }
479}
480
481impl Executable {
482 #[must_use]
484 pub fn options(&self) -> u64 {
485 unsafe { ffi::mpsgraph_executable_options(self.as_ptr()) }
487 }
488
489 pub fn set_options(&self, options: u64) -> Result<()> {
491 let ok = unsafe { ffi::mpsgraph_executable_set_options(self.as_ptr(), options) };
493 if ok {
494 Ok(())
495 } else {
496 Err(Error::OperationFailed("failed to set executable options"))
497 }
498 }
499
500 #[must_use]
502 pub fn feed_tensors(&self) -> Vec<Tensor> {
503 let box_handle = unsafe { ffi::mpsgraph_executable_feed_tensors(self.as_ptr()) };
505 collect_owned_tensors(box_handle)
506 }
507
508 #[must_use]
510 pub fn target_tensors(&self) -> Vec<Tensor> {
511 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(self.as_ptr()) };
513 collect_owned_tensors(box_handle)
514 }
515
516 pub fn specialize(
518 &self,
519 device: Option<&MetalDevice>,
520 input_types: &[&ShapedType],
521 descriptor: Option<&CompilationDescriptor>,
522 ) -> Result<()> {
523 let input_type_handles = input_types
524 .iter()
525 .map(|value| value.as_ptr())
526 .collect::<Vec<_>>();
527 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
528 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
529
530 let ok = unsafe {
532 ffi::mpsgraph_executable_specialize(
533 self.as_ptr(),
534 device_ptr,
535 input_type_handles.as_ptr(),
536 input_types.len(),
537 descriptor_ptr,
538 )
539 };
540 if ok {
541 Ok(())
542 } else {
543 Err(Error::OperationFailed("failed to specialize executable"))
544 }
545 }
546
547 pub fn output_types(
549 &self,
550 device: Option<&MetalDevice>,
551 input_types: &[&ShapedType],
552 descriptor: Option<&CompilationDescriptor>,
553 ) -> Result<Vec<ShapedType>> {
554 let input_type_handles = input_types
555 .iter()
556 .map(|value| value.as_ptr())
557 .collect::<Vec<_>>();
558 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
559 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
560
561 let box_handle = unsafe {
563 ffi::mpsgraph_executable_get_output_types(
564 self.as_ptr(),
565 device_ptr,
566 input_type_handles.as_ptr(),
567 input_types.len(),
568 descriptor_ptr,
569 )
570 };
571 if box_handle.is_null() {
572 Err(Error::OperationFailed("failed to get executable output types"))
573 } else {
574 Ok(collect_shaped_type_array_box(box_handle))
575 }
576 }
577
578 pub fn run_with_descriptor(
580 &self,
581 command_queue: &CommandQueue,
582 inputs: &[&TensorData],
583 results: Option<&[&TensorData]>,
584 descriptor: Option<&ExecutableExecutionDescriptor>,
585 ) -> Result<Vec<TensorData>> {
586 let input_handles = inputs.iter().map(|value| value.as_ptr()).collect::<Vec<_>>();
587 let result_handles = results
588 .map(|values| values.iter().map(|value| value.as_ptr()).collect::<Vec<_>>())
589 .unwrap_or_default();
590 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
591
592 let box_handle = unsafe {
594 ffi::mpsgraph_executable_run_with_descriptor(
595 self.as_ptr(),
596 command_queue.as_ptr(),
597 input_handles.as_ptr(),
598 inputs.len(),
599 result_handles.as_ptr(),
600 result_handles.len(),
601 descriptor_ptr,
602 )
603 };
604 if box_handle.is_null() {
605 Err(Error::OperationFailed("failed to run executable"))
606 } else {
607 Ok(collect_tensor_data_array_box(box_handle))
608 }
609 }
610
611 pub fn run_async_with_descriptor(
613 &self,
614 command_queue: &CommandQueue,
615 inputs: &[&TensorData],
616 results: Option<&[&TensorData]>,
617 descriptor: Option<&ExecutableExecutionDescriptor>,
618 ) -> Result<Vec<TensorData>> {
619 let input_handles = inputs.iter().map(|value| value.as_ptr()).collect::<Vec<_>>();
620 let result_handles = results
621 .map(|values| values.iter().map(|value| value.as_ptr()).collect::<Vec<_>>())
622 .unwrap_or_default();
623 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
624
625 let box_handle = unsafe {
627 ffi::mpsgraph_executable_run_async_with_descriptor(
628 self.as_ptr(),
629 command_queue.as_ptr(),
630 input_handles.as_ptr(),
631 inputs.len(),
632 result_handles.as_ptr(),
633 result_handles.len(),
634 descriptor_ptr,
635 )
636 };
637 if box_handle.is_null() {
638 Err(Error::OperationFailed("failed to run executable asynchronously"))
639 } else {
640 Ok(collect_tensor_data_array_box(box_handle))
641 }
642 }
643
644 pub fn serialize_package(
646 &self,
647 path: &str,
648 descriptor: Option<&ExecutableSerializationDescriptor>,
649 ) -> Result<()> {
650 let path = CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
651 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), ExecutableSerializationDescriptor::as_ptr);
652 let ok = unsafe { ffi::mpsgraph_executable_serialize_package(self.as_ptr(), path.as_ptr(), descriptor_ptr) };
654 if ok {
655 Ok(())
656 } else {
657 Err(Error::OperationFailed("failed to serialize executable package"))
658 }
659 }
660
661 pub fn from_package(path: &str, descriptor: Option<&CompilationDescriptor>) -> Result<Self> {
663 let path = CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
664 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
665 let ptr = unsafe { ffi::mpsgraph_executable_new_with_package(path.as_ptr(), descriptor_ptr) };
667 if ptr.is_null() {
668 return Err(Error::OperationFailed("failed to load executable package"));
669 }
670 let output_count = {
671 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(ptr) };
673 collect_owned_tensors(box_handle).len()
674 };
675 Ok(Self::from_raw(ptr, output_count))
676 }
677}