1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::{collect_owned_tensors, ShapedType};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn cstring_ptr(value: &CString) -> *const c_char {
10 value.as_ptr()
11}
12
13fn optional_cstring(name: Option<&str>) -> Option<CString> {
14 name.and_then(|value| CString::new(value).ok())
15}
16
17#[allow(clippy::ref_option)]
18fn optional_name_ptr(value: &Option<CString>) -> *const c_char {
19 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23 if box_handle.is_null() {
24 None
25 } else {
26 Some(collect_owned_tensors(box_handle))
27 }
28}
29
30impl crate::graph::Graph {
31 pub fn call(
32 &self,
33 symbol_name: &str,
34 input_tensors: &[&Tensor],
35 output_types: &[&ShapedType],
36 name: Option<&str>,
37 ) -> Result<Vec<Tensor>> {
38 let symbol_name = CString::new(symbol_name)
39 .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
40 let name = optional_cstring(name);
41 let input_handles = input_tensors
42 .iter()
43 .map(|tensor| tensor.as_ptr())
44 .collect::<Vec<_>>();
45 let output_type_handles = output_types
46 .iter()
47 .map(|output_type| output_type.as_ptr())
48 .collect::<Vec<_>>();
49 let input_ptr = if input_handles.is_empty() {
50 ptr::null()
51 } else {
52 input_handles.as_ptr()
53 };
54 let output_type_ptr = if output_type_handles.is_empty() {
55 ptr::null()
56 } else {
57 output_type_handles.as_ptr()
58 };
59 let box_handle = unsafe {
61 ffi::mpsgraph_graph_call_symbol(
62 self.as_ptr(),
63 cstring_ptr(&symbol_name),
64 input_ptr,
65 input_handles.len(),
66 output_type_ptr,
67 output_type_handles.len(),
68 optional_name_ptr(&name),
69 )
70 };
71 wrap_tensor_array(box_handle).ok_or(Error::OperationFailed("failed to create call op"))
72 }
73}