Skip to main content

apple_mpsgraph/
control_flow.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::{collect_owned_tensors, Operation};
4use core::ffi::{c_char, c_void};
5use core::ptr;
6use std::ffi::CString;
7use std::mem;
8use std::panic::{catch_unwind, AssertUnwindSafe};
9
10fn optional_cstring(name: Option<&str>) -> Option<CString> {
11    name.and_then(|value| CString::new(value).ok())
12}
13
14#[allow(clippy::ref_option)]
15fn cstring_ptr(value: &Option<CString>) -> *const c_char {
16    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
17}
18
19fn tensor_array_box_from_tensors(tensors: &[Tensor]) -> *mut c_void {
20    let handles = tensors.iter().map(Tensor::as_ptr).collect::<Vec<_>>();
21    let handles_ptr = if handles.is_empty() {
22        ptr::null()
23    } else {
24        handles.as_ptr()
25    };
26    // SAFETY: the handles stay valid for the duration of the bridge call and the Swift array retains them.
27    unsafe { ffi::mpsgraph_tensor_array_box_new(handles_ptr, handles.len()) }
28}
29
30fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
31    if box_handle.is_null() {
32        None
33    } else {
34        Some(collect_owned_tensors(box_handle))
35    }
36}
37
38fn into_owned_tensor_handle(tensor: Tensor) -> *mut c_void {
39    let ptr = tensor.as_ptr();
40    mem::forget(tensor);
41    ptr
42}
43
44pub struct WhileBeforeResult {
45    pub predicate: Tensor,
46    pub results: Vec<Tensor>,
47}
48
49struct ZeroArgCallbackContext<'a, F> {
50    callback: &'a mut F,
51}
52
53unsafe extern "C" fn zero_arg_tensor_array_trampoline<F>(context: *mut c_void) -> *mut c_void
54where
55    F: FnMut() -> Vec<Tensor>,
56{
57    let context = unsafe { &mut *context.cast::<ZeroArgCallbackContext<'_, F>>() };
58    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)()))
59        .unwrap_or_else(|_| std::process::abort());
60    tensor_array_box_from_tensors(&tensors)
61}
62
63struct WhileBeforeCallbackContext<'a, F> {
64    callback: &'a mut F,
65}
66
67unsafe extern "C" fn while_before_trampoline<F>(
68    context: *mut c_void,
69    input_box_handle: *mut c_void,
70    out_result_box_handle: *mut *mut c_void,
71) -> *mut c_void
72where
73    F: FnMut(&[Tensor]) -> WhileBeforeResult,
74{
75    let context = unsafe { &mut *context.cast::<WhileBeforeCallbackContext<'_, F>>() };
76    let inputs = collect_owned_tensors(input_box_handle);
77    match catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs))) {
78        Ok(result) => {
79            if !out_result_box_handle.is_null() {
80                // SAFETY: caller provided a valid output slot for the tensor-array box.
81                unsafe { *out_result_box_handle = tensor_array_box_from_tensors(&result.results) };
82            }
83            into_owned_tensor_handle(result.predicate)
84        }
85        Err(_) => std::process::abort(),
86    }
87}
88
89struct TensorArrayInputCallbackContext<'a, F> {
90    callback: &'a mut F,
91}
92
93unsafe extern "C" fn tensor_array_input_trampoline<F>(
94    context: *mut c_void,
95    input_box_handle: *mut c_void,
96) -> *mut c_void
97where
98    F: FnMut(&[Tensor]) -> Vec<Tensor>,
99{
100    let context = unsafe { &mut *context.cast::<TensorArrayInputCallbackContext<'_, F>>() };
101    let inputs = collect_owned_tensors(input_box_handle);
102    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs)))
103        .unwrap_or_else(|_| std::process::abort());
104    tensor_array_box_from_tensors(&tensors)
105}
106
107struct ForBodyCallbackContext<'a, F> {
108    callback: &'a mut F,
109}
110
111unsafe extern "C" fn for_body_trampoline<F>(
112    context: *mut c_void,
113    index_handle: *mut c_void,
114    input_box_handle: *mut c_void,
115) -> *mut c_void
116where
117    F: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
118{
119    if index_handle.is_null() {
120        return ptr::null_mut();
121    }
122    let context = unsafe { &mut *context.cast::<ForBodyCallbackContext<'_, F>>() };
123    let index = Tensor::from_raw(index_handle);
124    let inputs = collect_owned_tensors(input_box_handle);
125    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&index, &inputs)))
126        .unwrap_or_else(|_| std::process::abort());
127    tensor_array_box_from_tensors(&tensors)
128}
129
130impl crate::graph::Graph {
131    pub fn control_dependency<F>(
132        &self,
133        operations: &[&Operation],
134        mut dependent_block: F,
135        name: Option<&str>,
136    ) -> Option<Vec<Tensor>>
137    where
138        F: FnMut() -> Vec<Tensor>,
139    {
140        let name = optional_cstring(name);
141        let operation_handles = operations.iter().map(|operation| operation.as_ptr()).collect::<Vec<_>>();
142        let operation_ptr = if operation_handles.is_empty() {
143            ptr::null()
144        } else {
145            operation_handles.as_ptr()
146        };
147        let mut context = ZeroArgCallbackContext {
148            callback: &mut dependent_block,
149        };
150        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
151        let box_handle = unsafe {
152            ffi::mpsgraph_graph_control_dependency(
153                self.as_ptr(),
154                operation_ptr,
155                operation_handles.len(),
156                Some(zero_arg_tensor_array_trampoline::<F>),
157                ptr::from_mut(&mut context).cast(),
158                cstring_ptr(&name),
159            )
160        };
161        wrap_tensor_array(box_handle)
162    }
163
164    pub fn if_then<Then>(
165        &self,
166        predicate: &Tensor,
167        mut then_block: Then,
168        name: Option<&str>,
169    ) -> Option<Vec<Tensor>>
170    where
171        Then: FnMut() -> Vec<Tensor>,
172    {
173        let name = optional_cstring(name);
174        let mut then_context = ZeroArgCallbackContext {
175            callback: &mut then_block,
176        };
177        // SAFETY: the callback context and handles remain valid for the duration of the call.
178        let box_handle = unsafe {
179            ffi::mpsgraph_graph_if_then_else(
180                self.as_ptr(),
181                predicate.as_ptr(),
182                Some(zero_arg_tensor_array_trampoline::<Then>),
183                ptr::from_mut(&mut then_context).cast(),
184                None,
185                ptr::null_mut(),
186                cstring_ptr(&name),
187            )
188        };
189        wrap_tensor_array(box_handle)
190    }
191
192    pub fn if_then_else<Then, Else>(
193        &self,
194        predicate: &Tensor,
195        mut then_block: Then,
196        mut else_block: Else,
197        name: Option<&str>,
198    ) -> Option<Vec<Tensor>>
199    where
200        Then: FnMut() -> Vec<Tensor>,
201        Else: FnMut() -> Vec<Tensor>,
202    {
203        let name = optional_cstring(name);
204        let mut then_context = ZeroArgCallbackContext {
205            callback: &mut then_block,
206        };
207        let mut else_context = ZeroArgCallbackContext {
208            callback: &mut else_block,
209        };
210        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
211        let box_handle = unsafe {
212            ffi::mpsgraph_graph_if_then_else(
213                self.as_ptr(),
214                predicate.as_ptr(),
215                Some(zero_arg_tensor_array_trampoline::<Then>),
216                ptr::from_mut(&mut then_context).cast(),
217                Some(zero_arg_tensor_array_trampoline::<Else>),
218                ptr::from_mut(&mut else_context).cast(),
219                cstring_ptr(&name),
220            )
221        };
222        wrap_tensor_array(box_handle)
223    }
224
225    pub fn while_loop<Before, After>(
226        &self,
227        initial_inputs: &[&Tensor],
228        mut before: Before,
229        mut after: After,
230        name: Option<&str>,
231    ) -> Option<Vec<Tensor>>
232    where
233        Before: FnMut(&[Tensor]) -> WhileBeforeResult,
234        After: FnMut(&[Tensor]) -> Vec<Tensor>,
235    {
236        let name = optional_cstring(name);
237        let input_handles = initial_inputs.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
238        let input_ptr = if input_handles.is_empty() {
239            ptr::null()
240        } else {
241            input_handles.as_ptr()
242        };
243        let mut before_context = WhileBeforeCallbackContext {
244            callback: &mut before,
245        };
246        let mut after_context = TensorArrayInputCallbackContext {
247            callback: &mut after,
248        };
249        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
250        let box_handle = unsafe {
251            ffi::mpsgraph_graph_while_loop(
252                self.as_ptr(),
253                input_ptr,
254                input_handles.len(),
255                Some(while_before_trampoline::<Before>),
256                ptr::from_mut(&mut before_context).cast(),
257                Some(tensor_array_input_trampoline::<After>),
258                ptr::from_mut(&mut after_context).cast(),
259                cstring_ptr(&name),
260            )
261        };
262        wrap_tensor_array(box_handle)
263    }
264
265    #[allow(clippy::too_many_arguments)]
266    pub fn for_loop<Body>(
267        &self,
268        lower_bound: &Tensor,
269        upper_bound: &Tensor,
270        step: &Tensor,
271        initial_body_arguments: &[&Tensor],
272        mut body: Body,
273        name: Option<&str>,
274    ) -> Option<Vec<Tensor>>
275    where
276        Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
277    {
278        let name = optional_cstring(name);
279        let argument_handles = initial_body_arguments
280            .iter()
281            .map(|tensor| tensor.as_ptr())
282            .collect::<Vec<_>>();
283        let argument_ptr = if argument_handles.is_empty() {
284            ptr::null()
285        } else {
286            argument_handles.as_ptr()
287        };
288        let mut context = ForBodyCallbackContext {
289            callback: &mut body,
290        };
291        // SAFETY: the callback context and handles remain valid for the duration of the call.
292        let box_handle = unsafe {
293            ffi::mpsgraph_graph_for_loop(
294                self.as_ptr(),
295                lower_bound.as_ptr(),
296                upper_bound.as_ptr(),
297                step.as_ptr(),
298                argument_ptr,
299                argument_handles.len(),
300                Some(for_body_trampoline::<Body>),
301                ptr::from_mut(&mut context).cast(),
302                cstring_ptr(&name),
303            )
304        };
305        wrap_tensor_array(box_handle)
306    }
307
308    pub fn for_loop_iterations<Body>(
309        &self,
310        number_of_iterations: &Tensor,
311        initial_body_arguments: &[&Tensor],
312        mut body: Body,
313        name: Option<&str>,
314    ) -> Option<Vec<Tensor>>
315    where
316        Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
317    {
318        let name = optional_cstring(name);
319        let argument_handles = initial_body_arguments
320            .iter()
321            .map(|tensor| tensor.as_ptr())
322            .collect::<Vec<_>>();
323        let argument_ptr = if argument_handles.is_empty() {
324            ptr::null()
325        } else {
326            argument_handles.as_ptr()
327        };
328        let mut context = ForBodyCallbackContext {
329            callback: &mut body,
330        };
331        // SAFETY: the callback context and handles remain valid for the duration of the call.
332        let box_handle = unsafe {
333            ffi::mpsgraph_graph_for_loop_iterations(
334                self.as_ptr(),
335                number_of_iterations.as_ptr(),
336                argument_ptr,
337                argument_handles.len(),
338                Some(for_body_trampoline::<Body>),
339                ptr::from_mut(&mut context).cast(),
340                cstring_ptr(&name),
341            )
342        };
343        wrap_tensor_array(box_handle)
344    }
345}