1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10 name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn optional_tensor_ptr(tensor: Option<&Tensor>) -> *mut c_void {
19 tensor.map_or(ptr::null_mut(), Tensor::as_ptr)
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23 if box_handle.is_null() {
24 None
25 } else {
26 Some(collect_owned_tensors(box_handle))
27 }
28}
29
30pub mod rnn_activation {
32 pub const NONE: usize = 0;
33 pub const RELU: usize = 1;
34 pub const TANH: usize = 2;
35 pub const SIGMOID: usize = 3;
36 pub const HARD_SIGMOID: usize = 4;
37}
38
39macro_rules! descriptor_handle {
40 ($name:ident) => {
41 pub struct $name {
42 ptr: *mut c_void,
43 }
44
45 unsafe impl Send for $name {}
46 unsafe impl Sync for $name {}
47
48 impl Drop for $name {
49 fn drop(&mut self) {
50 if !self.ptr.is_null() {
51 unsafe { ffi::mpsgraph_object_release(self.ptr) };
53 self.ptr = ptr::null_mut();
54 }
55 }
56 }
57
58 impl $name {
59 #[must_use]
60 pub(crate) const fn as_ptr(&self) -> *mut c_void {
61 self.ptr
62 }
63 }
64 };
65}
66
67macro_rules! bool_getter_setter {
68 ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
69 #[must_use]
70 pub fn $getter(&self) -> bool {
71 unsafe { ffi::$ffi_get(self.ptr) }
73 }
74
75 pub fn $setter(&self, value: bool) -> Result<()> {
76 let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
78 if ok {
79 Ok(())
80 } else {
81 Err(Error::OperationFailed($msg))
82 }
83 }
84 };
85}
86
87macro_rules! activation_getter_setter {
88 ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
89 #[must_use]
90 pub fn $getter(&self) -> usize {
91 unsafe { ffi::$ffi_get(self.ptr) }
93 }
94
95 pub fn $setter(&self, value: usize) -> Result<()> {
96 let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
98 if ok {
99 Ok(())
100 } else {
101 Err(Error::OperationFailed($msg))
102 }
103 }
104 };
105}
106
107descriptor_handle!(SingleGateRNNDescriptor);
108impl SingleGateRNNDescriptor {
109 #[must_use]
110 pub fn new() -> Option<Self> {
111 let ptr = unsafe { ffi::mpsgraph_single_gate_rnn_descriptor_new() };
113 if ptr.is_null() {
114 None
115 } else {
116 Some(Self { ptr })
117 }
118 }
119
120 bool_getter_setter!(
121 reverse,
122 set_reverse,
123 mpsgraph_single_gate_rnn_descriptor_reverse,
124 mpsgraph_single_gate_rnn_descriptor_set_reverse,
125 "failed to set single-gate RNN reverse"
126 );
127 bool_getter_setter!(
128 bidirectional,
129 set_bidirectional,
130 mpsgraph_single_gate_rnn_descriptor_bidirectional,
131 mpsgraph_single_gate_rnn_descriptor_set_bidirectional,
132 "failed to set single-gate RNN bidirectional"
133 );
134 bool_getter_setter!(
135 training,
136 set_training,
137 mpsgraph_single_gate_rnn_descriptor_training,
138 mpsgraph_single_gate_rnn_descriptor_set_training,
139 "failed to set single-gate RNN training"
140 );
141 activation_getter_setter!(
142 activation,
143 set_activation,
144 mpsgraph_single_gate_rnn_descriptor_activation,
145 mpsgraph_single_gate_rnn_descriptor_set_activation,
146 "failed to set single-gate RNN activation"
147 );
148}
149
150descriptor_handle!(LSTMDescriptor);
151impl LSTMDescriptor {
152 #[must_use]
153 pub fn new() -> Option<Self> {
154 let ptr = unsafe { ffi::mpsgraph_lstm_descriptor_new() };
156 if ptr.is_null() {
157 None
158 } else {
159 Some(Self { ptr })
160 }
161 }
162
163 bool_getter_setter!(
164 reverse,
165 set_reverse,
166 mpsgraph_lstm_descriptor_reverse,
167 mpsgraph_lstm_descriptor_set_reverse,
168 "failed to set LSTM reverse"
169 );
170 bool_getter_setter!(
171 bidirectional,
172 set_bidirectional,
173 mpsgraph_lstm_descriptor_bidirectional,
174 mpsgraph_lstm_descriptor_set_bidirectional,
175 "failed to set LSTM bidirectional"
176 );
177 bool_getter_setter!(
178 produce_cell,
179 set_produce_cell,
180 mpsgraph_lstm_descriptor_produce_cell,
181 mpsgraph_lstm_descriptor_set_produce_cell,
182 "failed to set LSTM produceCell"
183 );
184 bool_getter_setter!(
185 training,
186 set_training,
187 mpsgraph_lstm_descriptor_training,
188 mpsgraph_lstm_descriptor_set_training,
189 "failed to set LSTM training"
190 );
191 bool_getter_setter!(
192 forget_gate_last,
193 set_forget_gate_last,
194 mpsgraph_lstm_descriptor_forget_gate_last,
195 mpsgraph_lstm_descriptor_set_forget_gate_last,
196 "failed to set LSTM forgetGateLast"
197 );
198 activation_getter_setter!(
199 input_gate_activation,
200 set_input_gate_activation,
201 mpsgraph_lstm_descriptor_input_gate_activation,
202 mpsgraph_lstm_descriptor_set_input_gate_activation,
203 "failed to set LSTM inputGateActivation"
204 );
205 activation_getter_setter!(
206 forget_gate_activation,
207 set_forget_gate_activation,
208 mpsgraph_lstm_descriptor_forget_gate_activation,
209 mpsgraph_lstm_descriptor_set_forget_gate_activation,
210 "failed to set LSTM forgetGateActivation"
211 );
212 activation_getter_setter!(
213 cell_gate_activation,
214 set_cell_gate_activation,
215 mpsgraph_lstm_descriptor_cell_gate_activation,
216 mpsgraph_lstm_descriptor_set_cell_gate_activation,
217 "failed to set LSTM cellGateActivation"
218 );
219 activation_getter_setter!(
220 output_gate_activation,
221 set_output_gate_activation,
222 mpsgraph_lstm_descriptor_output_gate_activation,
223 mpsgraph_lstm_descriptor_set_output_gate_activation,
224 "failed to set LSTM outputGateActivation"
225 );
226 activation_getter_setter!(
227 activation,
228 set_activation,
229 mpsgraph_lstm_descriptor_activation,
230 mpsgraph_lstm_descriptor_set_activation,
231 "failed to set LSTM activation"
232 );
233}
234
235descriptor_handle!(GRUDescriptor);
236impl GRUDescriptor {
237 #[must_use]
238 pub fn new() -> Option<Self> {
239 let ptr = unsafe { ffi::mpsgraph_gru_descriptor_new() };
241 if ptr.is_null() {
242 None
243 } else {
244 Some(Self { ptr })
245 }
246 }
247
248 bool_getter_setter!(
249 reverse,
250 set_reverse,
251 mpsgraph_gru_descriptor_reverse,
252 mpsgraph_gru_descriptor_set_reverse,
253 "failed to set GRU reverse"
254 );
255 bool_getter_setter!(
256 bidirectional,
257 set_bidirectional,
258 mpsgraph_gru_descriptor_bidirectional,
259 mpsgraph_gru_descriptor_set_bidirectional,
260 "failed to set GRU bidirectional"
261 );
262 bool_getter_setter!(
263 training,
264 set_training,
265 mpsgraph_gru_descriptor_training,
266 mpsgraph_gru_descriptor_set_training,
267 "failed to set GRU training"
268 );
269 bool_getter_setter!(
270 reset_gate_first,
271 set_reset_gate_first,
272 mpsgraph_gru_descriptor_reset_gate_first,
273 mpsgraph_gru_descriptor_set_reset_gate_first,
274 "failed to set GRU resetGateFirst"
275 );
276 bool_getter_setter!(
277 reset_after,
278 set_reset_after,
279 mpsgraph_gru_descriptor_reset_after,
280 mpsgraph_gru_descriptor_set_reset_after,
281 "failed to set GRU resetAfter"
282 );
283 bool_getter_setter!(
284 flip_z,
285 set_flip_z,
286 mpsgraph_gru_descriptor_flip_z,
287 mpsgraph_gru_descriptor_set_flip_z,
288 "failed to set GRU flipZ"
289 );
290 activation_getter_setter!(
291 update_gate_activation,
292 set_update_gate_activation,
293 mpsgraph_gru_descriptor_update_gate_activation,
294 mpsgraph_gru_descriptor_set_update_gate_activation,
295 "failed to set GRU updateGateActivation"
296 );
297 activation_getter_setter!(
298 reset_gate_activation,
299 set_reset_gate_activation,
300 mpsgraph_gru_descriptor_reset_gate_activation,
301 mpsgraph_gru_descriptor_set_reset_gate_activation,
302 "failed to set GRU resetGateActivation"
303 );
304 activation_getter_setter!(
305 output_gate_activation,
306 set_output_gate_activation,
307 mpsgraph_gru_descriptor_output_gate_activation,
308 mpsgraph_gru_descriptor_set_output_gate_activation,
309 "failed to set GRU outputGateActivation"
310 );
311}
312
313impl crate::graph::Graph {
314 #[allow(clippy::too_many_arguments)]
315 pub fn single_gate_rnn(
316 &self,
317 source: &Tensor,
318 recurrent_weight: &Tensor,
319 input_weight: Option<&Tensor>,
320 bias: Option<&Tensor>,
321 init_state: Option<&Tensor>,
322 mask: Option<&Tensor>,
323 descriptor: &SingleGateRNNDescriptor,
324 name: Option<&str>,
325 ) -> Option<Vec<Tensor>> {
326 let name = optional_cstring(name);
327 let box_handle = unsafe {
329 ffi::mpsgraph_graph_single_gate_rnn(
330 self.as_ptr(),
331 source.as_ptr(),
332 recurrent_weight.as_ptr(),
333 optional_tensor_ptr(input_weight),
334 optional_tensor_ptr(bias),
335 optional_tensor_ptr(init_state),
336 optional_tensor_ptr(mask),
337 descriptor.as_ptr(),
338 cstring_ptr(&name),
339 )
340 };
341 wrap_tensor_array(box_handle)
342 }
343
344 #[allow(clippy::too_many_arguments)]
345 pub fn lstm(
346 &self,
347 source: &Tensor,
348 recurrent_weight: &Tensor,
349 input_weight: Option<&Tensor>,
350 bias: Option<&Tensor>,
351 init_state: Option<&Tensor>,
352 init_cell: Option<&Tensor>,
353 mask: Option<&Tensor>,
354 peephole: Option<&Tensor>,
355 descriptor: &LSTMDescriptor,
356 name: Option<&str>,
357 ) -> Option<Vec<Tensor>> {
358 let name = optional_cstring(name);
359 let box_handle = unsafe {
361 ffi::mpsgraph_graph_lstm(
362 self.as_ptr(),
363 source.as_ptr(),
364 recurrent_weight.as_ptr(),
365 optional_tensor_ptr(input_weight),
366 optional_tensor_ptr(bias),
367 optional_tensor_ptr(init_state),
368 optional_tensor_ptr(init_cell),
369 optional_tensor_ptr(mask),
370 optional_tensor_ptr(peephole),
371 descriptor.as_ptr(),
372 cstring_ptr(&name),
373 )
374 };
375 wrap_tensor_array(box_handle)
376 }
377
378 #[allow(clippy::too_many_arguments)]
379 pub fn gru(
380 &self,
381 source: &Tensor,
382 recurrent_weight: &Tensor,
383 input_weight: Option<&Tensor>,
384 bias: Option<&Tensor>,
385 init_state: Option<&Tensor>,
386 mask: Option<&Tensor>,
387 secondary_bias: Option<&Tensor>,
388 descriptor: &GRUDescriptor,
389 name: Option<&str>,
390 ) -> Option<Vec<Tensor>> {
391 let name = optional_cstring(name);
392 let box_handle = unsafe {
394 ffi::mpsgraph_graph_gru(
395 self.as_ptr(),
396 source.as_ptr(),
397 recurrent_weight.as_ptr(),
398 optional_tensor_ptr(input_weight),
399 optional_tensor_ptr(bias),
400 optional_tensor_ptr(init_state),
401 optional_tensor_ptr(mask),
402 optional_tensor_ptr(secondary_bias),
403 descriptor.as_ptr(),
404 cstring_ptr(&name),
405 )
406 };
407 wrap_tensor_array(box_handle)
408 }
409}