Skip to main content

apple_mpsgraph/
control_flow.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::{collect_owned_tensors, Operation};
4use core::ffi::{c_char, c_void};
5use core::ptr;
6use std::ffi::CString;
7use std::mem;
8use std::panic::{catch_unwind, AssertUnwindSafe};
9
10fn optional_cstring(name: Option<&str>) -> Option<CString> {
11    name.and_then(|value| CString::new(value).ok())
12}
13
14#[allow(clippy::ref_option)]
15fn cstring_ptr(value: &Option<CString>) -> *const c_char {
16    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
17}
18
19fn tensor_array_box_from_tensors(tensors: &[Tensor]) -> *mut c_void {
20    let handles = tensors.iter().map(Tensor::as_ptr).collect::<Vec<_>>();
21    let handles_ptr = if handles.is_empty() {
22        ptr::null()
23    } else {
24        handles.as_ptr()
25    };
26    // SAFETY: the handles stay valid for the duration of the bridge call and the Swift array retains them.
27    unsafe { ffi::mpsgraph_tensor_array_box_new(handles_ptr, handles.len()) }
28}
29
30fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
31    if box_handle.is_null() {
32        None
33    } else {
34        Some(collect_owned_tensors(box_handle))
35    }
36}
37
38fn into_owned_tensor_handle(tensor: Tensor) -> *mut c_void {
39    let ptr = tensor.as_ptr();
40    mem::forget(tensor);
41    ptr
42}
43
44/// Mirrors the `MPSGraph` framework counterpart for `WhileBeforeResult`.
45pub struct WhileBeforeResult {
46/// Mirrors the `MPSGraph` framework property for `predicate`.
47    pub predicate: Tensor,
48/// Mirrors the `MPSGraph` framework property for `results`.
49    pub results: Vec<Tensor>,
50}
51
52struct ZeroArgCallbackContext<'a, F> {
53    callback: &'a mut F,
54}
55
56unsafe extern "C" fn zero_arg_tensor_array_trampoline<F>(context: *mut c_void) -> *mut c_void
57where
58    F: FnMut() -> Vec<Tensor>,
59{
60    // SAFETY: `context` is a pointer to `ZeroArgCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
61    let context = unsafe { &mut *context.cast::<ZeroArgCallbackContext<'_, F>>() };
62    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)()))
63        .unwrap_or_else(|_| std::process::abort());
64    tensor_array_box_from_tensors(&tensors)
65}
66
67struct WhileBeforeCallbackContext<'a, F> {
68    callback: &'a mut F,
69}
70
71unsafe extern "C" fn while_before_trampoline<F>(
72    context: *mut c_void,
73    input_box_handle: *mut c_void,
74    out_result_box_handle: *mut *mut c_void,
75) -> *mut c_void
76where
77    F: FnMut(&[Tensor]) -> WhileBeforeResult,
78{
79    // SAFETY: `context` is a pointer to `WhileBeforeCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
80    let context = unsafe { &mut *context.cast::<WhileBeforeCallbackContext<'_, F>>() };
81    let inputs = collect_owned_tensors(input_box_handle);
82    match catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs))) {
83        Ok(result) => {
84            if !out_result_box_handle.is_null() {
85                // SAFETY: caller provided a valid output slot for the tensor-array box.
86                unsafe { *out_result_box_handle = tensor_array_box_from_tensors(&result.results) };
87            }
88            into_owned_tensor_handle(result.predicate)
89        }
90        Err(_) => std::process::abort(),
91    }
92}
93
94struct TensorArrayInputCallbackContext<'a, F> {
95    callback: &'a mut F,
96}
97
98unsafe extern "C" fn tensor_array_input_trampoline<F>(
99    context: *mut c_void,
100    input_box_handle: *mut c_void,
101) -> *mut c_void
102where
103    F: FnMut(&[Tensor]) -> Vec<Tensor>,
104{
105    // SAFETY: `context` is a pointer to `TensorArrayInputCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
106    let context = unsafe { &mut *context.cast::<TensorArrayInputCallbackContext<'_, F>>() };
107    let inputs = collect_owned_tensors(input_box_handle);
108    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&inputs)))
109        .unwrap_or_else(|_| std::process::abort());
110    tensor_array_box_from_tensors(&tensors)
111}
112
113struct ForBodyCallbackContext<'a, F> {
114    callback: &'a mut F,
115}
116
117unsafe extern "C" fn for_body_trampoline<F>(
118    context: *mut c_void,
119    index_handle: *mut c_void,
120    input_box_handle: *mut c_void,
121) -> *mut c_void
122where
123    F: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
124{
125    if index_handle.is_null() {
126        return ptr::null_mut();
127    }
128    // SAFETY: `context` is a pointer to `ForBodyCallbackContext<'_, F>` set up by the safe wrapper at the callsite.
129    let context = unsafe { &mut *context.cast::<ForBodyCallbackContext<'_, F>>() };
130    let index = Tensor::from_raw(index_handle);
131    let inputs = collect_owned_tensors(input_box_handle);
132    let tensors = catch_unwind(AssertUnwindSafe(|| (context.callback)(&index, &inputs)))
133        .unwrap_or_else(|_| std::process::abort());
134    tensor_array_box_from_tensors(&tensors)
135}
136
137impl crate::graph::Graph {
138/// Calls the `MPSGraph` framework counterpart for `control_dependency`.
139    pub fn control_dependency<F>(
140        &self,
141        operations: &[&Operation],
142        mut dependent_block: F,
143        name: Option<&str>,
144    ) -> Option<Vec<Tensor>>
145    where
146        F: FnMut() -> Vec<Tensor>,
147    {
148        let name = optional_cstring(name);
149        let operation_handles = operations
150            .iter()
151            .map(|operation| operation.as_ptr())
152            .collect::<Vec<_>>();
153        let operation_ptr = if operation_handles.is_empty() {
154            ptr::null()
155        } else {
156            operation_handles.as_ptr()
157        };
158        let mut context = ZeroArgCallbackContext {
159            callback: &mut dependent_block,
160        };
161        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
162        let box_handle = unsafe {
163            ffi::mpsgraph_graph_control_dependency(
164                self.as_ptr(),
165                operation_ptr,
166                operation_handles.len(),
167                Some(zero_arg_tensor_array_trampoline::<F>),
168                ptr::from_mut(&mut context).cast(),
169                cstring_ptr(&name),
170            )
171        };
172        wrap_tensor_array(box_handle)
173    }
174
175/// Calls the `MPSGraph` framework counterpart for `if_then`.
176    pub fn if_then<Then>(
177        &self,
178        predicate: &Tensor,
179        mut then_block: Then,
180        name: Option<&str>,
181    ) -> Option<Vec<Tensor>>
182    where
183        Then: FnMut() -> Vec<Tensor>,
184    {
185        let name = optional_cstring(name);
186        let mut then_context = ZeroArgCallbackContext {
187            callback: &mut then_block,
188        };
189        // SAFETY: the callback context and handles remain valid for the duration of the call.
190        let box_handle = unsafe {
191            ffi::mpsgraph_graph_if_then_else(
192                self.as_ptr(),
193                predicate.as_ptr(),
194                Some(zero_arg_tensor_array_trampoline::<Then>),
195                ptr::from_mut(&mut then_context).cast(),
196                None,
197                ptr::null_mut(),
198                cstring_ptr(&name),
199            )
200        };
201        wrap_tensor_array(box_handle)
202    }
203
204/// Calls the `MPSGraph` framework counterpart for `if_then_else`.
205    pub fn if_then_else<Then, Else>(
206        &self,
207        predicate: &Tensor,
208        mut then_block: Then,
209        mut else_block: Else,
210        name: Option<&str>,
211    ) -> Option<Vec<Tensor>>
212    where
213        Then: FnMut() -> Vec<Tensor>,
214        Else: FnMut() -> Vec<Tensor>,
215    {
216        let name = optional_cstring(name);
217        let mut then_context = ZeroArgCallbackContext {
218            callback: &mut then_block,
219        };
220        let mut else_context = ZeroArgCallbackContext {
221            callback: &mut else_block,
222        };
223        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
224        let box_handle = unsafe {
225            ffi::mpsgraph_graph_if_then_else(
226                self.as_ptr(),
227                predicate.as_ptr(),
228                Some(zero_arg_tensor_array_trampoline::<Then>),
229                ptr::from_mut(&mut then_context).cast(),
230                Some(zero_arg_tensor_array_trampoline::<Else>),
231                ptr::from_mut(&mut else_context).cast(),
232                cstring_ptr(&name),
233            )
234        };
235        wrap_tensor_array(box_handle)
236    }
237
238/// Calls the `MPSGraph` framework counterpart for `while_loop`.
239    pub fn while_loop<Before, After>(
240        &self,
241        initial_inputs: &[&Tensor],
242        mut before: Before,
243        mut after: After,
244        name: Option<&str>,
245    ) -> Option<Vec<Tensor>>
246    where
247        Before: FnMut(&[Tensor]) -> WhileBeforeResult,
248        After: FnMut(&[Tensor]) -> Vec<Tensor>,
249    {
250        let name = optional_cstring(name);
251        let input_handles = initial_inputs
252            .iter()
253            .map(|tensor| tensor.as_ptr())
254            .collect::<Vec<_>>();
255        let input_ptr = if input_handles.is_empty() {
256            ptr::null()
257        } else {
258            input_handles.as_ptr()
259        };
260        let mut before_context = WhileBeforeCallbackContext {
261            callback: &mut before,
262        };
263        let mut after_context = TensorArrayInputCallbackContext {
264            callback: &mut after,
265        };
266        // SAFETY: the callback contexts and handles remain valid for the duration of the call.
267        let box_handle = unsafe {
268            ffi::mpsgraph_graph_while_loop(
269                self.as_ptr(),
270                input_ptr,
271                input_handles.len(),
272                Some(while_before_trampoline::<Before>),
273                ptr::from_mut(&mut before_context).cast(),
274                Some(tensor_array_input_trampoline::<After>),
275                ptr::from_mut(&mut after_context).cast(),
276                cstring_ptr(&name),
277            )
278        };
279        wrap_tensor_array(box_handle)
280    }
281
282/// Calls the `MPSGraph` framework counterpart for `for_loop`.
283    #[allow(clippy::too_many_arguments)]
284    pub fn for_loop<Body>(
285        &self,
286        lower_bound: &Tensor,
287        upper_bound: &Tensor,
288        step: &Tensor,
289        initial_body_arguments: &[&Tensor],
290        mut body: Body,
291        name: Option<&str>,
292    ) -> Option<Vec<Tensor>>
293    where
294        Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
295    {
296        let name = optional_cstring(name);
297        let argument_handles = initial_body_arguments
298            .iter()
299            .map(|tensor| tensor.as_ptr())
300            .collect::<Vec<_>>();
301        let argument_ptr = if argument_handles.is_empty() {
302            ptr::null()
303        } else {
304            argument_handles.as_ptr()
305        };
306        let mut context = ForBodyCallbackContext {
307            callback: &mut body,
308        };
309        // SAFETY: the callback context and handles remain valid for the duration of the call.
310        let box_handle = unsafe {
311            ffi::mpsgraph_graph_for_loop(
312                self.as_ptr(),
313                lower_bound.as_ptr(),
314                upper_bound.as_ptr(),
315                step.as_ptr(),
316                argument_ptr,
317                argument_handles.len(),
318                Some(for_body_trampoline::<Body>),
319                ptr::from_mut(&mut context).cast(),
320                cstring_ptr(&name),
321            )
322        };
323        wrap_tensor_array(box_handle)
324    }
325
326/// Calls the `MPSGraph` framework counterpart for `for_loop_iterations`.
327    pub fn for_loop_iterations<Body>(
328        &self,
329        number_of_iterations: &Tensor,
330        initial_body_arguments: &[&Tensor],
331        mut body: Body,
332        name: Option<&str>,
333    ) -> Option<Vec<Tensor>>
334    where
335        Body: FnMut(&Tensor, &[Tensor]) -> Vec<Tensor>,
336    {
337        let name = optional_cstring(name);
338        let argument_handles = initial_body_arguments
339            .iter()
340            .map(|tensor| tensor.as_ptr())
341            .collect::<Vec<_>>();
342        let argument_ptr = if argument_handles.is_empty() {
343            ptr::null()
344        } else {
345            argument_handles.as_ptr()
346        };
347        let mut context = ForBodyCallbackContext {
348            callback: &mut body,
349        };
350        // SAFETY: the callback context and handles remain valid for the duration of the call.
351        let box_handle = unsafe {
352            ffi::mpsgraph_graph_for_loop_iterations(
353                self.as_ptr(),
354                number_of_iterations.as_ptr(),
355                argument_ptr,
356                argument_handles.len(),
357                Some(for_body_trampoline::<Body>),
358                ptr::from_mut(&mut context).cast(),
359                cstring_ptr(&name),
360            )
361        };
362        wrap_tensor_array(box_handle)
363    }
364}