Skip to main content

apple_mpsgraph/
call.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::{collect_owned_tensors, ShapedType};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn cstring_ptr(value: &CString) -> *const c_char {
10    value.as_ptr()
11}
12
13fn optional_cstring(name: Option<&str>) -> Option<CString> {
14    name.and_then(|value| CString::new(value).ok())
15}
16
17#[allow(clippy::ref_option)]
18fn optional_name_ptr(value: &Option<CString>) -> *const c_char {
19    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23    if box_handle.is_null() {
24        None
25    } else {
26        Some(collect_owned_tensors(box_handle))
27    }
28}
29
30impl crate::graph::Graph {
31    pub fn call(
32        &self,
33        symbol_name: &str,
34        input_tensors: &[&Tensor],
35        output_types: &[&ShapedType],
36        name: Option<&str>,
37    ) -> Result<Vec<Tensor>> {
38        let symbol_name =
39            CString::new(symbol_name).map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
40        let name = optional_cstring(name);
41        let input_handles = input_tensors.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
42        let output_type_handles = output_types
43            .iter()
44            .map(|output_type| output_type.as_ptr())
45            .collect::<Vec<_>>();
46        let input_ptr = if input_handles.is_empty() {
47            ptr::null()
48        } else {
49            input_handles.as_ptr()
50        };
51        let output_type_ptr = if output_type_handles.is_empty() {
52            ptr::null()
53        } else {
54            output_type_handles.as_ptr()
55        };
56        // SAFETY: all handles remain valid for the duration of the call.
57        let box_handle = unsafe {
58            ffi::mpsgraph_graph_call_symbol(
59                self.as_ptr(),
60                cstring_ptr(&symbol_name),
61                input_ptr,
62                input_handles.len(),
63                output_type_ptr,
64                output_type_handles.len(),
65                optional_name_ptr(&name),
66            )
67        };
68        wrap_tensor_array(box_handle).ok_or(Error::OperationFailed("failed to create call op"))
69    }
70}