1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::{collect_owned_tensors, Operation};
4use core::ffi::{c_char, c_void};
5use core::ptr;
6use std::ffi::CString;
7use std::mem;
8use std::panic::{catch_unwind, AssertUnwindSafe};
9
10fn optional_cstring(name: Option<&str>) -> Option<CString> {
11 name.and_then(|value| CString::new(value).ok())
12}
13
14#[allow(clippy::ref_option)]
15fn cstring_ptr(value: &Option<CString>) -> *const c_char {
16 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
17}
18
19fn tensor_array_box_from_tensors(tensors: &[Tensor]) -> *mut c_void {
20 let handles = tensors.iter().map(Tensor::as_ptr).collect::<Vec<_>>();
21 let handles_ptr = if handles.is_empty() {
22 ptr::null()
23 } else {
24 handles.as_ptr()
25 };
26 unsafe { ffi::mpsgraph_tensor_array_box_new(handles_ptr, handles.len()) }
28}
29
30fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
31 if box_handle.is_null() {
32 None
33 } else {
34 Some(collect_owned_tensors(box_handle))
35 }
36}
37
38fn into_owned_tensor_handle(tensor: Tensor) -> *mut c_void {
39 let ptr = tensor.as_ptr();
40 mem::forget(tensor);
41 ptr
42}
43
44pub struct WhileBeforeResult {
46pub predicate: Tensor,
48pub results: Vec<Tensor>,
50}
51
52struct ZeroArgCallbackContext<'a, F> {
53 callback: &'a mut F,
54}
55
56unsafe extern "C" fn zero_arg_tensor_array_trampoline<F>(context: *mut c_void) -> *mut c_void
57where
58 F: FnMut() -> Vec<Tensor>,
59{
60 let context = unsafe { &mut *context.cast::<ZeroArgCallbackContext<'_, F>>() };
62 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)()))
63 .unwrap_or_else(|_| std::process::abort());
64 tensor_array_box_from_tensors(&tensors)
65}
66
67struct WhileBeforeCallbackContext<'a, F> {
68 callback: &'a mut F,
69}
70
71unsafe extern "C" fn while_before_trampoline<F>(
72 context: *mut c_void,
73 input_box_handle: *mut c_void,
74 out_result_box_handle: *mut *mut c_void,
75) -> *mut c_void
76where
77 F: FnMut(&[Tensor]) -> WhileBeforeResult,
78{
79 let context = unsafe { &mut *context.cast::<WhileBeforeCallbackContext<'_, F>>() };
81 let inputs = collect_owned_tensors(input_box_handle);
82 match catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs))) {
83 Ok(result) => {
84 if !out_result_box_handle.is_null() {
85 unsafe { *out_result_box_handle = tensor_array_box_from_tensors(&result.results) };
87 }
88 into_owned_tensor_handle(result.predicate)
89 }
90 Err(_) => std::process::abort(),
91 }
92}
93
94struct TensorArrayInputCallbackContext<'a, F> {
95 callback: &'a mut F,
96}
97
98unsafe extern "C" fn tensor_array_input_trampoline<F>(
99 context: *mut c_void,
100 input_box_handle: *mut c_void,
101) -> *mut c_void
102where
103 F: FnMut(&[Tensor]) -> Vec<Tensor>,
104{
105 let context = unsafe { &mut *context.cast::<TensorArrayInputCallbackContext<'_, F>>() };
107 let inputs = collect_owned_tensors(input_box_handle);
108 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs)))
109 .unwrap_or_else(|_| std::process::abort());
110 tensor_array_box_from_tensors(&tensors)
111}
112
113struct ForBodyCallbackContext<'a, F> {
114 callback: &'a mut F,
115}
116
117unsafe extern "C" fn for_body_trampoline<F>(
118 context: *mut c_void,
119 index_handle: *mut c_void,
120 input_box_handle: *mut c_void,
121) -> *mut c_void
122where
123 F: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
124{
125 if index_handle.is_null() {
126 return ptr::null_mut();
127 }
128 let context = unsafe { &mut *context.cast::<ForBodyCallbackContext<'_, F>>() };
130 let index = Tensor::from_raw(index_handle);
131 let inputs = collect_owned_tensors(input_box_handle);
132 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&index, &inputs)))
133 .unwrap_or_else(|_| std::process::abort());
134 tensor_array_box_from_tensors(&tensors)
135}
136
137impl crate::graph::Graph {
138pub fn control_dependency<F>(
140 &self,
141 operations: &[&Operation],
142 mut dependent_block: F,
143 name: Option<&str>,
144 ) -> Option<Vec<Tensor>>
145 where
146 F: FnMut() -> Vec<Tensor>,
147 {
148 let name = optional_cstring(name);
149 let operation_handles = operations
150 .iter()
151 .map(|operation| operation.as_ptr())
152 .collect::<Vec<_>>();
153 let operation_ptr = if operation_handles.is_empty() {
154 ptr::null()
155 } else {
156 operation_handles.as_ptr()
157 };
158 let mut context = ZeroArgCallbackContext {
159 callback: &mut dependent_block,
160 };
161 let box_handle = unsafe {
163 ffi::mpsgraph_graph_control_dependency(
164 self.as_ptr(),
165 operation_ptr,
166 operation_handles.len(),
167 Some(zero_arg_tensor_array_trampoline::<F>),
168 ptr::from_mut(&mut context).cast(),
169 cstring_ptr(&name),
170 )
171 };
172 wrap_tensor_array(box_handle)
173 }
174
175pub fn if_then<Then>(
177 &self,
178 predicate: &Tensor,
179 mut then_block: Then,
180 name: Option<&str>,
181 ) -> Option<Vec<Tensor>>
182 where
183 Then: FnMut() -> Vec<Tensor>,
184 {
185 let name = optional_cstring(name);
186 let mut then_context = ZeroArgCallbackContext {
187 callback: &mut then_block,
188 };
189 let box_handle = unsafe {
191 ffi::mpsgraph_graph_if_then_else(
192 self.as_ptr(),
193 predicate.as_ptr(),
194 Some(zero_arg_tensor_array_trampoline::<Then>),
195 ptr::from_mut(&mut then_context).cast(),
196 None,
197 ptr::null_mut(),
198 cstring_ptr(&name),
199 )
200 };
201 wrap_tensor_array(box_handle)
202 }
203
204pub fn if_then_else<Then, Else>(
206 &self,
207 predicate: &Tensor,
208 mut then_block: Then,
209 mut else_block: Else,
210 name: Option<&str>,
211 ) -> Option<Vec<Tensor>>
212 where
213 Then: FnMut() -> Vec<Tensor>,
214 Else: FnMut() -> Vec<Tensor>,
215 {
216 let name = optional_cstring(name);
217 let mut then_context = ZeroArgCallbackContext {
218 callback: &mut then_block,
219 };
220 let mut else_context = ZeroArgCallbackContext {
221 callback: &mut else_block,
222 };
223 let box_handle = unsafe {
225 ffi::mpsgraph_graph_if_then_else(
226 self.as_ptr(),
227 predicate.as_ptr(),
228 Some(zero_arg_tensor_array_trampoline::<Then>),
229 ptr::from_mut(&mut then_context).cast(),
230 Some(zero_arg_tensor_array_trampoline::<Else>),
231 ptr::from_mut(&mut else_context).cast(),
232 cstring_ptr(&name),
233 )
234 };
235 wrap_tensor_array(box_handle)
236 }
237
238pub fn while_loop<Before, After>(
240 &self,
241 initial_inputs: &[&Tensor],
242 mut before: Before,
243 mut after: After,
244 name: Option<&str>,
245 ) -> Option<Vec<Tensor>>
246 where
247 Before: FnMut(&[Tensor]) -> WhileBeforeResult,
248 After: FnMut(&[Tensor]) -> Vec<Tensor>,
249 {
250 let name = optional_cstring(name);
251 let input_handles = initial_inputs
252 .iter()
253 .map(|tensor| tensor.as_ptr())
254 .collect::<Vec<_>>();
255 let input_ptr = if input_handles.is_empty() {
256 ptr::null()
257 } else {
258 input_handles.as_ptr()
259 };
260 let mut before_context = WhileBeforeCallbackContext {
261 callback: &mut before,
262 };
263 let mut after_context = TensorArrayInputCallbackContext {
264 callback: &mut after,
265 };
266 let box_handle = unsafe {
268 ffi::mpsgraph_graph_while_loop(
269 self.as_ptr(),
270 input_ptr,
271 input_handles.len(),
272 Some(while_before_trampoline::<Before>),
273 ptr::from_mut(&mut before_context).cast(),
274 Some(tensor_array_input_trampoline::<After>),
275 ptr::from_mut(&mut after_context).cast(),
276 cstring_ptr(&name),
277 )
278 };
279 wrap_tensor_array(box_handle)
280 }
281
282#[allow(clippy::too_many_arguments)]
284 pub fn for_loop<Body>(
285 &self,
286 lower_bound: &Tensor,
287 upper_bound: &Tensor,
288 step: &Tensor,
289 initial_body_arguments: &[&Tensor],
290 mut body: Body,
291 name: Option<&str>,
292 ) -> Option<Vec<Tensor>>
293 where
294 Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
295 {
296 let name = optional_cstring(name);
297 let argument_handles = initial_body_arguments
298 .iter()
299 .map(|tensor| tensor.as_ptr())
300 .collect::<Vec<_>>();
301 let argument_ptr = if argument_handles.is_empty() {
302 ptr::null()
303 } else {
304 argument_handles.as_ptr()
305 };
306 let mut context = ForBodyCallbackContext {
307 callback: &mut body,
308 };
309 let box_handle = unsafe {
311 ffi::mpsgraph_graph_for_loop(
312 self.as_ptr(),
313 lower_bound.as_ptr(),
314 upper_bound.as_ptr(),
315 step.as_ptr(),
316 argument_ptr,
317 argument_handles.len(),
318 Some(for_body_trampoline::<Body>),
319 ptr::from_mut(&mut context).cast(),
320 cstring_ptr(&name),
321 )
322 };
323 wrap_tensor_array(box_handle)
324 }
325
326pub fn for_loop_iterations<Body>(
328 &self,
329 number_of_iterations: &Tensor,
330 initial_body_arguments: &[&Tensor],
331 mut body: Body,
332 name: Option<&str>,
333 ) -> Option<Vec<Tensor>>
334 where
335 Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
336 {
337 let name = optional_cstring(name);
338 let argument_handles = initial_body_arguments
339 .iter()
340 .map(|tensor| tensor.as_ptr())
341 .collect::<Vec<_>>();
342 let argument_ptr = if argument_handles.is_empty() {
343 ptr::null()
344 } else {
345 argument_handles.as_ptr()
346 };
347 let mut context = ForBodyCallbackContext {
348 callback: &mut body,
349 };
350 let box_handle = unsafe {
352 ffi::mpsgraph_graph_for_loop_iterations(
353 self.as_ptr(),
354 number_of_iterations.as_ptr(),
355 argument_ptr,
356 argument_handles.len(),
357 Some(for_body_trampoline::<Body>),
358 ptr::from_mut(&mut context).cast(),
359 cstring_ptr(&name),
360 )
361 };
362 wrap_tensor_array(box_handle)
363 }
364}