Skip to main content

apple_mpsgraph/
rnn.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10    name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn optional_tensor_ptr(tensor: Option<&Tensor>) -> *mut c_void {
19    tensor.map_or(ptr::null_mut(), Tensor::as_ptr)
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23    if box_handle.is_null() {
24        None
25    } else {
26        Some(collect_owned_tensors(box_handle))
27    }
28}
29
30/// `MPSGraphRNNActivation` constants.
31pub mod rnn_activation {
32    pub const NONE: usize = 0;
33    pub const RELU: usize = 1;
34    pub const TANH: usize = 2;
35    pub const SIGMOID: usize = 3;
36    pub const HARD_SIGMOID: usize = 4;
37}
38
39macro_rules! descriptor_handle {
40    ($name:ident) => {
41        pub struct $name {
42            ptr: *mut c_void,
43        }
44
45        unsafe impl Send for $name {}
46        unsafe impl Sync for $name {}
47
48        impl Drop for $name {
49            fn drop(&mut self) {
50                if !self.ptr.is_null() {
51                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
52                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
53                    self.ptr = ptr::null_mut();
54                }
55            }
56        }
57
58        impl $name {
59            #[must_use]
60            pub(crate) const fn as_ptr(&self) -> *mut c_void {
61                self.ptr
62            }
63        }
64    };
65}
66
67macro_rules! bool_getter_setter {
68    ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
69        #[must_use]
70        pub fn $getter(&self) -> bool {
71            // SAFETY: `self.ptr` is a live descriptor handle.
72            unsafe { ffi::$ffi_get(self.ptr) }
73        }
74
75        pub fn $setter(&self, value: bool) -> Result<()> {
76            // SAFETY: `self.ptr` is a live descriptor handle.
77            let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
78            if ok {
79                Ok(())
80            } else {
81                Err(Error::OperationFailed($msg))
82            }
83        }
84    };
85}
86
87macro_rules! activation_getter_setter {
88    ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
89        #[must_use]
90        pub fn $getter(&self) -> usize {
91            // SAFETY: `self.ptr` is a live descriptor handle.
92            unsafe { ffi::$ffi_get(self.ptr) }
93        }
94
95        pub fn $setter(&self, value: usize) -> Result<()> {
96            // SAFETY: `self.ptr` is a live descriptor handle.
97            let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
98            if ok {
99                Ok(())
100            } else {
101                Err(Error::OperationFailed($msg))
102            }
103        }
104    };
105}
106
107descriptor_handle!(SingleGateRNNDescriptor);
108impl SingleGateRNNDescriptor {
109    #[must_use]
110    pub fn new() -> Option<Self> {
111        // SAFETY: pure constructor.
112        let ptr = unsafe { ffi::mpsgraph_single_gate_rnn_descriptor_new() };
113        if ptr.is_null() {
114            None
115        } else {
116            Some(Self { ptr })
117        }
118    }
119
120    bool_getter_setter!(reverse, set_reverse, mpsgraph_single_gate_rnn_descriptor_reverse, mpsgraph_single_gate_rnn_descriptor_set_reverse, "failed to set single-gate RNN reverse");
121    bool_getter_setter!(bidirectional, set_bidirectional, mpsgraph_single_gate_rnn_descriptor_bidirectional, mpsgraph_single_gate_rnn_descriptor_set_bidirectional, "failed to set single-gate RNN bidirectional");
122    bool_getter_setter!(training, set_training, mpsgraph_single_gate_rnn_descriptor_training, mpsgraph_single_gate_rnn_descriptor_set_training, "failed to set single-gate RNN training");
123    activation_getter_setter!(activation, set_activation, mpsgraph_single_gate_rnn_descriptor_activation, mpsgraph_single_gate_rnn_descriptor_set_activation, "failed to set single-gate RNN activation");
124}
125
126descriptor_handle!(LSTMDescriptor);
127impl LSTMDescriptor {
128    #[must_use]
129    pub fn new() -> Option<Self> {
130        // SAFETY: pure constructor.
131        let ptr = unsafe { ffi::mpsgraph_lstm_descriptor_new() };
132        if ptr.is_null() {
133            None
134        } else {
135            Some(Self { ptr })
136        }
137    }
138
139    bool_getter_setter!(reverse, set_reverse, mpsgraph_lstm_descriptor_reverse, mpsgraph_lstm_descriptor_set_reverse, "failed to set LSTM reverse");
140    bool_getter_setter!(bidirectional, set_bidirectional, mpsgraph_lstm_descriptor_bidirectional, mpsgraph_lstm_descriptor_set_bidirectional, "failed to set LSTM bidirectional");
141    bool_getter_setter!(produce_cell, set_produce_cell, mpsgraph_lstm_descriptor_produce_cell, mpsgraph_lstm_descriptor_set_produce_cell, "failed to set LSTM produceCell");
142    bool_getter_setter!(training, set_training, mpsgraph_lstm_descriptor_training, mpsgraph_lstm_descriptor_set_training, "failed to set LSTM training");
143    bool_getter_setter!(forget_gate_last, set_forget_gate_last, mpsgraph_lstm_descriptor_forget_gate_last, mpsgraph_lstm_descriptor_set_forget_gate_last, "failed to set LSTM forgetGateLast");
144    activation_getter_setter!(input_gate_activation, set_input_gate_activation, mpsgraph_lstm_descriptor_input_gate_activation, mpsgraph_lstm_descriptor_set_input_gate_activation, "failed to set LSTM inputGateActivation");
145    activation_getter_setter!(forget_gate_activation, set_forget_gate_activation, mpsgraph_lstm_descriptor_forget_gate_activation, mpsgraph_lstm_descriptor_set_forget_gate_activation, "failed to set LSTM forgetGateActivation");
146    activation_getter_setter!(cell_gate_activation, set_cell_gate_activation, mpsgraph_lstm_descriptor_cell_gate_activation, mpsgraph_lstm_descriptor_set_cell_gate_activation, "failed to set LSTM cellGateActivation");
147    activation_getter_setter!(output_gate_activation, set_output_gate_activation, mpsgraph_lstm_descriptor_output_gate_activation, mpsgraph_lstm_descriptor_set_output_gate_activation, "failed to set LSTM outputGateActivation");
148    activation_getter_setter!(activation, set_activation, mpsgraph_lstm_descriptor_activation, mpsgraph_lstm_descriptor_set_activation, "failed to set LSTM activation");
149}
150
151descriptor_handle!(GRUDescriptor);
152impl GRUDescriptor {
153    #[must_use]
154    pub fn new() -> Option<Self> {
155        // SAFETY: pure constructor.
156        let ptr = unsafe { ffi::mpsgraph_gru_descriptor_new() };
157        if ptr.is_null() {
158            None
159        } else {
160            Some(Self { ptr })
161        }
162    }
163
164    bool_getter_setter!(reverse, set_reverse, mpsgraph_gru_descriptor_reverse, mpsgraph_gru_descriptor_set_reverse, "failed to set GRU reverse");
165    bool_getter_setter!(bidirectional, set_bidirectional, mpsgraph_gru_descriptor_bidirectional, mpsgraph_gru_descriptor_set_bidirectional, "failed to set GRU bidirectional");
166    bool_getter_setter!(training, set_training, mpsgraph_gru_descriptor_training, mpsgraph_gru_descriptor_set_training, "failed to set GRU training");
167    bool_getter_setter!(reset_gate_first, set_reset_gate_first, mpsgraph_gru_descriptor_reset_gate_first, mpsgraph_gru_descriptor_set_reset_gate_first, "failed to set GRU resetGateFirst");
168    bool_getter_setter!(reset_after, set_reset_after, mpsgraph_gru_descriptor_reset_after, mpsgraph_gru_descriptor_set_reset_after, "failed to set GRU resetAfter");
169    bool_getter_setter!(flip_z, set_flip_z, mpsgraph_gru_descriptor_flip_z, mpsgraph_gru_descriptor_set_flip_z, "failed to set GRU flipZ");
170    activation_getter_setter!(update_gate_activation, set_update_gate_activation, mpsgraph_gru_descriptor_update_gate_activation, mpsgraph_gru_descriptor_set_update_gate_activation, "failed to set GRU updateGateActivation");
171    activation_getter_setter!(reset_gate_activation, set_reset_gate_activation, mpsgraph_gru_descriptor_reset_gate_activation, mpsgraph_gru_descriptor_set_reset_gate_activation, "failed to set GRU resetGateActivation");
172    activation_getter_setter!(output_gate_activation, set_output_gate_activation, mpsgraph_gru_descriptor_output_gate_activation, mpsgraph_gru_descriptor_set_output_gate_activation, "failed to set GRU outputGateActivation");
173}
174
175impl crate::graph::Graph {
176    #[allow(clippy::too_many_arguments)]
177    pub fn single_gate_rnn(
178        &self,
179        source: &Tensor,
180        recurrent_weight: &Tensor,
181        input_weight: Option<&Tensor>,
182        bias: Option<&Tensor>,
183        init_state: Option<&Tensor>,
184        mask: Option<&Tensor>,
185        descriptor: &SingleGateRNNDescriptor,
186        name: Option<&str>,
187    ) -> Option<Vec<Tensor>> {
188        let name = optional_cstring(name);
189        // SAFETY: all handles remain valid for the duration of the call.
190        let box_handle = unsafe {
191            ffi::mpsgraph_graph_single_gate_rnn(
192                self.as_ptr(),
193                source.as_ptr(),
194                recurrent_weight.as_ptr(),
195                optional_tensor_ptr(input_weight),
196                optional_tensor_ptr(bias),
197                optional_tensor_ptr(init_state),
198                optional_tensor_ptr(mask),
199                descriptor.as_ptr(),
200                cstring_ptr(&name),
201            )
202        };
203        wrap_tensor_array(box_handle)
204    }
205
206    #[allow(clippy::too_many_arguments)]
207    pub fn lstm(
208        &self,
209        source: &Tensor,
210        recurrent_weight: &Tensor,
211        input_weight: Option<&Tensor>,
212        bias: Option<&Tensor>,
213        init_state: Option<&Tensor>,
214        init_cell: Option<&Tensor>,
215        mask: Option<&Tensor>,
216        peephole: Option<&Tensor>,
217        descriptor: &LSTMDescriptor,
218        name: Option<&str>,
219    ) -> Option<Vec<Tensor>> {
220        let name = optional_cstring(name);
221        // SAFETY: all handles remain valid for the duration of the call.
222        let box_handle = unsafe {
223            ffi::mpsgraph_graph_lstm(
224                self.as_ptr(),
225                source.as_ptr(),
226                recurrent_weight.as_ptr(),
227                optional_tensor_ptr(input_weight),
228                optional_tensor_ptr(bias),
229                optional_tensor_ptr(init_state),
230                optional_tensor_ptr(init_cell),
231                optional_tensor_ptr(mask),
232                optional_tensor_ptr(peephole),
233                descriptor.as_ptr(),
234                cstring_ptr(&name),
235            )
236        };
237        wrap_tensor_array(box_handle)
238    }
239
240    #[allow(clippy::too_many_arguments)]
241    pub fn gru(
242        &self,
243        source: &Tensor,
244        recurrent_weight: &Tensor,
245        input_weight: Option<&Tensor>,
246        bias: Option<&Tensor>,
247        init_state: Option<&Tensor>,
248        mask: Option<&Tensor>,
249        secondary_bias: Option<&Tensor>,
250        descriptor: &GRUDescriptor,
251        name: Option<&str>,
252    ) -> Option<Vec<Tensor>> {
253        let name = optional_cstring(name);
254        // SAFETY: all handles remain valid for the duration of the call.
255        let box_handle = unsafe {
256            ffi::mpsgraph_graph_gru(
257                self.as_ptr(),
258                source.as_ptr(),
259                recurrent_weight.as_ptr(),
260                optional_tensor_ptr(input_weight),
261                optional_tensor_ptr(bias),
262                optional_tensor_ptr(init_state),
263                optional_tensor_ptr(mask),
264                optional_tensor_ptr(secondary_bias),
265                descriptor.as_ptr(),
266                cstring_ptr(&name),
267            )
268        };
269        wrap_tensor_array(box_handle)
270    }
271}