Skip to main content

apple_mpsgraph/
rnn.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10    name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn optional_tensor_ptr(tensor: Option<&Tensor>) -> *mut c_void {
19    tensor.map_or(ptr::null_mut(), Tensor::as_ptr)
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23    if box_handle.is_null() {
24        None
25    } else {
26        Some(collect_owned_tensors(box_handle))
27    }
28}
29
30/// `MPSGraphRNNActivation` constants.
31pub mod rnn_activation {
32/// Mirrors the `MPSGraph` framework constant `NONE`.
33    pub const NONE: usize = 0;
34/// Mirrors the `MPSGraph` framework constant `RELU`.
35    pub const RELU: usize = 1;
36/// Mirrors the `MPSGraph` framework constant `TANH`.
37    pub const TANH: usize = 2;
38/// Mirrors the `MPSGraph` framework constant `SIGMOID`.
39    pub const SIGMOID: usize = 3;
40/// Mirrors the `MPSGraph` framework constant `HARD_SIGMOID`.
41    pub const HARD_SIGMOID: usize = 4;
42}
43
44macro_rules! descriptor_handle {
45    ($name:ident) => {
46/// Mirrors the `MPSGraph` framework counterpart for this type.
47        pub struct $name {
48            ptr: *mut c_void,
49        }
50
51        unsafe impl Send for $name {}
52        unsafe impl Sync for $name {}
53
54        impl Drop for $name {
55            fn drop(&mut self) {
56                if !self.ptr.is_null() {
57                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
58                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
59                    self.ptr = ptr::null_mut();
60                }
61            }
62        }
63
64        impl $name {
65            #[must_use]
66            pub(crate) const fn as_ptr(&self) -> *mut c_void {
67                self.ptr
68            }
69        }
70    };
71}
72
73macro_rules! bool_getter_setter {
74    ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
75/// Calls the `MPSGraph` framework counterpart for this method.
76        #[must_use]
77        pub fn $getter(&self) -> bool {
78            // SAFETY: `self.ptr` is a live descriptor handle.
79            unsafe { ffi::$ffi_get(self.ptr) }
80        }
81
82/// Calls the `MPSGraph` framework counterpart for this method.
83        pub fn $setter(&self, value: bool) -> Result<()> {
84            // SAFETY: `self.ptr` is a live descriptor handle.
85            let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
86            if ok {
87                Ok(())
88            } else {
89                Err(Error::OperationFailed($msg))
90            }
91        }
92    };
93}
94
95macro_rules! activation_getter_setter {
96    ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
97/// Calls the `MPSGraph` framework counterpart for this method.
98        #[must_use]
99        pub fn $getter(&self) -> usize {
100            // SAFETY: `self.ptr` is a live descriptor handle.
101            unsafe { ffi::$ffi_get(self.ptr) }
102        }
103
104/// Calls the `MPSGraph` framework counterpart for this method.
105        pub fn $setter(&self, value: usize) -> Result<()> {
106            // SAFETY: `self.ptr` is a live descriptor handle.
107            let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
108            if ok {
109                Ok(())
110            } else {
111                Err(Error::OperationFailed($msg))
112            }
113        }
114    };
115}
116
117descriptor_handle!(SingleGateRNNDescriptor);
118impl SingleGateRNNDescriptor {
119/// Calls the `MPSGraph` framework counterpart for `new`.
120    #[must_use]
121    pub fn new() -> Option<Self> {
122        // SAFETY: pure constructor.
123        let ptr = unsafe { ffi::mpsgraph_single_gate_rnn_descriptor_new() };
124        if ptr.is_null() {
125            None
126        } else {
127            Some(Self { ptr })
128        }
129    }
130
131    bool_getter_setter!(
132        reverse,
133        set_reverse,
134        mpsgraph_single_gate_rnn_descriptor_reverse,
135        mpsgraph_single_gate_rnn_descriptor_set_reverse,
136        "failed to set single-gate RNN reverse"
137    );
138    bool_getter_setter!(
139        bidirectional,
140        set_bidirectional,
141        mpsgraph_single_gate_rnn_descriptor_bidirectional,
142        mpsgraph_single_gate_rnn_descriptor_set_bidirectional,
143        "failed to set single-gate RNN bidirectional"
144    );
145    bool_getter_setter!(
146        training,
147        set_training,
148        mpsgraph_single_gate_rnn_descriptor_training,
149        mpsgraph_single_gate_rnn_descriptor_set_training,
150        "failed to set single-gate RNN training"
151    );
152    activation_getter_setter!(
153        activation,
154        set_activation,
155        mpsgraph_single_gate_rnn_descriptor_activation,
156        mpsgraph_single_gate_rnn_descriptor_set_activation,
157        "failed to set single-gate RNN activation"
158    );
159}
160
161descriptor_handle!(LSTMDescriptor);
162impl LSTMDescriptor {
163/// Calls the `MPSGraph` framework counterpart for `new`.
164    #[must_use]
165    pub fn new() -> Option<Self> {
166        // SAFETY: pure constructor.
167        let ptr = unsafe { ffi::mpsgraph_lstm_descriptor_new() };
168        if ptr.is_null() {
169            None
170        } else {
171            Some(Self { ptr })
172        }
173    }
174
175    bool_getter_setter!(
176        reverse,
177        set_reverse,
178        mpsgraph_lstm_descriptor_reverse,
179        mpsgraph_lstm_descriptor_set_reverse,
180        "failed to set LSTM reverse"
181    );
182    bool_getter_setter!(
183        bidirectional,
184        set_bidirectional,
185        mpsgraph_lstm_descriptor_bidirectional,
186        mpsgraph_lstm_descriptor_set_bidirectional,
187        "failed to set LSTM bidirectional"
188    );
189    bool_getter_setter!(
190        produce_cell,
191        set_produce_cell,
192        mpsgraph_lstm_descriptor_produce_cell,
193        mpsgraph_lstm_descriptor_set_produce_cell,
194        "failed to set LSTM produceCell"
195    );
196    bool_getter_setter!(
197        training,
198        set_training,
199        mpsgraph_lstm_descriptor_training,
200        mpsgraph_lstm_descriptor_set_training,
201        "failed to set LSTM training"
202    );
203    bool_getter_setter!(
204        forget_gate_last,
205        set_forget_gate_last,
206        mpsgraph_lstm_descriptor_forget_gate_last,
207        mpsgraph_lstm_descriptor_set_forget_gate_last,
208        "failed to set LSTM forgetGateLast"
209    );
210    activation_getter_setter!(
211        input_gate_activation,
212        set_input_gate_activation,
213        mpsgraph_lstm_descriptor_input_gate_activation,
214        mpsgraph_lstm_descriptor_set_input_gate_activation,
215        "failed to set LSTM inputGateActivation"
216    );
217    activation_getter_setter!(
218        forget_gate_activation,
219        set_forget_gate_activation,
220        mpsgraph_lstm_descriptor_forget_gate_activation,
221        mpsgraph_lstm_descriptor_set_forget_gate_activation,
222        "failed to set LSTM forgetGateActivation"
223    );
224    activation_getter_setter!(
225        cell_gate_activation,
226        set_cell_gate_activation,
227        mpsgraph_lstm_descriptor_cell_gate_activation,
228        mpsgraph_lstm_descriptor_set_cell_gate_activation,
229        "failed to set LSTM cellGateActivation"
230    );
231    activation_getter_setter!(
232        output_gate_activation,
233        set_output_gate_activation,
234        mpsgraph_lstm_descriptor_output_gate_activation,
235        mpsgraph_lstm_descriptor_set_output_gate_activation,
236        "failed to set LSTM outputGateActivation"
237    );
238    activation_getter_setter!(
239        activation,
240        set_activation,
241        mpsgraph_lstm_descriptor_activation,
242        mpsgraph_lstm_descriptor_set_activation,
243        "failed to set LSTM activation"
244    );
245}
246
247descriptor_handle!(GRUDescriptor);
248impl GRUDescriptor {
249/// Calls the `MPSGraph` framework counterpart for `new`.
250    #[must_use]
251    pub fn new() -> Option<Self> {
252        // SAFETY: pure constructor.
253        let ptr = unsafe { ffi::mpsgraph_gru_descriptor_new() };
254        if ptr.is_null() {
255            None
256        } else {
257            Some(Self { ptr })
258        }
259    }
260
261    bool_getter_setter!(
262        reverse,
263        set_reverse,
264        mpsgraph_gru_descriptor_reverse,
265        mpsgraph_gru_descriptor_set_reverse,
266        "failed to set GRU reverse"
267    );
268    bool_getter_setter!(
269        bidirectional,
270        set_bidirectional,
271        mpsgraph_gru_descriptor_bidirectional,
272        mpsgraph_gru_descriptor_set_bidirectional,
273        "failed to set GRU bidirectional"
274    );
275    bool_getter_setter!(
276        training,
277        set_training,
278        mpsgraph_gru_descriptor_training,
279        mpsgraph_gru_descriptor_set_training,
280        "failed to set GRU training"
281    );
282    bool_getter_setter!(
283        reset_gate_first,
284        set_reset_gate_first,
285        mpsgraph_gru_descriptor_reset_gate_first,
286        mpsgraph_gru_descriptor_set_reset_gate_first,
287        "failed to set GRU resetGateFirst"
288    );
289    bool_getter_setter!(
290        reset_after,
291        set_reset_after,
292        mpsgraph_gru_descriptor_reset_after,
293        mpsgraph_gru_descriptor_set_reset_after,
294        "failed to set GRU resetAfter"
295    );
296    bool_getter_setter!(
297        flip_z,
298        set_flip_z,
299        mpsgraph_gru_descriptor_flip_z,
300        mpsgraph_gru_descriptor_set_flip_z,
301        "failed to set GRU flipZ"
302    );
303    activation_getter_setter!(
304        update_gate_activation,
305        set_update_gate_activation,
306        mpsgraph_gru_descriptor_update_gate_activation,
307        mpsgraph_gru_descriptor_set_update_gate_activation,
308        "failed to set GRU updateGateActivation"
309    );
310    activation_getter_setter!(
311        reset_gate_activation,
312        set_reset_gate_activation,
313        mpsgraph_gru_descriptor_reset_gate_activation,
314        mpsgraph_gru_descriptor_set_reset_gate_activation,
315        "failed to set GRU resetGateActivation"
316    );
317    activation_getter_setter!(
318        output_gate_activation,
319        set_output_gate_activation,
320        mpsgraph_gru_descriptor_output_gate_activation,
321        mpsgraph_gru_descriptor_set_output_gate_activation,
322        "failed to set GRU outputGateActivation"
323    );
324}
325
326impl crate::graph::Graph {
327/// Calls the `MPSGraph` framework counterpart for `single_gate_rnn`.
328    #[allow(clippy::too_many_arguments)]
329    pub fn single_gate_rnn(
330        &self,
331        source: &Tensor,
332        recurrent_weight: &Tensor,
333        input_weight: Option<&Tensor>,
334        bias: Option<&Tensor>,
335        init_state: Option<&Tensor>,
336        mask: Option<&Tensor>,
337        descriptor: &SingleGateRNNDescriptor,
338        name: Option<&str>,
339    ) -> Option<Vec<Tensor>> {
340        let name = optional_cstring(name);
341        // SAFETY: all handles remain valid for the duration of the call.
342        let box_handle = unsafe {
343            ffi::mpsgraph_graph_single_gate_rnn(
344                self.as_ptr(),
345                source.as_ptr(),
346                recurrent_weight.as_ptr(),
347                optional_tensor_ptr(input_weight),
348                optional_tensor_ptr(bias),
349                optional_tensor_ptr(init_state),
350                optional_tensor_ptr(mask),
351                descriptor.as_ptr(),
352                cstring_ptr(&name),
353            )
354        };
355        wrap_tensor_array(box_handle)
356    }
357
358/// Calls the `MPSGraph` framework counterpart for `lstm`.
359    #[allow(clippy::too_many_arguments)]
360    pub fn lstm(
361        &self,
362        source: &Tensor,
363        recurrent_weight: &Tensor,
364        input_weight: Option<&Tensor>,
365        bias: Option<&Tensor>,
366        init_state: Option<&Tensor>,
367        init_cell: Option<&Tensor>,
368        mask: Option<&Tensor>,
369        peephole: Option<&Tensor>,
370        descriptor: &LSTMDescriptor,
371        name: Option<&str>,
372    ) -> Option<Vec<Tensor>> {
373        let name = optional_cstring(name);
374        // SAFETY: all handles remain valid for the duration of the call.
375        let box_handle = unsafe {
376            ffi::mpsgraph_graph_lstm(
377                self.as_ptr(),
378                source.as_ptr(),
379                recurrent_weight.as_ptr(),
380                optional_tensor_ptr(input_weight),
381                optional_tensor_ptr(bias),
382                optional_tensor_ptr(init_state),
383                optional_tensor_ptr(init_cell),
384                optional_tensor_ptr(mask),
385                optional_tensor_ptr(peephole),
386                descriptor.as_ptr(),
387                cstring_ptr(&name),
388            )
389        };
390        wrap_tensor_array(box_handle)
391    }
392
393/// Calls the `MPSGraph` framework counterpart for `gru`.
394    #[allow(clippy::too_many_arguments)]
395    pub fn gru(
396        &self,
397        source: &Tensor,
398        recurrent_weight: &Tensor,
399        input_weight: Option<&Tensor>,
400        bias: Option<&Tensor>,
401        init_state: Option<&Tensor>,
402        mask: Option<&Tensor>,
403        secondary_bias: Option<&Tensor>,
404        descriptor: &GRUDescriptor,
405        name: Option<&str>,
406    ) -> Option<Vec<Tensor>> {
407        let name = optional_cstring(name);
408        // SAFETY: all handles remain valid for the duration of the call.
409        let box_handle = unsafe {
410            ffi::mpsgraph_graph_gru(
411                self.as_ptr(),
412                source.as_ptr(),
413                recurrent_weight.as_ptr(),
414                optional_tensor_ptr(input_weight),
415                optional_tensor_ptr(bias),
416                optional_tensor_ptr(init_state),
417                optional_tensor_ptr(mask),
418                optional_tensor_ptr(secondary_bias),
419                descriptor.as_ptr(),
420                cstring_ptr(&name),
421            )
422        };
423        wrap_tensor_array(box_handle)
424    }
425}