Skip to main content

apple_mpsgraph/
control_flow.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::{collect_owned_tensors, Operation};
4use core::ffi::{c_char, c_void};
5use core::ptr;
6use std::ffi::CString;
7use std::mem;
8use std::panic::{catch_unwind, AssertUnwindSafe};
9
10fn optional_cstring(name: Option<&str>) -> Option<CString> {
11    name.and_then(|value| CString::new(value).ok())
12}
13
14#[allow(clippy::ref_option)]
15fn cstring_ptr(value: &Option<CString>) -> *const c_char {
16    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
17}
18
19fn tensor_array_box_from_tensors(tensors: &[Tensor]) -> *mut c_void {
20    let handles = tensors.iter().map(Tensor::as_ptr).collect::<Vec<_>>();
21    let handles_ptr = if handles.is_empty() {
22        ptr::null()
23    } else {
24        handles.as_ptr()
25    };
26    // SAFETY: the handles stay valid for the duration of the bridge call and the Swift array retains them.
27    unsafe { ffi::mpsgraph_tensor_array_box_new(handles_ptr, handles.len()) }
28}
29
30fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
31    if box_handle.is_null() {
32        None
33    } else {
34        Some(collect_owned_tensors(box_handle))
35    }
36}
37
38fn into_owned_tensor_handle(tensor: Tensor) -> *mut c_void {
39    let ptr = tensor.as_ptr();
40    mem::forget(tensor);
41    ptr
42}
43
44pub struct WhileBeforeResult {
45    pub predicate: Tensor,
46    pub results: Vec<Tensor>,
47}
48
49struct ZeroArgCallbackContext<'a, F> {
50    callback: &'a mut F,
51}
52
53unsafe extern "C" fn zero_arg_tensor_array_trampoline<F>(context: *mut c_void) -> *mut c_void
54where
55    F: FnMut() -> Vec<Tensor>,
56{
57    // SAFETY: `context` is a pointer to `ZeroArgCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
58    let context = unsafe { &mut *context.cast::<ZeroArgCallbackContext<'_, F>>() };
59    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)()))
60        .unwrap_or_else(|_| std::process::abort());
61    tensor_array_box_from_tensors(&tensors)
62}
63
64struct WhileBeforeCallbackContext<'a, F> {
65    callback: &'a mut F,
66}
67
68unsafe extern "C" fn while_before_trampoline<F>(
69    context: *mut c_void,
70    input_box_handle: *mut c_void,
71    out_result_box_handle: *mut *mut c_void,
72) -> *mut c_void
73where
74    F: FnMut(&[Tensor]) -> WhileBeforeResult,
75{
76    // SAFETY: `context` is a pointer to `WhileBeforeCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
77    let context = unsafe { &mut *context.cast::<WhileBeforeCallbackContext<'_, F>>() };
78    let inputs = collect_owned_tensors(input_box_handle);
79    match catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs))) {
80        Ok(result) => {
81            if !out_result_box_handle.is_null() {
82                // SAFETY: caller provided a valid output slot for the tensor-array box.
83                unsafe { *out_result_box_handle = tensor_array_box_from_tensors(&result.results) };
84            }
85            into_owned_tensor_handle(result.predicate)
86        }
87        Err(_) => std::process::abort(),
88    }
89}
90
91struct TensorArrayInputCallbackContext<'a, F> {
92    callback: &'a mut F,
93}
94
95unsafe extern "C" fn tensor_array_input_trampoline<F>(
96    context: *mut c_void,
97    input_box_handle: *mut c_void,
98) -> *mut c_void
99where
100    F: FnMut(&[Tensor]) -> Vec<Tensor>,
101{
102    // SAFETY: `context` is a pointer to `TensorArrayInputCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
103    let context = unsafe { &mut *context.cast::<TensorArrayInputCallbackContext<'_, F>>() };
104    let inputs = collect_owned_tensors(input_box_handle);
105    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs)))
106        .unwrap_or_else(|_| std::process::abort());
107    tensor_array_box_from_tensors(&tensors)
108}
109
110struct ForBodyCallbackContext<'a, F> {
111    callback: &'a mut F,
112}
113
114unsafe extern "C" fn for_body_trampoline<F>(
115    context: *mut c_void,
116    index_handle: *mut c_void,
117    input_box_handle: *mut c_void,
118) -> *mut c_void
119where
120    F: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
121{
122    if index_handle.is_null() {
123        return ptr::null_mut();
124    }
125    // SAFETY: `context` is a pointer to `ForBodyCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
126    let context = unsafe { &mut *context.cast::<ForBodyCallbackContext<'_, F>>() };
127    let index = Tensor::from_raw(index_handle);
128    let inputs = collect_owned_tensors(input_box_handle);
129    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&index, &inputs)))
130        .unwrap_or_else(|_| std::process::abort());
131    tensor_array_box_from_tensors(&tensors)
132}
133
134impl crate::graph::Graph {
135    pub fn control_dependency<F>(
136        &self,
137        operations: &[&Operation],
138        mut dependent_block: F,
139        name: Option<&str>,
140    ) -> Option<Vec<Tensor>>
141    where
142        F: FnMut() -> Vec<Tensor>,
143    {
144        let name = optional_cstring(name);
145        let operation_handles = operations
146            .iter()
147            .map(|operation| operation.as_ptr())
148            .collect::<Vec<_>>();
149        let operation_ptr = if operation_handles.is_empty() {
150            ptr::null()
151        } else {
152            operation_handles.as_ptr()
153        };
154        let mut context = ZeroArgCallbackContext {
155            callback: &mut dependent_block,
156        };
157        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
158        let box_handle = unsafe {
159            ffi::mpsgraph_graph_control_dependency(
160                self.as_ptr(),
161                operation_ptr,
162                operation_handles.len(),
163                Some(zero_arg_tensor_array_trampoline::<F>),
164                ptr::from_mut(&mut context).cast(),
165                cstring_ptr(&name),
166            )
167        };
168        wrap_tensor_array(box_handle)
169    }
170
171    pub fn if_then<Then>(
172        &self,
173        predicate: &Tensor,
174        mut then_block: Then,
175        name: Option<&str>,
176    ) -> Option<Vec<Tensor>>
177    where
178        Then: FnMut() -> Vec<Tensor>,
179    {
180        let name = optional_cstring(name);
181        let mut then_context = ZeroArgCallbackContext {
182            callback: &mut then_block,
183        };
184        // SAFETY: the callback context and handles remain valid for the duration of the call.
185        let box_handle = unsafe {
186            ffi::mpsgraph_graph_if_then_else(
187                self.as_ptr(),
188                predicate.as_ptr(),
189                Some(zero_arg_tensor_array_trampoline::<Then>),
190                ptr::from_mut(&mut then_context).cast(),
191                None,
192                ptr::null_mut(),
193                cstring_ptr(&name),
194            )
195        };
196        wrap_tensor_array(box_handle)
197    }
198
199    pub fn if_then_else<Then, Else>(
200        &self,
201        predicate: &Tensor,
202        mut then_block: Then,
203        mut else_block: Else,
204        name: Option<&str>,
205    ) -> Option<Vec<Tensor>>
206    where
207        Then: FnMut() -> Vec<Tensor>,
208        Else: FnMut() -> Vec<Tensor>,
209    {
210        let name = optional_cstring(name);
211        let mut then_context = ZeroArgCallbackContext {
212            callback: &mut then_block,
213        };
214        let mut else_context = ZeroArgCallbackContext {
215            callback: &mut else_block,
216        };
217        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
218        let box_handle = unsafe {
219            ffi::mpsgraph_graph_if_then_else(
220                self.as_ptr(),
221                predicate.as_ptr(),
222                Some(zero_arg_tensor_array_trampoline::<Then>),
223                ptr::from_mut(&mut then_context).cast(),
224                Some(zero_arg_tensor_array_trampoline::<Else>),
225                ptr::from_mut(&mut else_context).cast(),
226                cstring_ptr(&name),
227            )
228        };
229        wrap_tensor_array(box_handle)
230    }
231
232    pub fn while_loop<Before, After>(
233        &self,
234        initial_inputs: &[&Tensor],
235        mut before: Before,
236        mut after: After,
237        name: Option<&str>,
238    ) -> Option<Vec<Tensor>>
239    where
240        Before: FnMut(&[Tensor]) -> WhileBeforeResult,
241        After: FnMut(&[Tensor]) -> Vec<Tensor>,
242    {
243        let name = optional_cstring(name);
244        let input_handles = initial_inputs
245            .iter()
246            .map(|tensor| tensor.as_ptr())
247            .collect::<Vec<_>>();
248        let input_ptr = if input_handles.is_empty() {
249            ptr::null()
250        } else {
251            input_handles.as_ptr()
252        };
253        let mut before_context = WhileBeforeCallbackContext {
254            callback: &mut before,
255        };
256        let mut after_context = TensorArrayInputCallbackContext {
257            callback: &mut after,
258        };
259        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
260        let box_handle = unsafe {
261            ffi::mpsgraph_graph_while_loop(
262                self.as_ptr(),
263                input_ptr,
264                input_handles.len(),
265                Some(while_before_trampoline::<Before>),
266                ptr::from_mut(&mut before_context).cast(),
267                Some(tensor_array_input_trampoline::<After>),
268                ptr::from_mut(&mut after_context).cast(),
269                cstring_ptr(&name),
270            )
271        };
272        wrap_tensor_array(box_handle)
273    }
274
275    #[allow(clippy::too_many_arguments)]
276    pub fn for_loop<Body>(
277        &self,
278        lower_bound: &Tensor,
279        upper_bound: &Tensor,
280        step: &Tensor,
281        initial_body_arguments: &[&Tensor],
282        mut body: Body,
283        name: Option<&str>,
284    ) -> Option<Vec<Tensor>>
285    where
286        Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
287    {
288        let name = optional_cstring(name);
289        let argument_handles = initial_body_arguments
290            .iter()
291            .map(|tensor| tensor.as_ptr())
292            .collect::<Vec<_>>();
293        let argument_ptr = if argument_handles.is_empty() {
294            ptr::null()
295        } else {
296            argument_handles.as_ptr()
297        };
298        let mut context = ForBodyCallbackContext {
299            callback: &mut body,
300        };
301        // SAFETY: the callback context and handles remain valid for the duration of the call.
302        let box_handle = unsafe {
303            ffi::mpsgraph_graph_for_loop(
304                self.as_ptr(),
305                lower_bound.as_ptr(),
306                upper_bound.as_ptr(),
307                step.as_ptr(),
308                argument_ptr,
309                argument_handles.len(),
310                Some(for_body_trampoline::<Body>),
311                ptr::from_mut(&mut context).cast(),
312                cstring_ptr(&name),
313            )
314        };
315        wrap_tensor_array(box_handle)
316    }
317
318    pub fn for_loop_iterations<Body>(
319        &self,
320        number_of_iterations: &Tensor,
321        initial_body_arguments: &[&Tensor],
322        mut body: Body,
323        name: Option<&str>,
324    ) -> Option<Vec<Tensor>>
325    where
326        Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
327    {
328        let name = optional_cstring(name);
329        let argument_handles = initial_body_arguments
330            .iter()
331            .map(|tensor| tensor.as_ptr())
332            .collect::<Vec<_>>();
333        let argument_ptr = if argument_handles.is_empty() {
334            ptr::null()
335        } else {
336            argument_handles.as_ptr()
337        };
338        let mut context = ForBodyCallbackContext {
339            callback: &mut body,
340        };
341        // SAFETY: the callback context and handles remain valid for the duration of the call.
342        let box_handle = unsafe {
343            ffi::mpsgraph_graph_for_loop_iterations(
344                self.as_ptr(),
345                number_of_iterations.as_ptr(),
346                argument_ptr,
347                argument_handles.len(),
348                Some(for_body_trampoline::<Body>),
349                ptr::from_mut(&mut context).cast(),
350                cstring_ptr(&name),
351            )
352        };
353        wrap_tensor_array(box_handle)
354    }
355}