1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::{collect_owned_tensors, Operation};
4use core::ffi::{c_char, c_void};
5use core::ptr;
6use std::ffi::CString;
7use std::mem;
8use std::panic::{catch_unwind, AssertUnwindSafe};
9
10fn optional_cstring(name: Option<&str>) -> Option<CString> {
11 name.and_then(|value| CString::new(value).ok())
12}
13
14#[allow(clippy::ref_option)]
15fn cstring_ptr(value: &Option<CString>) -> *const c_char {
16 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
17}
18
19fn tensor_array_box_from_tensors(tensors: &[Tensor]) -> *mut c_void {
20 let handles = tensors.iter().map(Tensor::as_ptr).collect::<Vec<_>>();
21 let handles_ptr = if handles.is_empty() {
22 ptr::null()
23 } else {
24 handles.as_ptr()
25 };
26 unsafe { ffi::mpsgraph_tensor_array_box_new(handles_ptr, handles.len()) }
28}
29
30fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
31 if box_handle.is_null() {
32 None
33 } else {
34 Some(collect_owned_tensors(box_handle))
35 }
36}
37
38fn into_owned_tensor_handle(tensor: Tensor) -> *mut c_void {
39 let ptr = tensor.as_ptr();
40 mem::forget(tensor);
41 ptr
42}
43
44pub struct WhileBeforeResult {
45 pub predicate: Tensor,
46 pub results: Vec<Tensor>,
47}
48
49struct ZeroArgCallbackContext<'a, F> {
50 callback: &'a mut F,
51}
52
53unsafe extern "C" fn zero_arg_tensor_array_trampoline<F>(context: *mut c_void) -> *mut c_void
54where
55 F: FnMut() -> Vec<Tensor>,
56{
57 let context = unsafe { &mut *context.cast::<ZeroArgCallbackContext<'_, F>>() };
58 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)()))
59 .unwrap_or_else(|_| std::process::abort());
60 tensor_array_box_from_tensors(&tensors)
61}
62
63struct WhileBeforeCallbackContext<'a, F> {
64 callback: &'a mut F,
65}
66
67unsafe extern "C" fn while_before_trampoline<F>(
68 context: *mut c_void,
69 input_box_handle: *mut c_void,
70 out_result_box_handle: *mut *mut c_void,
71) -> *mut c_void
72where
73 F: FnMut(&[Tensor]) -> WhileBeforeResult,
74{
75 let context = unsafe { &mut *context.cast::<WhileBeforeCallbackContext<'_, F>>() };
76 let inputs = collect_owned_tensors(input_box_handle);
77 match catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs))) {
78 Ok(result) => {
79 if !out_result_box_handle.is_null() {
80 unsafe { *out_result_box_handle = tensor_array_box_from_tensors(&result.results) };
82 }
83 into_owned_tensor_handle(result.predicate)
84 }
85 Err(_) => std::process::abort(),
86 }
87}
88
89struct TensorArrayInputCallbackContext<'a, F> {
90 callback: &'a mut F,
91}
92
93unsafe extern "C" fn tensor_array_input_trampoline<F>(
94 context: *mut c_void,
95 input_box_handle: *mut c_void,
96) -> *mut c_void
97where
98 F: FnMut(&[Tensor]) -> Vec<Tensor>,
99{
100 let context = unsafe { &mut *context.cast::<TensorArrayInputCallbackContext<'_, F>>() };
101 let inputs = collect_owned_tensors(input_box_handle);
102 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs)))
103 .unwrap_or_else(|_| std::process::abort());
104 tensor_array_box_from_tensors(&tensors)
105}
106
107struct ForBodyCallbackContext<'a, F> {
108 callback: &'a mut F,
109}
110
111unsafe extern "C" fn for_body_trampoline<F>(
112 context: *mut c_void,
113 index_handle: *mut c_void,
114 input_box_handle: *mut c_void,
115) -> *mut c_void
116where
117 F: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
118{
119 if index_handle.is_null() {
120 return ptr::null_mut();
121 }
122 let context = unsafe { &mut *context.cast::<ForBodyCallbackContext<'_, F>>() };
123 let index = Tensor::from_raw(index_handle);
124 let inputs = collect_owned_tensors(input_box_handle);
125 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&index, &inputs)))
126 .unwrap_or_else(|_| std::process::abort());
127 tensor_array_box_from_tensors(&tensors)
128}
129
130impl crate::graph::Graph {
131 pub fn control_dependency<F>(
132 &self,
133 operations: &[&Operation],
134 mut dependent_block: F,
135 name: Option<&str>,
136 ) -> Option<Vec<Tensor>>
137 where
138 F: FnMut() -> Vec<Tensor>,
139 {
140 let name = optional_cstring(name);
141 let operation_handles = operations.iter().map(|operation| operation.as_ptr()).collect::<Vec<_>>();
142 let operation_ptr = if operation_handles.is_empty() {
143 ptr::null()
144 } else {
145 operation_handles.as_ptr()
146 };
147 let mut context = ZeroArgCallbackContext {
148 callback: &mut dependent_block,
149 };
150 let box_handle = unsafe {
152 ffi::mpsgraph_graph_control_dependency(
153 self.as_ptr(),
154 operation_ptr,
155 operation_handles.len(),
156 Some(zero_arg_tensor_array_trampoline::<F>),
157 ptr::from_mut(&mut context).cast(),
158 cstring_ptr(&name),
159 )
160 };
161 wrap_tensor_array(box_handle)
162 }
163
164 pub fn if_then<Then>(
165 &self,
166 predicate: &Tensor,
167 mut then_block: Then,
168 name: Option<&str>,
169 ) -> Option<Vec<Tensor>>
170 where
171 Then: FnMut() -> Vec<Tensor>,
172 {
173 let name = optional_cstring(name);
174 let mut then_context = ZeroArgCallbackContext {
175 callback: &mut then_block,
176 };
177 let box_handle = unsafe {
179 ffi::mpsgraph_graph_if_then_else(
180 self.as_ptr(),
181 predicate.as_ptr(),
182 Some(zero_arg_tensor_array_trampoline::<Then>),
183 ptr::from_mut(&mut then_context).cast(),
184 None,
185 ptr::null_mut(),
186 cstring_ptr(&name),
187 )
188 };
189 wrap_tensor_array(box_handle)
190 }
191
192 pub fn if_then_else<Then, Else>(
193 &self,
194 predicate: &Tensor,
195 mut then_block: Then,
196 mut else_block: Else,
197 name: Option<&str>,
198 ) -> Option<Vec<Tensor>>
199 where
200 Then: FnMut() -> Vec<Tensor>,
201 Else: FnMut() -> Vec<Tensor>,
202 {
203 let name = optional_cstring(name);
204 let mut then_context = ZeroArgCallbackContext {
205 callback: &mut then_block,
206 };
207 let mut else_context = ZeroArgCallbackContext {
208 callback: &mut else_block,
209 };
210 let box_handle = unsafe {
212 ffi::mpsgraph_graph_if_then_else(
213 self.as_ptr(),
214 predicate.as_ptr(),
215 Some(zero_arg_tensor_array_trampoline::<Then>),
216 ptr::from_mut(&mut then_context).cast(),
217 Some(zero_arg_tensor_array_trampoline::<Else>),
218 ptr::from_mut(&mut else_context).cast(),
219 cstring_ptr(&name),
220 )
221 };
222 wrap_tensor_array(box_handle)
223 }
224
225 pub fn while_loop<Before, After>(
226 &self,
227 initial_inputs: &[&Tensor],
228 mut before: Before,
229 mut after: After,
230 name: Option<&str>,
231 ) -> Option<Vec<Tensor>>
232 where
233 Before: FnMut(&[Tensor]) -> WhileBeforeResult,
234 After: FnMut(&[Tensor]) -> Vec<Tensor>,
235 {
236 let name = optional_cstring(name);
237 let input_handles = initial_inputs.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
238 let input_ptr = if input_handles.is_empty() {
239 ptr::null()
240 } else {
241 input_handles.as_ptr()
242 };
243 let mut before_context = WhileBeforeCallbackContext {
244 callback: &mut before,
245 };
246 let mut after_context = TensorArrayInputCallbackContext {
247 callback: &mut after,
248 };
249 let box_handle = unsafe {
251 ffi::mpsgraph_graph_while_loop(
252 self.as_ptr(),
253 input_ptr,
254 input_handles.len(),
255 Some(while_before_trampoline::<Before>),
256 ptr::from_mut(&mut before_context).cast(),
257 Some(tensor_array_input_trampoline::<After>),
258 ptr::from_mut(&mut after_context).cast(),
259 cstring_ptr(&name),
260 )
261 };
262 wrap_tensor_array(box_handle)
263 }
264
265 #[allow(clippy::too_many_arguments)]
266 pub fn for_loop<Body>(
267 &self,
268 lower_bound: &Tensor,
269 upper_bound: &Tensor,
270 step: &Tensor,
271 initial_body_arguments: &[&Tensor],
272 mut body: Body,
273 name: Option<&str>,
274 ) -> Option<Vec<Tensor>>
275 where
276 Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
277 {
278 let name = optional_cstring(name);
279 let argument_handles = initial_body_arguments
280 .iter()
281 .map(|tensor| tensor.as_ptr())
282 .collect::<Vec<_>>();
283 let argument_ptr = if argument_handles.is_empty() {
284 ptr::null()
285 } else {
286 argument_handles.as_ptr()
287 };
288 let mut context = ForBodyCallbackContext {
289 callback: &mut body,
290 };
291 let box_handle = unsafe {
293 ffi::mpsgraph_graph_for_loop(
294 self.as_ptr(),
295 lower_bound.as_ptr(),
296 upper_bound.as_ptr(),
297 step.as_ptr(),
298 argument_ptr,
299 argument_handles.len(),
300 Some(for_body_trampoline::<Body>),
301 ptr::from_mut(&mut context).cast(),
302 cstring_ptr(&name),
303 )
304 };
305 wrap_tensor_array(box_handle)
306 }
307
308 pub fn for_loop_iterations<Body>(
309 &self,
310 number_of_iterations: &Tensor,
311 initial_body_arguments: &[&Tensor],
312 mut body: Body,
313 name: Option<&str>,
314 ) -> Option<Vec<Tensor>>
315 where
316 Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
317 {
318 let name = optional_cstring(name);
319 let argument_handles = initial_body_arguments
320 .iter()
321 .map(|tensor| tensor.as_ptr())
322 .collect::<Vec<_>>();
323 let argument_ptr = if argument_handles.is_empty() {
324 ptr::null()
325 } else {
326 argument_handles.as_ptr()
327 };
328 let mut context = ForBodyCallbackContext {
329 callback: &mut body,
330 };
331 let box_handle = unsafe {
333 ffi::mpsgraph_graph_for_loop_iterations(
334 self.as_ptr(),
335 number_of_iterations.as_ptr(),
336 argument_ptr,
337 argument_handles.len(),
338 Some(for_body_trampoline::<Body>),
339 ptr::from_mut(&mut context).cast(),
340 cstring_ptr(&name),
341 )
342 };
343 wrap_tensor_array(box_handle)
344 }
345}