Skip to main content

apple_mpsgraph/
call.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::{collect_owned_tensors, ShapedType};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn cstring_ptr(value: &CString) -> *const c_char {
10    value.as_ptr()
11}
12
13fn optional_cstring(name: Option<&str>) -> Option<CString> {
14    name.and_then(|value| CString::new(value).ok())
15}
16
17#[allow(clippy::ref_option)]
18fn optional_name_ptr(value: &Option<CString>) -> *const c_char {
19    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23    if box_handle.is_null() {
24        None
25    } else {
26        Some(collect_owned_tensors(box_handle))
27    }
28}
29
30impl crate::graph::Graph {
31/// Calls the `MPSGraph` framework counterpart for `call`.
32    pub fn call(
33        &self,
34        symbol_name: &str,
35        input_tensors: &[&Tensor],
36        output_types: &[&ShapedType],
37        name: Option<&str>,
38    ) -> Result<Vec<Tensor>> {
39        let symbol_name = CString::new(symbol_name)
40            .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
41        let name = optional_cstring(name);
42        let input_handles = input_tensors
43            .iter()
44            .map(|tensor| tensor.as_ptr())
45            .collect::<Vec<_>>();
46        let output_type_handles = output_types
47            .iter()
48            .map(|output_type| output_type.as_ptr())
49            .collect::<Vec<_>>();
50        let input_ptr = if input_handles.is_empty() {
51            ptr::null()
52        } else {
53            input_handles.as_ptr()
54        };
55        let output_type_ptr = if output_type_handles.is_empty() {
56            ptr::null()
57        } else {
58            output_type_handles.as_ptr()
59        };
60        // SAFETY: all handles remain valid for the duration of the call.
61        let box_handle = unsafe {
62            ffi::mpsgraph_graph_call_symbol(
63                self.as_ptr(),
64                cstring_ptr(&symbol_name),
65                input_ptr,
66                input_handles.len(),
67                output_type_ptr,
68                output_type_handles.len(),
69                optional_name_ptr(&name),
70            )
71        };
72        wrap_tensor_array(box_handle).ok_or(Error::OperationFailed("failed to create call op"))
73    }
74}