1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::{collect_owned_tensors, ShapedType};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn cstring_ptr(value: &CString) -> *const c_char {
10 value.as_ptr()
11}
12
13fn optional_cstring(name: Option<&str>) -> Option<CString> {
14 name.and_then(|value| CString::new(value).ok())
15}
16
17#[allow(clippy::ref_option)]
18fn optional_name_ptr(value: &Option<CString>) -> *const c_char {
19 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23 if box_handle.is_null() {
24 None
25 } else {
26 Some(collect_owned_tensors(box_handle))
27 }
28}
29
30impl crate::graph::Graph {
31 pub fn call(
32 &self,
33 symbol_name: &str,
34 input_tensors: &[&Tensor],
35 output_types: &[&ShapedType],
36 name: Option<&str>,
37 ) -> Result<Vec<Tensor>> {
38 let symbol_name =
39 CString::new(symbol_name).map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
40 let name = optional_cstring(name);
41 let input_handles = input_tensors.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
42 let output_type_handles = output_types
43 .iter()
44 .map(|output_type| output_type.as_ptr())
45 .collect::<Vec<_>>();
46 let input_ptr = if input_handles.is_empty() {
47 ptr::null()
48 } else {
49 input_handles.as_ptr()
50 };
51 let output_type_ptr = if output_type_handles.is_empty() {
52 ptr::null()
53 } else {
54 output_type_handles.as_ptr()
55 };
56 let box_handle = unsafe {
58 ffi::mpsgraph_graph_call_symbol(
59 self.as_ptr(),
60 cstring_ptr(&symbol_name),
61 input_ptr,
62 input_handles.len(),
63 output_type_ptr,
64 output_type_handles.len(),
65 optional_name_ptr(&name),
66 )
67 };
68 wrap_tensor_array(box_handle).ok_or(Error::OperationFailed("failed to create call op"))
69 }
70}