Skip to main content

apple_mpsgraph/
rnn.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10    name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn optional_tensor_ptr(tensor: Option<&Tensor>) -> *mut c_void {
19    tensor.map_or(ptr::null_mut(), Tensor::as_ptr)
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23    if box_handle.is_null() {
24        None
25    } else {
26        Some(collect_owned_tensors(box_handle))
27    }
28}
29
30/// `MPSGraphRNNActivation` constants.
31pub mod rnn_activation {
32    pub const NONE: usize = 0;
33    pub const RELU: usize = 1;
34    pub const TANH: usize = 2;
35    pub const SIGMOID: usize = 3;
36    pub const HARD_SIGMOID: usize = 4;
37}
38
39macro_rules! descriptor_handle {
40    ($name:ident) => {
41        pub struct $name {
42            ptr: *mut c_void,
43        }
44
45        unsafe impl Send for $name {}
46        unsafe impl Sync for $name {}
47
48        impl Drop for $name {
49            fn drop(&mut self) {
50                if !self.ptr.is_null() {
51                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
52                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
53                    self.ptr = ptr::null_mut();
54                }
55            }
56        }
57
58        impl $name {
59            #[must_use]
60            pub(crate) const fn as_ptr(&self) -> *mut c_void {
61                self.ptr
62            }
63        }
64    };
65}
66
67macro_rules! bool_getter_setter {
68    ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
69        #[must_use]
70        pub fn $getter(&self) -> bool {
71            // SAFETY: `self.ptr` is a live descriptor handle.
72            unsafe { ffi::$ffi_get(self.ptr) }
73        }
74
75        pub fn $setter(&self, value: bool) -> Result<()> {
76            // SAFETY: `self.ptr` is a live descriptor handle.
77            let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
78            if ok {
79                Ok(())
80            } else {
81                Err(Error::OperationFailed($msg))
82            }
83        }
84    };
85}
86
87macro_rules! activation_getter_setter {
88    ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
89        #[must_use]
90        pub fn $getter(&self) -> usize {
91            // SAFETY: `self.ptr` is a live descriptor handle.
92            unsafe { ffi::$ffi_get(self.ptr) }
93        }
94
95        pub fn $setter(&self, value: usize) -> Result<()> {
96            // SAFETY: `self.ptr` is a live descriptor handle.
97            let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
98            if ok {
99                Ok(())
100            } else {
101                Err(Error::OperationFailed($msg))
102            }
103        }
104    };
105}
106
107descriptor_handle!(SingleGateRNNDescriptor);
108impl SingleGateRNNDescriptor {
109    #[must_use]
110    pub fn new() -> Option<Self> {
111        // SAFETY: pure constructor.
112        let ptr = unsafe { ffi::mpsgraph_single_gate_rnn_descriptor_new() };
113        if ptr.is_null() {
114            None
115        } else {
116            Some(Self { ptr })
117        }
118    }
119
120    bool_getter_setter!(
121        reverse,
122        set_reverse,
123        mpsgraph_single_gate_rnn_descriptor_reverse,
124        mpsgraph_single_gate_rnn_descriptor_set_reverse,
125        "failed to set single-gate RNN reverse"
126    );
127    bool_getter_setter!(
128        bidirectional,
129        set_bidirectional,
130        mpsgraph_single_gate_rnn_descriptor_bidirectional,
131        mpsgraph_single_gate_rnn_descriptor_set_bidirectional,
132        "failed to set single-gate RNN bidirectional"
133    );
134    bool_getter_setter!(
135        training,
136        set_training,
137        mpsgraph_single_gate_rnn_descriptor_training,
138        mpsgraph_single_gate_rnn_descriptor_set_training,
139        "failed to set single-gate RNN training"
140    );
141    activation_getter_setter!(
142        activation,
143        set_activation,
144        mpsgraph_single_gate_rnn_descriptor_activation,
145        mpsgraph_single_gate_rnn_descriptor_set_activation,
146        "failed to set single-gate RNN activation"
147    );
148}
149
150descriptor_handle!(LSTMDescriptor);
151impl LSTMDescriptor {
152    #[must_use]
153    pub fn new() -> Option<Self> {
154        // SAFETY: pure constructor.
155        let ptr = unsafe { ffi::mpsgraph_lstm_descriptor_new() };
156        if ptr.is_null() {
157            None
158        } else {
159            Some(Self { ptr })
160        }
161    }
162
163    bool_getter_setter!(
164        reverse,
165        set_reverse,
166        mpsgraph_lstm_descriptor_reverse,
167        mpsgraph_lstm_descriptor_set_reverse,
168        "failed to set LSTM reverse"
169    );
170    bool_getter_setter!(
171        bidirectional,
172        set_bidirectional,
173        mpsgraph_lstm_descriptor_bidirectional,
174        mpsgraph_lstm_descriptor_set_bidirectional,
175        "failed to set LSTM bidirectional"
176    );
177    bool_getter_setter!(
178        produce_cell,
179        set_produce_cell,
180        mpsgraph_lstm_descriptor_produce_cell,
181        mpsgraph_lstm_descriptor_set_produce_cell,
182        "failed to set LSTM produceCell"
183    );
184    bool_getter_setter!(
185        training,
186        set_training,
187        mpsgraph_lstm_descriptor_training,
188        mpsgraph_lstm_descriptor_set_training,
189        "failed to set LSTM training"
190    );
191    bool_getter_setter!(
192        forget_gate_last,
193        set_forget_gate_last,
194        mpsgraph_lstm_descriptor_forget_gate_last,
195        mpsgraph_lstm_descriptor_set_forget_gate_last,
196        "failed to set LSTM forgetGateLast"
197    );
198    activation_getter_setter!(
199        input_gate_activation,
200        set_input_gate_activation,
201        mpsgraph_lstm_descriptor_input_gate_activation,
202        mpsgraph_lstm_descriptor_set_input_gate_activation,
203        "failed to set LSTM inputGateActivation"
204    );
205    activation_getter_setter!(
206        forget_gate_activation,
207        set_forget_gate_activation,
208        mpsgraph_lstm_descriptor_forget_gate_activation,
209        mpsgraph_lstm_descriptor_set_forget_gate_activation,
210        "failed to set LSTM forgetGateActivation"
211    );
212    activation_getter_setter!(
213        cell_gate_activation,
214        set_cell_gate_activation,
215        mpsgraph_lstm_descriptor_cell_gate_activation,
216        mpsgraph_lstm_descriptor_set_cell_gate_activation,
217        "failed to set LSTM cellGateActivation"
218    );
219    activation_getter_setter!(
220        output_gate_activation,
221        set_output_gate_activation,
222        mpsgraph_lstm_descriptor_output_gate_activation,
223        mpsgraph_lstm_descriptor_set_output_gate_activation,
224        "failed to set LSTM outputGateActivation"
225    );
226    activation_getter_setter!(
227        activation,
228        set_activation,
229        mpsgraph_lstm_descriptor_activation,
230        mpsgraph_lstm_descriptor_set_activation,
231        "failed to set LSTM activation"
232    );
233}
234
235descriptor_handle!(GRUDescriptor);
236impl GRUDescriptor {
237    #[must_use]
238    pub fn new() -> Option<Self> {
239        // SAFETY: pure constructor.
240        let ptr = unsafe { ffi::mpsgraph_gru_descriptor_new() };
241        if ptr.is_null() {
242            None
243        } else {
244            Some(Self { ptr })
245        }
246    }
247
248    bool_getter_setter!(
249        reverse,
250        set_reverse,
251        mpsgraph_gru_descriptor_reverse,
252        mpsgraph_gru_descriptor_set_reverse,
253        "failed to set GRU reverse"
254    );
255    bool_getter_setter!(
256        bidirectional,
257        set_bidirectional,
258        mpsgraph_gru_descriptor_bidirectional,
259        mpsgraph_gru_descriptor_set_bidirectional,
260        "failed to set GRU bidirectional"
261    );
262    bool_getter_setter!(
263        training,
264        set_training,
265        mpsgraph_gru_descriptor_training,
266        mpsgraph_gru_descriptor_set_training,
267        "failed to set GRU training"
268    );
269    bool_getter_setter!(
270        reset_gate_first,
271        set_reset_gate_first,
272        mpsgraph_gru_descriptor_reset_gate_first,
273        mpsgraph_gru_descriptor_set_reset_gate_first,
274        "failed to set GRU resetGateFirst"
275    );
276    bool_getter_setter!(
277        reset_after,
278        set_reset_after,
279        mpsgraph_gru_descriptor_reset_after,
280        mpsgraph_gru_descriptor_set_reset_after,
281        "failed to set GRU resetAfter"
282    );
283    bool_getter_setter!(
284        flip_z,
285        set_flip_z,
286        mpsgraph_gru_descriptor_flip_z,
287        mpsgraph_gru_descriptor_set_flip_z,
288        "failed to set GRU flipZ"
289    );
290    activation_getter_setter!(
291        update_gate_activation,
292        set_update_gate_activation,
293        mpsgraph_gru_descriptor_update_gate_activation,
294        mpsgraph_gru_descriptor_set_update_gate_activation,
295        "failed to set GRU updateGateActivation"
296    );
297    activation_getter_setter!(
298        reset_gate_activation,
299        set_reset_gate_activation,
300        mpsgraph_gru_descriptor_reset_gate_activation,
301        mpsgraph_gru_descriptor_set_reset_gate_activation,
302        "failed to set GRU resetGateActivation"
303    );
304    activation_getter_setter!(
305        output_gate_activation,
306        set_output_gate_activation,
307        mpsgraph_gru_descriptor_output_gate_activation,
308        mpsgraph_gru_descriptor_set_output_gate_activation,
309        "failed to set GRU outputGateActivation"
310    );
311}
312
313impl crate::graph::Graph {
314    #[allow(clippy::too_many_arguments)]
315    pub fn single_gate_rnn(
316        &self,
317        source: &Tensor,
318        recurrent_weight: &Tensor,
319        input_weight: Option<&Tensor>,
320        bias: Option<&Tensor>,
321        init_state: Option<&Tensor>,
322        mask: Option<&Tensor>,
323        descriptor: &SingleGateRNNDescriptor,
324        name: Option<&str>,
325    ) -> Option<Vec<Tensor>> {
326        let name = optional_cstring(name);
327        // SAFETY: all handles remain valid for the duration of the call.
328        let box_handle = unsafe {
329            ffi::mpsgraph_graph_single_gate_rnn(
330                self.as_ptr(),
331                source.as_ptr(),
332                recurrent_weight.as_ptr(),
333                optional_tensor_ptr(input_weight),
334                optional_tensor_ptr(bias),
335                optional_tensor_ptr(init_state),
336                optional_tensor_ptr(mask),
337                descriptor.as_ptr(),
338                cstring_ptr(&name),
339            )
340        };
341        wrap_tensor_array(box_handle)
342    }
343
344    #[allow(clippy::too_many_arguments)]
345    pub fn lstm(
346        &self,
347        source: &Tensor,
348        recurrent_weight: &Tensor,
349        input_weight: Option<&Tensor>,
350        bias: Option<&Tensor>,
351        init_state: Option<&Tensor>,
352        init_cell: Option<&Tensor>,
353        mask: Option<&Tensor>,
354        peephole: Option<&Tensor>,
355        descriptor: &LSTMDescriptor,
356        name: Option<&str>,
357    ) -> Option<Vec<Tensor>> {
358        let name = optional_cstring(name);
359        // SAFETY: all handles remain valid for the duration of the call.
360        let box_handle = unsafe {
361            ffi::mpsgraph_graph_lstm(
362                self.as_ptr(),
363                source.as_ptr(),
364                recurrent_weight.as_ptr(),
365                optional_tensor_ptr(input_weight),
366                optional_tensor_ptr(bias),
367                optional_tensor_ptr(init_state),
368                optional_tensor_ptr(init_cell),
369                optional_tensor_ptr(mask),
370                optional_tensor_ptr(peephole),
371                descriptor.as_ptr(),
372                cstring_ptr(&name),
373            )
374        };
375        wrap_tensor_array(box_handle)
376    }
377
378    #[allow(clippy::too_many_arguments)]
379    pub fn gru(
380        &self,
381        source: &Tensor,
382        recurrent_weight: &Tensor,
383        input_weight: Option<&Tensor>,
384        bias: Option<&Tensor>,
385        init_state: Option<&Tensor>,
386        mask: Option<&Tensor>,
387        secondary_bias: Option<&Tensor>,
388        descriptor: &GRUDescriptor,
389        name: Option<&str>,
390    ) -> Option<Vec<Tensor>> {
391        let name = optional_cstring(name);
392        // SAFETY: all handles remain valid for the duration of the call.
393        let box_handle = unsafe {
394            ffi::mpsgraph_graph_gru(
395                self.as_ptr(),
396                source.as_ptr(),
397                recurrent_weight.as_ptr(),
398                optional_tensor_ptr(input_weight),
399                optional_tensor_ptr(bias),
400                optional_tensor_ptr(init_state),
401                optional_tensor_ptr(mask),
402                optional_tensor_ptr(secondary_bias),
403                descriptor.as_ptr(),
404                cstring_ptr(&name),
405            )
406        };
407        wrap_tensor_array(box_handle)
408    }
409}