1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::{Executable, FeedDescription, Graph, Tensor};
5use crate::types::{collect_owned_tensors, collect_shaped_type_array_box, collect_tensor_data_array_box, ShapedType};
6use apple_metal::{CommandQueue, MetalDevice};
7use core::ffi::c_void;
8use core::ptr;
9use std::ffi::CString;
10
11fn release_handle(ptr: &mut *mut c_void) {
12 if !ptr.is_null() {
13 unsafe { ffi::mpsgraph_object_release(*ptr) };
15 *ptr = ptr::null_mut();
16 }
17}
18
19fn copy_string(
20 len: unsafe extern "C" fn(*mut c_void) -> usize,
21 copy: unsafe extern "C" fn(*mut c_void, *mut u8, usize) -> bool,
22 handle: *mut c_void,
23) -> Result<String> {
24 let len = unsafe { len(handle) };
26 let mut bytes = vec![0_u8; len];
27 let ok = unsafe { copy(handle, bytes.as_mut_ptr(), len) };
29 if ok {
30 String::from_utf8(bytes).map_err(|_| Error::OperationFailed("bridge returned invalid UTF-8"))
31 } else {
32 Err(Error::OperationFailed("failed to copy string from bridge"))
33 }
34}
35
36pub mod graph_options {
38 pub const NONE: u64 = 0;
39 pub const SYNCHRONIZE_RESULTS: u64 = 1;
40 pub const VERBOSE: u64 = 2;
41 pub const DEFAULT: u64 = SYNCHRONIZE_RESULTS;
42}
43
44pub mod optimization {
46 pub const LEVEL0: u64 = 0;
47 pub const LEVEL1: u64 = 1;
48}
49
50pub mod optimization_profile {
52 pub const PERFORMANCE: u64 = 0;
53 pub const POWER_EFFICIENCY: u64 = 1;
54}
55
56pub mod reduced_precision_fast_math {
58 pub const NONE: usize = 0;
59 pub const ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE: usize = 1 << 1;
60 pub const ALLOW_FP16_INTERMEDIATES: usize = ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE;
61 pub const DEFAULT: usize = NONE;
62}
63
64pub mod deployment_platform {
66 pub const MACOS: u64 = 0;
67 pub const IOS: u64 = 1;
68 pub const TVOS: u64 = 2;
69 pub const VISIONOS: u64 = 3;
70}
71
72pub struct CompilationDescriptor {
74 ptr: *mut c_void,
75}
76
77unsafe impl Send for CompilationDescriptor {}
78unsafe impl Sync for CompilationDescriptor {}
79
80impl Drop for CompilationDescriptor {
81 fn drop(&mut self) {
82 release_handle(&mut self.ptr);
83 }
84}
85
86impl CompilationDescriptor {
87 #[must_use]
88 pub fn new() -> Option<Self> {
89 let ptr = unsafe { ffi::mpsgraph_compilation_descriptor_new() };
91 if ptr.is_null() {
92 None
93 } else {
94 Some(Self { ptr })
95 }
96 }
97
98 #[must_use]
99 pub(crate) const fn as_ptr(&self) -> *mut c_void {
100 self.ptr
101 }
102
103 pub fn disable_type_inference(&self) -> Result<()> {
104 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_disable_type_inference(self.ptr) };
106 if ok {
107 Ok(())
108 } else {
109 Err(Error::OperationFailed("failed to disable type inference"))
110 }
111 }
112
113 #[must_use]
114 pub fn optimization_level(&self) -> u64 {
115 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_level(self.ptr) }
117 }
118
119 pub fn set_optimization_level(&self, value: u64) -> Result<()> {
120 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_level(self.ptr, value) };
122 if ok {
123 Ok(())
124 } else {
125 Err(Error::OperationFailed("failed to set optimization level"))
126 }
127 }
128
129 #[must_use]
130 pub fn wait_for_compilation_completion(&self) -> bool {
131 unsafe { ffi::mpsgraph_compilation_descriptor_wait_for_completion(self.ptr) }
133 }
134
135 pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()> {
136 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_set_wait_for_completion(self.ptr, value) };
138 if ok {
139 Ok(())
140 } else {
141 Err(Error::OperationFailed("failed to set waitForCompilationCompletion"))
142 }
143 }
144
145 #[must_use]
146 pub fn optimization_profile(&self) -> u64 {
147 unsafe { ffi::mpsgraph_compilation_descriptor_optimization_profile(self.ptr) }
149 }
150
151 pub fn set_optimization_profile(&self, value: u64) -> Result<()> {
152 let ok = unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_profile(self.ptr, value) };
154 if ok {
155 Ok(())
156 } else {
157 Err(Error::OperationFailed("failed to set optimization profile"))
158 }
159 }
160
161 #[must_use]
162 pub fn reduced_precision_fast_math(&self) -> usize {
163 unsafe { ffi::mpsgraph_compilation_descriptor_reduced_precision_fast_math(self.ptr) }
165 }
166
167 pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()> {
168 let ok = unsafe {
170 ffi::mpsgraph_compilation_descriptor_set_reduced_precision_fast_math(self.ptr, value)
171 };
172 if ok {
173 Ok(())
174 } else {
175 Err(Error::OperationFailed("failed to set reducedPrecisionFastMath"))
176 }
177 }
178}
179
180pub struct ExecutionDescriptor {
182 ptr: *mut c_void,
183}
184
185unsafe impl Send for ExecutionDescriptor {}
186unsafe impl Sync for ExecutionDescriptor {}
187
188impl Drop for ExecutionDescriptor {
189 fn drop(&mut self) {
190 release_handle(&mut self.ptr);
191 }
192}
193
194impl ExecutionDescriptor {
195 #[must_use]
196 pub fn new() -> Option<Self> {
197 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_new() };
199 if ptr.is_null() {
200 None
201 } else {
202 Some(Self { ptr })
203 }
204 }
205
206 #[must_use]
207 pub const fn as_ptr(&self) -> *mut c_void {
208 self.ptr
209 }
210
211 #[must_use]
212 pub fn wait_until_completed(&self) -> bool {
213 unsafe { ffi::mpsgraph_execution_descriptor_wait_until_completed(self.ptr) }
215 }
216
217 pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
218 let ok = unsafe { ffi::mpsgraph_execution_descriptor_set_wait_until_completed(self.ptr, value) };
220 if ok {
221 Ok(())
222 } else {
223 Err(Error::OperationFailed("failed to set waitUntilCompleted"))
224 }
225 }
226
227 #[must_use]
228 pub fn compilation_descriptor(&self) -> Option<CompilationDescriptor> {
229 let ptr = unsafe { ffi::mpsgraph_execution_descriptor_compilation_descriptor(self.ptr) };
231 if ptr.is_null() {
232 None
233 } else {
234 Some(CompilationDescriptor { ptr })
235 }
236 }
237
238 pub fn set_compilation_descriptor(&self, descriptor: Option<&CompilationDescriptor>) -> Result<()> {
239 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
240 let ok = unsafe {
242 ffi::mpsgraph_execution_descriptor_set_compilation_descriptor(self.ptr, descriptor_ptr)
243 };
244 if ok {
245 Ok(())
246 } else {
247 Err(Error::OperationFailed("failed to set compilation descriptor"))
248 }
249 }
250}
251
252pub struct ExecutableExecutionDescriptor {
254 ptr: *mut c_void,
255}
256
257unsafe impl Send for ExecutableExecutionDescriptor {}
258unsafe impl Sync for ExecutableExecutionDescriptor {}
259
260impl Drop for ExecutableExecutionDescriptor {
261 fn drop(&mut self) {
262 release_handle(&mut self.ptr);
263 }
264}
265
266impl ExecutableExecutionDescriptor {
267 #[must_use]
268 pub fn new() -> Option<Self> {
269 let ptr = unsafe { ffi::mpsgraph_executable_execution_descriptor_new() };
271 if ptr.is_null() {
272 None
273 } else {
274 Some(Self { ptr })
275 }
276 }
277
278 #[must_use]
279 pub(crate) const fn as_ptr(&self) -> *mut c_void {
280 self.ptr
281 }
282
283 #[must_use]
284 pub fn wait_until_completed(&self) -> bool {
285 unsafe { ffi::mpsgraph_executable_execution_descriptor_wait_until_completed(self.ptr) }
287 }
288
289 pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
290 let ok = unsafe {
292 ffi::mpsgraph_executable_execution_descriptor_set_wait_until_completed(self.ptr, value)
293 };
294 if ok {
295 Ok(())
296 } else {
297 Err(Error::OperationFailed("failed to set executable waitUntilCompleted"))
298 }
299 }
300}
301
302pub struct ExecutableSerializationDescriptor {
304 ptr: *mut c_void,
305}
306
307unsafe impl Send for ExecutableSerializationDescriptor {}
308unsafe impl Sync for ExecutableSerializationDescriptor {}
309
310impl Drop for ExecutableSerializationDescriptor {
311 fn drop(&mut self) {
312 release_handle(&mut self.ptr);
313 }
314}
315
316impl ExecutableSerializationDescriptor {
317 #[must_use]
318 pub fn new() -> Option<Self> {
319 let ptr = unsafe { ffi::mpsgraph_executable_serialization_descriptor_new() };
321 if ptr.is_null() {
322 None
323 } else {
324 Some(Self { ptr })
325 }
326 }
327
328 #[must_use]
329 pub(crate) const fn as_ptr(&self) -> *mut c_void {
330 self.ptr
331 }
332
333 #[must_use]
334 pub fn append(&self) -> bool {
335 unsafe { ffi::mpsgraph_executable_serialization_descriptor_append(self.ptr) }
337 }
338
339 pub fn set_append(&self, value: bool) -> Result<()> {
340 let ok = unsafe { ffi::mpsgraph_executable_serialization_descriptor_set_append(self.ptr, value) };
342 if ok {
343 Ok(())
344 } else {
345 Err(Error::OperationFailed("failed to set append"))
346 }
347 }
348
349 #[must_use]
350 pub fn deployment_platform(&self) -> u64 {
351 unsafe { ffi::mpsgraph_executable_serialization_descriptor_deployment_platform(self.ptr) }
353 }
354
355 pub fn set_deployment_platform(&self, value: u64) -> Result<()> {
356 let ok = unsafe {
358 ffi::mpsgraph_executable_serialization_descriptor_set_deployment_platform(self.ptr, value)
359 };
360 if ok {
361 Ok(())
362 } else {
363 Err(Error::OperationFailed("failed to set deployment platform"))
364 }
365 }
366
367 pub fn minimum_deployment_target(&self) -> Result<String> {
368 copy_string(
369 ffi::mpsgraph_executable_serialization_descriptor_minimum_deployment_target_len,
370 ffi::mpsgraph_executable_serialization_descriptor_copy_minimum_deployment_target,
371 self.ptr,
372 )
373 }
374
375 pub fn set_minimum_deployment_target(&self, value: &str) -> Result<()> {
376 let value = CString::new(value).map_err(|_| Error::OperationFailed("minimum deployment target contained NUL"))?;
377 let ok = unsafe {
379 ffi::mpsgraph_executable_serialization_descriptor_set_minimum_deployment_target(
380 self.ptr,
381 value.as_ptr(),
382 )
383 };
384 if ok {
385 Ok(())
386 } else {
387 Err(Error::OperationFailed("failed to set minimum deployment target"))
388 }
389 }
390}
391
392impl Graph {
393 #[must_use]
395 pub fn options(&self) -> u64 {
396 unsafe { ffi::mpsgraph_graph_options(self.as_ptr()) }
398 }
399
400 pub fn set_options(&self, options: u64) -> Result<()> {
402 let ok = unsafe { ffi::mpsgraph_graph_set_options(self.as_ptr(), options) };
404 if ok {
405 Ok(())
406 } else {
407 Err(Error::OperationFailed("failed to set graph options"))
408 }
409 }
410
411 #[must_use]
413 pub fn placeholder_tensors(&self) -> Vec<Tensor> {
414 let box_handle = unsafe { ffi::mpsgraph_graph_placeholder_tensors(self.as_ptr()) };
416 collect_owned_tensors(box_handle)
417 }
418
419 #[must_use]
421 pub fn compile_with_descriptor(
422 &self,
423 device: Option<&MetalDevice>,
424 feeds: &[FeedDescription<'_>],
425 targets: &[&Tensor],
426 descriptor: Option<&CompilationDescriptor>,
427 ) -> Option<Executable> {
428 let feed_tensors = feeds.iter().map(|feed| feed.tensor.as_ptr()).collect::<Vec<_>>();
429 let shape_lengths = feeds.iter().map(|feed| feed.shape.len()).collect::<Vec<_>>();
430 let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
431 let flat_shapes = feeds
432 .iter()
433 .flat_map(|feed| feed.shape.iter().copied())
434 .collect::<Vec<_>>();
435 let target_tensors = targets.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
436 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
437 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
438
439 let ptr = unsafe {
441 ffi::mpsgraph_graph_compile_with_descriptor(
442 self.as_ptr(),
443 device_ptr,
444 feed_tensors.as_ptr(),
445 feeds.len(),
446 flat_shapes.as_ptr(),
447 shape_lengths.as_ptr(),
448 data_types.as_ptr(),
449 target_tensors.as_ptr(),
450 targets.len(),
451 descriptor_ptr,
452 )
453 };
454 if ptr.is_null() {
455 None
456 } else {
457 Some(Executable::from_raw(ptr, targets.len()))
458 }
459 }
460}
461
462impl Executable {
463 #[must_use]
465 pub fn options(&self) -> u64 {
466 unsafe { ffi::mpsgraph_executable_options(self.as_ptr()) }
468 }
469
470 pub fn set_options(&self, options: u64) -> Result<()> {
472 let ok = unsafe { ffi::mpsgraph_executable_set_options(self.as_ptr(), options) };
474 if ok {
475 Ok(())
476 } else {
477 Err(Error::OperationFailed("failed to set executable options"))
478 }
479 }
480
481 #[must_use]
483 pub fn feed_tensors(&self) -> Vec<Tensor> {
484 let box_handle = unsafe { ffi::mpsgraph_executable_feed_tensors(self.as_ptr()) };
486 collect_owned_tensors(box_handle)
487 }
488
489 #[must_use]
491 pub fn target_tensors(&self) -> Vec<Tensor> {
492 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(self.as_ptr()) };
494 collect_owned_tensors(box_handle)
495 }
496
497 pub fn specialize(
499 &self,
500 device: Option<&MetalDevice>,
501 input_types: &[&ShapedType],
502 descriptor: Option<&CompilationDescriptor>,
503 ) -> Result<()> {
504 let input_type_handles = input_types
505 .iter()
506 .map(|value| value.as_ptr())
507 .collect::<Vec<_>>();
508 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
509 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
510
511 let ok = unsafe {
513 ffi::mpsgraph_executable_specialize(
514 self.as_ptr(),
515 device_ptr,
516 input_type_handles.as_ptr(),
517 input_types.len(),
518 descriptor_ptr,
519 )
520 };
521 if ok {
522 Ok(())
523 } else {
524 Err(Error::OperationFailed("failed to specialize executable"))
525 }
526 }
527
528 pub fn output_types(
530 &self,
531 device: Option<&MetalDevice>,
532 input_types: &[&ShapedType],
533 descriptor: Option<&CompilationDescriptor>,
534 ) -> Result<Vec<ShapedType>> {
535 let input_type_handles = input_types
536 .iter()
537 .map(|value| value.as_ptr())
538 .collect::<Vec<_>>();
539 let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
540 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
541
542 let box_handle = unsafe {
544 ffi::mpsgraph_executable_get_output_types(
545 self.as_ptr(),
546 device_ptr,
547 input_type_handles.as_ptr(),
548 input_types.len(),
549 descriptor_ptr,
550 )
551 };
552 if box_handle.is_null() {
553 Err(Error::OperationFailed("failed to get executable output types"))
554 } else {
555 Ok(collect_shaped_type_array_box(box_handle))
556 }
557 }
558
559 pub fn run_with_descriptor(
561 &self,
562 command_queue: &CommandQueue,
563 inputs: &[&TensorData],
564 results: Option<&[&TensorData]>,
565 descriptor: Option<&ExecutableExecutionDescriptor>,
566 ) -> Result<Vec<TensorData>> {
567 let input_handles = inputs.iter().map(|value| value.as_ptr()).collect::<Vec<_>>();
568 let result_handles = results
569 .map(|values| values.iter().map(|value| value.as_ptr()).collect::<Vec<_>>())
570 .unwrap_or_default();
571 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
572
573 let box_handle = unsafe {
575 ffi::mpsgraph_executable_run_with_descriptor(
576 self.as_ptr(),
577 command_queue.as_ptr(),
578 input_handles.as_ptr(),
579 inputs.len(),
580 result_handles.as_ptr(),
581 result_handles.len(),
582 descriptor_ptr,
583 )
584 };
585 if box_handle.is_null() {
586 Err(Error::OperationFailed("failed to run executable"))
587 } else {
588 Ok(collect_tensor_data_array_box(box_handle))
589 }
590 }
591
592 pub fn run_async_with_descriptor(
594 &self,
595 command_queue: &CommandQueue,
596 inputs: &[&TensorData],
597 results: Option<&[&TensorData]>,
598 descriptor: Option<&ExecutableExecutionDescriptor>,
599 ) -> Result<Vec<TensorData>> {
600 let input_handles = inputs.iter().map(|value| value.as_ptr()).collect::<Vec<_>>();
601 let result_handles = results
602 .map(|values| values.iter().map(|value| value.as_ptr()).collect::<Vec<_>>())
603 .unwrap_or_default();
604 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
605
606 let box_handle = unsafe {
608 ffi::mpsgraph_executable_run_async_with_descriptor(
609 self.as_ptr(),
610 command_queue.as_ptr(),
611 input_handles.as_ptr(),
612 inputs.len(),
613 result_handles.as_ptr(),
614 result_handles.len(),
615 descriptor_ptr,
616 )
617 };
618 if box_handle.is_null() {
619 Err(Error::OperationFailed("failed to run executable asynchronously"))
620 } else {
621 Ok(collect_tensor_data_array_box(box_handle))
622 }
623 }
624
625 pub fn serialize_package(
627 &self,
628 path: &str,
629 descriptor: Option<&ExecutableSerializationDescriptor>,
630 ) -> Result<()> {
631 let path = CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
632 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), ExecutableSerializationDescriptor::as_ptr);
633 let ok = unsafe { ffi::mpsgraph_executable_serialize_package(self.as_ptr(), path.as_ptr(), descriptor_ptr) };
635 if ok {
636 Ok(())
637 } else {
638 Err(Error::OperationFailed("failed to serialize executable package"))
639 }
640 }
641
642 pub fn from_package(path: &str, descriptor: Option<&CompilationDescriptor>) -> Result<Self> {
644 let path = CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
645 let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
646 let ptr = unsafe { ffi::mpsgraph_executable_new_with_package(path.as_ptr(), descriptor_ptr) };
648 if ptr.is_null() {
649 return Err(Error::OperationFailed("failed to load executable package"));
650 }
651 let output_count = {
652 let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(ptr) };
654 collect_owned_tensors(box_handle).len()
655 };
656 Ok(Self::from_raw(ptr, output_count))
657 }
658}