1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10 name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn optional_tensor_ptr(tensor: Option<&Tensor>) -> *mut c_void {
19 tensor.map_or(ptr::null_mut(), Tensor::as_ptr)
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23 if box_handle.is_null() {
24 None
25 } else {
26 Some(collect_owned_tensors(box_handle))
27 }
28}
29
30pub mod rnn_activation {
32 pub const NONE: usize = 0;
33 pub const RELU: usize = 1;
34 pub const TANH: usize = 2;
35 pub const SIGMOID: usize = 3;
36 pub const HARD_SIGMOID: usize = 4;
37}
38
39macro_rules! descriptor_handle {
40 ($name:ident) => {
41 pub struct $name {
42 ptr: *mut c_void,
43 }
44
45 unsafe impl Send for $name {}
46 unsafe impl Sync for $name {}
47
48 impl Drop for $name {
49 fn drop(&mut self) {
50 if !self.ptr.is_null() {
51 unsafe { ffi::mpsgraph_object_release(self.ptr) };
53 self.ptr = ptr::null_mut();
54 }
55 }
56 }
57
58 impl $name {
59 #[must_use]
60 pub(crate) const fn as_ptr(&self) -> *mut c_void {
61 self.ptr
62 }
63 }
64 };
65}
66
67macro_rules! bool_getter_setter {
68 ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
69 #[must_use]
70 pub fn $getter(&self) -> bool {
71 unsafe { ffi::$ffi_get(self.ptr) }
73 }
74
75 pub fn $setter(&self, value: bool) -> Result<()> {
76 let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
78 if ok {
79 Ok(())
80 } else {
81 Err(Error::OperationFailed($msg))
82 }
83 }
84 };
85}
86
87macro_rules! activation_getter_setter {
88 ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
89 #[must_use]
90 pub fn $getter(&self) -> usize {
91 unsafe { ffi::$ffi_get(self.ptr) }
93 }
94
95 pub fn $setter(&self, value: usize) -> Result<()> {
96 let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
98 if ok {
99 Ok(())
100 } else {
101 Err(Error::OperationFailed($msg))
102 }
103 }
104 };
105}
106
107descriptor_handle!(SingleGateRNNDescriptor);
108impl SingleGateRNNDescriptor {
109 #[must_use]
110 pub fn new() -> Option<Self> {
111 let ptr = unsafe { ffi::mpsgraph_single_gate_rnn_descriptor_new() };
113 if ptr.is_null() {
114 None
115 } else {
116 Some(Self { ptr })
117 }
118 }
119
120 bool_getter_setter!(reverse, set_reverse, mpsgraph_single_gate_rnn_descriptor_reverse, mpsgraph_single_gate_rnn_descriptor_set_reverse, "failed to set single-gate RNN reverse");
121 bool_getter_setter!(bidirectional, set_bidirectional, mpsgraph_single_gate_rnn_descriptor_bidirectional, mpsgraph_single_gate_rnn_descriptor_set_bidirectional, "failed to set single-gate RNN bidirectional");
122 bool_getter_setter!(training, set_training, mpsgraph_single_gate_rnn_descriptor_training, mpsgraph_single_gate_rnn_descriptor_set_training, "failed to set single-gate RNN training");
123 activation_getter_setter!(activation, set_activation, mpsgraph_single_gate_rnn_descriptor_activation, mpsgraph_single_gate_rnn_descriptor_set_activation, "failed to set single-gate RNN activation");
124}
125
126descriptor_handle!(LSTMDescriptor);
127impl LSTMDescriptor {
128 #[must_use]
129 pub fn new() -> Option<Self> {
130 let ptr = unsafe { ffi::mpsgraph_lstm_descriptor_new() };
132 if ptr.is_null() {
133 None
134 } else {
135 Some(Self { ptr })
136 }
137 }
138
139 bool_getter_setter!(reverse, set_reverse, mpsgraph_lstm_descriptor_reverse, mpsgraph_lstm_descriptor_set_reverse, "failed to set LSTM reverse");
140 bool_getter_setter!(bidirectional, set_bidirectional, mpsgraph_lstm_descriptor_bidirectional, mpsgraph_lstm_descriptor_set_bidirectional, "failed to set LSTM bidirectional");
141 bool_getter_setter!(produce_cell, set_produce_cell, mpsgraph_lstm_descriptor_produce_cell, mpsgraph_lstm_descriptor_set_produce_cell, "failed to set LSTM produceCell");
142 bool_getter_setter!(training, set_training, mpsgraph_lstm_descriptor_training, mpsgraph_lstm_descriptor_set_training, "failed to set LSTM training");
143 bool_getter_setter!(forget_gate_last, set_forget_gate_last, mpsgraph_lstm_descriptor_forget_gate_last, mpsgraph_lstm_descriptor_set_forget_gate_last, "failed to set LSTM forgetGateLast");
144 activation_getter_setter!(input_gate_activation, set_input_gate_activation, mpsgraph_lstm_descriptor_input_gate_activation, mpsgraph_lstm_descriptor_set_input_gate_activation, "failed to set LSTM inputGateActivation");
145 activation_getter_setter!(forget_gate_activation, set_forget_gate_activation, mpsgraph_lstm_descriptor_forget_gate_activation, mpsgraph_lstm_descriptor_set_forget_gate_activation, "failed to set LSTM forgetGateActivation");
146 activation_getter_setter!(cell_gate_activation, set_cell_gate_activation, mpsgraph_lstm_descriptor_cell_gate_activation, mpsgraph_lstm_descriptor_set_cell_gate_activation, "failed to set LSTM cellGateActivation");
147 activation_getter_setter!(output_gate_activation, set_output_gate_activation, mpsgraph_lstm_descriptor_output_gate_activation, mpsgraph_lstm_descriptor_set_output_gate_activation, "failed to set LSTM outputGateActivation");
148 activation_getter_setter!(activation, set_activation, mpsgraph_lstm_descriptor_activation, mpsgraph_lstm_descriptor_set_activation, "failed to set LSTM activation");
149}
150
151descriptor_handle!(GRUDescriptor);
152impl GRUDescriptor {
153 #[must_use]
154 pub fn new() -> Option<Self> {
155 let ptr = unsafe { ffi::mpsgraph_gru_descriptor_new() };
157 if ptr.is_null() {
158 None
159 } else {
160 Some(Self { ptr })
161 }
162 }
163
164 bool_getter_setter!(reverse, set_reverse, mpsgraph_gru_descriptor_reverse, mpsgraph_gru_descriptor_set_reverse, "failed to set GRU reverse");
165 bool_getter_setter!(bidirectional, set_bidirectional, mpsgraph_gru_descriptor_bidirectional, mpsgraph_gru_descriptor_set_bidirectional, "failed to set GRU bidirectional");
166 bool_getter_setter!(training, set_training, mpsgraph_gru_descriptor_training, mpsgraph_gru_descriptor_set_training, "failed to set GRU training");
167 bool_getter_setter!(reset_gate_first, set_reset_gate_first, mpsgraph_gru_descriptor_reset_gate_first, mpsgraph_gru_descriptor_set_reset_gate_first, "failed to set GRU resetGateFirst");
168 bool_getter_setter!(reset_after, set_reset_after, mpsgraph_gru_descriptor_reset_after, mpsgraph_gru_descriptor_set_reset_after, "failed to set GRU resetAfter");
169 bool_getter_setter!(flip_z, set_flip_z, mpsgraph_gru_descriptor_flip_z, mpsgraph_gru_descriptor_set_flip_z, "failed to set GRU flipZ");
170 activation_getter_setter!(update_gate_activation, set_update_gate_activation, mpsgraph_gru_descriptor_update_gate_activation, mpsgraph_gru_descriptor_set_update_gate_activation, "failed to set GRU updateGateActivation");
171 activation_getter_setter!(reset_gate_activation, set_reset_gate_activation, mpsgraph_gru_descriptor_reset_gate_activation, mpsgraph_gru_descriptor_set_reset_gate_activation, "failed to set GRU resetGateActivation");
172 activation_getter_setter!(output_gate_activation, set_output_gate_activation, mpsgraph_gru_descriptor_output_gate_activation, mpsgraph_gru_descriptor_set_output_gate_activation, "failed to set GRU outputGateActivation");
173}
174
175impl crate::graph::Graph {
176 #[allow(clippy::too_many_arguments)]
177 pub fn single_gate_rnn(
178 &self,
179 source: &Tensor,
180 recurrent_weight: &Tensor,
181 input_weight: Option<&Tensor>,
182 bias: Option<&Tensor>,
183 init_state: Option<&Tensor>,
184 mask: Option<&Tensor>,
185 descriptor: &SingleGateRNNDescriptor,
186 name: Option<&str>,
187 ) -> Option<Vec<Tensor>> {
188 let name = optional_cstring(name);
189 let box_handle = unsafe {
191 ffi::mpsgraph_graph_single_gate_rnn(
192 self.as_ptr(),
193 source.as_ptr(),
194 recurrent_weight.as_ptr(),
195 optional_tensor_ptr(input_weight),
196 optional_tensor_ptr(bias),
197 optional_tensor_ptr(init_state),
198 optional_tensor_ptr(mask),
199 descriptor.as_ptr(),
200 cstring_ptr(&name),
201 )
202 };
203 wrap_tensor_array(box_handle)
204 }
205
206 #[allow(clippy::too_many_arguments)]
207 pub fn lstm(
208 &self,
209 source: &Tensor,
210 recurrent_weight: &Tensor,
211 input_weight: Option<&Tensor>,
212 bias: Option<&Tensor>,
213 init_state: Option<&Tensor>,
214 init_cell: Option<&Tensor>,
215 mask: Option<&Tensor>,
216 peephole: Option<&Tensor>,
217 descriptor: &LSTMDescriptor,
218 name: Option<&str>,
219 ) -> Option<Vec<Tensor>> {
220 let name = optional_cstring(name);
221 let box_handle = unsafe {
223 ffi::mpsgraph_graph_lstm(
224 self.as_ptr(),
225 source.as_ptr(),
226 recurrent_weight.as_ptr(),
227 optional_tensor_ptr(input_weight),
228 optional_tensor_ptr(bias),
229 optional_tensor_ptr(init_state),
230 optional_tensor_ptr(init_cell),
231 optional_tensor_ptr(mask),
232 optional_tensor_ptr(peephole),
233 descriptor.as_ptr(),
234 cstring_ptr(&name),
235 )
236 };
237 wrap_tensor_array(box_handle)
238 }
239
240 #[allow(clippy::too_many_arguments)]
241 pub fn gru(
242 &self,
243 source: &Tensor,
244 recurrent_weight: &Tensor,
245 input_weight: Option<&Tensor>,
246 bias: Option<&Tensor>,
247 init_state: Option<&Tensor>,
248 mask: Option<&Tensor>,
249 secondary_bias: Option<&Tensor>,
250 descriptor: &GRUDescriptor,
251 name: Option<&str>,
252 ) -> Option<Vec<Tensor>> {
253 let name = optional_cstring(name);
254 let box_handle = unsafe {
256 ffi::mpsgraph_graph_gru(
257 self.as_ptr(),
258 source.as_ptr(),
259 recurrent_weight.as_ptr(),
260 optional_tensor_ptr(input_weight),
261 optional_tensor_ptr(bias),
262 optional_tensor_ptr(init_state),
263 optional_tensor_ptr(mask),
264 optional_tensor_ptr(secondary_bias),
265 descriptor.as_ptr(),
266 cstring_ptr(&name),
267 )
268 };
269 wrap_tensor_array(box_handle)
270 }
271}