1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::{collect_owned_tensors, Operation};
4use core::ffi::{c_char, c_void};
5use core::ptr;
6use std::ffi::CString;
7use std::mem;
8use std::panic::{catch_unwind, AssertUnwindSafe};
9
10fn optional_cstring(name: Option<&str>) -> Option<CString> {
11 name.and_then(|value| CString::new(value).ok())
12}
13
14#[allow(clippy::ref_option)]
15fn cstring_ptr(value: &Option<CString>) -> *const c_char {
16 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
17}
18
19fn tensor_array_box_from_tensors(tensors: &[Tensor]) -> *mut c_void {
20 let handles = tensors.iter().map(Tensor::as_ptr).collect::<Vec<_>>();
21 let handles_ptr = if handles.is_empty() {
22 ptr::null()
23 } else {
24 handles.as_ptr()
25 };
26 unsafe { ffi::mpsgraph_tensor_array_box_new(handles_ptr, handles.len()) }
28}
29
30fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
31 if box_handle.is_null() {
32 None
33 } else {
34 Some(collect_owned_tensors(box_handle))
35 }
36}
37
38fn into_owned_tensor_handle(tensor: Tensor) -> *mut c_void {
39 let ptr = tensor.as_ptr();
40 mem::forget(tensor);
41 ptr
42}
43
44pub struct WhileBeforeResult {
45 pub predicate: Tensor,
46 pub results: Vec<Tensor>,
47}
48
49struct ZeroArgCallbackContext<'a, F> {
50 callback: &'a mut F,
51}
52
53unsafe extern "C" fn zero_arg_tensor_array_trampoline<F>(context: *mut c_void) -> *mut c_void
54where
55 F: FnMut() -> Vec<Tensor>,
56{
57 let context = unsafe { &mut *context.cast::<ZeroArgCallbackContext<'_, F>>() };
59 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)()))
60 .unwrap_or_else(|_| std::process::abort());
61 tensor_array_box_from_tensors(&tensors)
62}
63
64struct WhileBeforeCallbackContext<'a, F> {
65 callback: &'a mut F,
66}
67
68unsafe extern "C" fn while_before_trampoline<F>(
69 context: *mut c_void,
70 input_box_handle: *mut c_void,
71 out_result_box_handle: *mut *mut c_void,
72) -> *mut c_void
73where
74 F: FnMut(&[Tensor]) -> WhileBeforeResult,
75{
76 let context = unsafe { &mut *context.cast::<WhileBeforeCallbackContext<'_, F>>() };
78 let inputs = collect_owned_tensors(input_box_handle);
79 match catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs))) {
80 Ok(result) => {
81 if !out_result_box_handle.is_null() {
82 unsafe { *out_result_box_handle = tensor_array_box_from_tensors(&result.results) };
84 }
85 into_owned_tensor_handle(result.predicate)
86 }
87 Err(_) => std::process::abort(),
88 }
89}
90
91struct TensorArrayInputCallbackContext<'a, F> {
92 callback: &'a mut F,
93}
94
95unsafe extern "C" fn tensor_array_input_trampoline<F>(
96 context: *mut c_void,
97 input_box_handle: *mut c_void,
98) -> *mut c_void
99where
100 F: FnMut(&[Tensor]) -> Vec<Tensor>,
101{
102 let context = unsafe { &mut *context.cast::<TensorArrayInputCallbackContext<'_, F>>() };
104 let inputs = collect_owned_tensors(input_box_handle);
105 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs)))
106 .unwrap_or_else(|_| std::process::abort());
107 tensor_array_box_from_tensors(&tensors)
108}
109
110struct ForBodyCallbackContext<'a, F> {
111 callback: &'a mut F,
112}
113
114unsafe extern "C" fn for_body_trampoline<F>(
115 context: *mut c_void,
116 index_handle: *mut c_void,
117 input_box_handle: *mut c_void,
118) -> *mut c_void
119where
120 F: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
121{
122 if index_handle.is_null() {
123 return ptr::null_mut();
124 }
125 let context = unsafe { &mut *context.cast::<ForBodyCallbackContext<'_, F>>() };
127 let index = Tensor::from_raw(index_handle);
128 let inputs = collect_owned_tensors(input_box_handle);
129 let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&index, &inputs)))
130 .unwrap_or_else(|_| std::process::abort());
131 tensor_array_box_from_tensors(&tensors)
132}
133
134impl crate::graph::Graph {
135 pub fn control_dependency<F>(
136 &self,
137 operations: &[&Operation],
138 mut dependent_block: F,
139 name: Option<&str>,
140 ) -> Option<Vec<Tensor>>
141 where
142 F: FnMut() -> Vec<Tensor>,
143 {
144 let name = optional_cstring(name);
145 let operation_handles = operations
146 .iter()
147 .map(|operation| operation.as_ptr())
148 .collect::<Vec<_>>();
149 let operation_ptr = if operation_handles.is_empty() {
150 ptr::null()
151 } else {
152 operation_handles.as_ptr()
153 };
154 let mut context = ZeroArgCallbackContext {
155 callback: &mut dependent_block,
156 };
157 let box_handle = unsafe {
159 ffi::mpsgraph_graph_control_dependency(
160 self.as_ptr(),
161 operation_ptr,
162 operation_handles.len(),
163 Some(zero_arg_tensor_array_trampoline::<F>),
164 ptr::from_mut(&mut context).cast(),
165 cstring_ptr(&name),
166 )
167 };
168 wrap_tensor_array(box_handle)
169 }
170
171 pub fn if_then<Then>(
172 &self,
173 predicate: &Tensor,
174 mut then_block: Then,
175 name: Option<&str>,
176 ) -> Option<Vec<Tensor>>
177 where
178 Then: FnMut() -> Vec<Tensor>,
179 {
180 let name = optional_cstring(name);
181 let mut then_context = ZeroArgCallbackContext {
182 callback: &mut then_block,
183 };
184 let box_handle = unsafe {
186 ffi::mpsgraph_graph_if_then_else(
187 self.as_ptr(),
188 predicate.as_ptr(),
189 Some(zero_arg_tensor_array_trampoline::<Then>),
190 ptr::from_mut(&mut then_context).cast(),
191 None,
192 ptr::null_mut(),
193 cstring_ptr(&name),
194 )
195 };
196 wrap_tensor_array(box_handle)
197 }
198
199 pub fn if_then_else<Then, Else>(
200 &self,
201 predicate: &Tensor,
202 mut then_block: Then,
203 mut else_block: Else,
204 name: Option<&str>,
205 ) -> Option<Vec<Tensor>>
206 where
207 Then: FnMut() -> Vec<Tensor>,
208 Else: FnMut() -> Vec<Tensor>,
209 {
210 let name = optional_cstring(name);
211 let mut then_context = ZeroArgCallbackContext {
212 callback: &mut then_block,
213 };
214 let mut else_context = ZeroArgCallbackContext {
215 callback: &mut else_block,
216 };
217 let box_handle = unsafe {
219 ffi::mpsgraph_graph_if_then_else(
220 self.as_ptr(),
221 predicate.as_ptr(),
222 Some(zero_arg_tensor_array_trampoline::<Then>),
223 ptr::from_mut(&mut then_context).cast(),
224 Some(zero_arg_tensor_array_trampoline::<Else>),
225 ptr::from_mut(&mut else_context).cast(),
226 cstring_ptr(&name),
227 )
228 };
229 wrap_tensor_array(box_handle)
230 }
231
232 pub fn while_loop<Before, After>(
233 &self,
234 initial_inputs: &[&Tensor],
235 mut before: Before,
236 mut after: After,
237 name: Option<&str>,
238 ) -> Option<Vec<Tensor>>
239 where
240 Before: FnMut(&[Tensor]) -> WhileBeforeResult,
241 After: FnMut(&[Tensor]) -> Vec<Tensor>,
242 {
243 let name = optional_cstring(name);
244 let input_handles = initial_inputs
245 .iter()
246 .map(|tensor| tensor.as_ptr())
247 .collect::<Vec<_>>();
248 let input_ptr = if input_handles.is_empty() {
249 ptr::null()
250 } else {
251 input_handles.as_ptr()
252 };
253 let mut before_context = WhileBeforeCallbackContext {
254 callback: &mut before,
255 };
256 let mut after_context = TensorArrayInputCallbackContext {
257 callback: &mut after,
258 };
259 let box_handle = unsafe {
261 ffi::mpsgraph_graph_while_loop(
262 self.as_ptr(),
263 input_ptr,
264 input_handles.len(),
265 Some(while_before_trampoline::<Before>),
266 ptr::from_mut(&mut before_context).cast(),
267 Some(tensor_array_input_trampoline::<After>),
268 ptr::from_mut(&mut after_context).cast(),
269 cstring_ptr(&name),
270 )
271 };
272 wrap_tensor_array(box_handle)
273 }
274
275 #[allow(clippy::too_many_arguments)]
276 pub fn for_loop<Body>(
277 &self,
278 lower_bound: &Tensor,
279 upper_bound: &Tensor,
280 step: &Tensor,
281 initial_body_arguments: &[&Tensor],
282 mut body: Body,
283 name: Option<&str>,
284 ) -> Option<Vec<Tensor>>
285 where
286 Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
287 {
288 let name = optional_cstring(name);
289 let argument_handles = initial_body_arguments
290 .iter()
291 .map(|tensor| tensor.as_ptr())
292 .collect::<Vec<_>>();
293 let argument_ptr = if argument_handles.is_empty() {
294 ptr::null()
295 } else {
296 argument_handles.as_ptr()
297 };
298 let mut context = ForBodyCallbackContext {
299 callback: &mut body,
300 };
301 let box_handle = unsafe {
303 ffi::mpsgraph_graph_for_loop(
304 self.as_ptr(),
305 lower_bound.as_ptr(),
306 upper_bound.as_ptr(),
307 step.as_ptr(),
308 argument_ptr,
309 argument_handles.len(),
310 Some(for_body_trampoline::<Body>),
311 ptr::from_mut(&mut context).cast(),
312 cstring_ptr(&name),
313 )
314 };
315 wrap_tensor_array(box_handle)
316 }
317
318 pub fn for_loop_iterations<Body>(
319 &self,
320 number_of_iterations: &Tensor,
321 initial_body_arguments: &[&Tensor],
322 mut body: Body,
323 name: Option<&str>,
324 ) -> Option<Vec<Tensor>>
325 where
326 Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
327 {
328 let name = optional_cstring(name);
329 let argument_handles = initial_body_arguments
330 .iter()
331 .map(|tensor| tensor.as_ptr())
332 .collect::<Vec<_>>();
333 let argument_ptr = if argument_handles.is_empty() {
334 ptr::null()
335 } else {
336 argument_handles.as_ptr()
337 };
338 let mut context = ForBodyCallbackContext {
339 callback: &mut body,
340 };
341 let box_handle = unsafe {
343 ffi::mpsgraph_graph_for_loop_iterations(
344 self.as_ptr(),
345 number_of_iterations.as_ptr(),
346 argument_ptr,
347 argument_handles.len(),
348 Some(for_body_trampoline::<Body>),
349 ptr::from_mut(&mut context).cast(),
350 cstring_ptr(&name),
351 )
352 };
353 wrap_tensor_array(box_handle)
354 }
355}