1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::{collect_owned_tensors, ShapedType};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn cstring_ptr(value: &CString) -> *const c_char {
10 value.as_ptr()
11}
12
13fn optional_cstring(name: Option<&str>) -> Option<CString> {
14 name.and_then(|value| CString::new(value).ok())
15}
16
17#[allow(clippy::ref_option)]
18fn optional_name_ptr(value: &Option<CString>) -> *const c_char {
19 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23 if box_handle.is_null() {
24 None
25 } else {
26 Some(collect_owned_tensors(box_handle))
27 }
28}
29
30impl crate::graph::Graph {
31pub fn call(
33 &self,
34 symbol_name: &str,
35 input_tensors: &[&Tensor],
36 output_types: &[&ShapedType],
37 name: Option<&str>,
38 ) -> Result<Vec<Tensor>> {
39 let symbol_name = CString::new(symbol_name)
40 .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
41 let name = optional_cstring(name);
42 let input_handles = input_tensors
43 .iter()
44 .map(|tensor| tensor.as_ptr())
45 .collect::<Vec<_>>();
46 let output_type_handles = output_types
47 .iter()
48 .map(|output_type| output_type.as_ptr())
49 .collect::<Vec<_>>();
50 let input_ptr = if input_handles.is_empty() {
51 ptr::null()
52 } else {
53 input_handles.as_ptr()
54 };
55 let output_type_ptr = if output_type_handles.is_empty() {
56 ptr::null()
57 } else {
58 output_type_handles.as_ptr()
59 };
60 let box_handle = unsafe {
62 ffi::mpsgraph_graph_call_symbol(
63 self.as_ptr(),
64 cstring_ptr(&symbol_name),
65 input_ptr,
66 input_handles.len(),
67 output_type_ptr,
68 output_type_handles.len(),
69 optional_name_ptr(&name),
70 )
71 };
72 wrap_tensor_array(box_handle).ok_or(Error::OperationFailed("failed to create call op"))
73 }
74}