1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10 name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn optional_tensor_ptr(tensor: Option<&Tensor>) -> *mut c_void {
19 tensor.map_or(ptr::null_mut(), Tensor::as_ptr)
20}
21
22fn wrap_tensor_array(box_handle: *mut c_void) -> Option<Vec<Tensor>> {
23 if box_handle.is_null() {
24 None
25 } else {
26 Some(collect_owned_tensors(box_handle))
27 }
28}
29
30pub mod rnn_activation {
32pub const NONE: usize = 0;
34pub const RELU: usize = 1;
36pub const TANH: usize = 2;
38pub const SIGMOID: usize = 3;
40pub const HARD_SIGMOID: usize = 4;
42}
43
44macro_rules! descriptor_handle {
45 ($name:ident) => {
46pub struct $name {
48 ptr: *mut c_void,
49 }
50
51 unsafe impl Send for $name {}
52 unsafe impl Sync for $name {}
53
54 impl Drop for $name {
55 fn drop(&mut self) {
56 if !self.ptr.is_null() {
57 unsafe { ffi::mpsgraph_object_release(self.ptr) };
59 self.ptr = ptr::null_mut();
60 }
61 }
62 }
63
64 impl $name {
65 #[must_use]
66 pub(crate) const fn as_ptr(&self) -> *mut c_void {
67 self.ptr
68 }
69 }
70 };
71}
72
73macro_rules! bool_getter_setter {
74 ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
75#[must_use]
77 pub fn $getter(&self) -> bool {
78 unsafe { ffi::$ffi_get(self.ptr) }
80 }
81
82pub fn $setter(&self, value: bool) -> Result<()> {
84 let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
86 if ok {
87 Ok(())
88 } else {
89 Err(Error::OperationFailed($msg))
90 }
91 }
92 };
93}
94
95macro_rules! activation_getter_setter {
96 ($getter:ident, $setter:ident, $ffi_get:ident, $ffi_set:ident, $msg:literal) => {
97#[must_use]
99 pub fn $getter(&self) -> usize {
100 unsafe { ffi::$ffi_get(self.ptr) }
102 }
103
104pub fn $setter(&self, value: usize) -> Result<()> {
106 let ok = unsafe { ffi::$ffi_set(self.ptr, value) };
108 if ok {
109 Ok(())
110 } else {
111 Err(Error::OperationFailed($msg))
112 }
113 }
114 };
115}
116
117descriptor_handle!(SingleGateRNNDescriptor);
118impl SingleGateRNNDescriptor {
119#[must_use]
121 pub fn new() -> Option<Self> {
122 let ptr = unsafe { ffi::mpsgraph_single_gate_rnn_descriptor_new() };
124 if ptr.is_null() {
125 None
126 } else {
127 Some(Self { ptr })
128 }
129 }
130
131 bool_getter_setter!(
132 reverse,
133 set_reverse,
134 mpsgraph_single_gate_rnn_descriptor_reverse,
135 mpsgraph_single_gate_rnn_descriptor_set_reverse,
136 "failed to set single-gate RNN reverse"
137 );
138 bool_getter_setter!(
139 bidirectional,
140 set_bidirectional,
141 mpsgraph_single_gate_rnn_descriptor_bidirectional,
142 mpsgraph_single_gate_rnn_descriptor_set_bidirectional,
143 "failed to set single-gate RNN bidirectional"
144 );
145 bool_getter_setter!(
146 training,
147 set_training,
148 mpsgraph_single_gate_rnn_descriptor_training,
149 mpsgraph_single_gate_rnn_descriptor_set_training,
150 "failed to set single-gate RNN training"
151 );
152 activation_getter_setter!(
153 activation,
154 set_activation,
155 mpsgraph_single_gate_rnn_descriptor_activation,
156 mpsgraph_single_gate_rnn_descriptor_set_activation,
157 "failed to set single-gate RNN activation"
158 );
159}
160
161descriptor_handle!(LSTMDescriptor);
162impl LSTMDescriptor {
163#[must_use]
165 pub fn new() -> Option<Self> {
166 let ptr = unsafe { ffi::mpsgraph_lstm_descriptor_new() };
168 if ptr.is_null() {
169 None
170 } else {
171 Some(Self { ptr })
172 }
173 }
174
175 bool_getter_setter!(
176 reverse,
177 set_reverse,
178 mpsgraph_lstm_descriptor_reverse,
179 mpsgraph_lstm_descriptor_set_reverse,
180 "failed to set LSTM reverse"
181 );
182 bool_getter_setter!(
183 bidirectional,
184 set_bidirectional,
185 mpsgraph_lstm_descriptor_bidirectional,
186 mpsgraph_lstm_descriptor_set_bidirectional,
187 "failed to set LSTM bidirectional"
188 );
189 bool_getter_setter!(
190 produce_cell,
191 set_produce_cell,
192 mpsgraph_lstm_descriptor_produce_cell,
193 mpsgraph_lstm_descriptor_set_produce_cell,
194 "failed to set LSTM produceCell"
195 );
196 bool_getter_setter!(
197 training,
198 set_training,
199 mpsgraph_lstm_descriptor_training,
200 mpsgraph_lstm_descriptor_set_training,
201 "failed to set LSTM training"
202 );
203 bool_getter_setter!(
204 forget_gate_last,
205 set_forget_gate_last,
206 mpsgraph_lstm_descriptor_forget_gate_last,
207 mpsgraph_lstm_descriptor_set_forget_gate_last,
208 "failed to set LSTM forgetGateLast"
209 );
210 activation_getter_setter!(
211 input_gate_activation,
212 set_input_gate_activation,
213 mpsgraph_lstm_descriptor_input_gate_activation,
214 mpsgraph_lstm_descriptor_set_input_gate_activation,
215 "failed to set LSTM inputGateActivation"
216 );
217 activation_getter_setter!(
218 forget_gate_activation,
219 set_forget_gate_activation,
220 mpsgraph_lstm_descriptor_forget_gate_activation,
221 mpsgraph_lstm_descriptor_set_forget_gate_activation,
222 "failed to set LSTM forgetGateActivation"
223 );
224 activation_getter_setter!(
225 cell_gate_activation,
226 set_cell_gate_activation,
227 mpsgraph_lstm_descriptor_cell_gate_activation,
228 mpsgraph_lstm_descriptor_set_cell_gate_activation,
229 "failed to set LSTM cellGateActivation"
230 );
231 activation_getter_setter!(
232 output_gate_activation,
233 set_output_gate_activation,
234 mpsgraph_lstm_descriptor_output_gate_activation,
235 mpsgraph_lstm_descriptor_set_output_gate_activation,
236 "failed to set LSTM outputGateActivation"
237 );
238 activation_getter_setter!(
239 activation,
240 set_activation,
241 mpsgraph_lstm_descriptor_activation,
242 mpsgraph_lstm_descriptor_set_activation,
243 "failed to set LSTM activation"
244 );
245}
246
247descriptor_handle!(GRUDescriptor);
248impl GRUDescriptor {
249#[must_use]
251 pub fn new() -> Option<Self> {
252 let ptr = unsafe { ffi::mpsgraph_gru_descriptor_new() };
254 if ptr.is_null() {
255 None
256 } else {
257 Some(Self { ptr })
258 }
259 }
260
261 bool_getter_setter!(
262 reverse,
263 set_reverse,
264 mpsgraph_gru_descriptor_reverse,
265 mpsgraph_gru_descriptor_set_reverse,
266 "failed to set GRU reverse"
267 );
268 bool_getter_setter!(
269 bidirectional,
270 set_bidirectional,
271 mpsgraph_gru_descriptor_bidirectional,
272 mpsgraph_gru_descriptor_set_bidirectional,
273 "failed to set GRU bidirectional"
274 );
275 bool_getter_setter!(
276 training,
277 set_training,
278 mpsgraph_gru_descriptor_training,
279 mpsgraph_gru_descriptor_set_training,
280 "failed to set GRU training"
281 );
282 bool_getter_setter!(
283 reset_gate_first,
284 set_reset_gate_first,
285 mpsgraph_gru_descriptor_reset_gate_first,
286 mpsgraph_gru_descriptor_set_reset_gate_first,
287 "failed to set GRU resetGateFirst"
288 );
289 bool_getter_setter!(
290 reset_after,
291 set_reset_after,
292 mpsgraph_gru_descriptor_reset_after,
293 mpsgraph_gru_descriptor_set_reset_after,
294 "failed to set GRU resetAfter"
295 );
296 bool_getter_setter!(
297 flip_z,
298 set_flip_z,
299 mpsgraph_gru_descriptor_flip_z,
300 mpsgraph_gru_descriptor_set_flip_z,
301 "failed to set GRU flipZ"
302 );
303 activation_getter_setter!(
304 update_gate_activation,
305 set_update_gate_activation,
306 mpsgraph_gru_descriptor_update_gate_activation,
307 mpsgraph_gru_descriptor_set_update_gate_activation,
308 "failed to set GRU updateGateActivation"
309 );
310 activation_getter_setter!(
311 reset_gate_activation,
312 set_reset_gate_activation,
313 mpsgraph_gru_descriptor_reset_gate_activation,
314 mpsgraph_gru_descriptor_set_reset_gate_activation,
315 "failed to set GRU resetGateActivation"
316 );
317 activation_getter_setter!(
318 output_gate_activation,
319 set_output_gate_activation,
320 mpsgraph_gru_descriptor_output_gate_activation,
321 mpsgraph_gru_descriptor_set_output_gate_activation,
322 "failed to set GRU outputGateActivation"
323 );
324}
325
326impl crate::graph::Graph {
327#[allow(clippy::too_many_arguments)]
329 pub fn single_gate_rnn(
330 &self,
331 source: &Tensor,
332 recurrent_weight: &Tensor,
333 input_weight: Option<&Tensor>,
334 bias: Option<&Tensor>,
335 init_state: Option<&Tensor>,
336 mask: Option<&Tensor>,
337 descriptor: &SingleGateRNNDescriptor,
338 name: Option<&str>,
339 ) -> Option<Vec<Tensor>> {
340 let name = optional_cstring(name);
341 let box_handle = unsafe {
343 ffi::mpsgraph_graph_single_gate_rnn(
344 self.as_ptr(),
345 source.as_ptr(),
346 recurrent_weight.as_ptr(),
347 optional_tensor_ptr(input_weight),
348 optional_tensor_ptr(bias),
349 optional_tensor_ptr(init_state),
350 optional_tensor_ptr(mask),
351 descriptor.as_ptr(),
352 cstring_ptr(&name),
353 )
354 };
355 wrap_tensor_array(box_handle)
356 }
357
358#[allow(clippy::too_many_arguments)]
360 pub fn lstm(
361 &self,
362 source: &Tensor,
363 recurrent_weight: &Tensor,
364 input_weight: Option<&Tensor>,
365 bias: Option<&Tensor>,
366 init_state: Option<&Tensor>,
367 init_cell: Option<&Tensor>,
368 mask: Option<&Tensor>,
369 peephole: Option<&Tensor>,
370 descriptor: &LSTMDescriptor,
371 name: Option<&str>,
372 ) -> Option<Vec<Tensor>> {
373 let name = optional_cstring(name);
374 let box_handle = unsafe {
376 ffi::mpsgraph_graph_lstm(
377 self.as_ptr(),
378 source.as_ptr(),
379 recurrent_weight.as_ptr(),
380 optional_tensor_ptr(input_weight),
381 optional_tensor_ptr(bias),
382 optional_tensor_ptr(init_state),
383 optional_tensor_ptr(init_cell),
384 optional_tensor_ptr(mask),
385 optional_tensor_ptr(peephole),
386 descriptor.as_ptr(),
387 cstring_ptr(&name),
388 )
389 };
390 wrap_tensor_array(box_handle)
391 }
392
393#[allow(clippy::too_many_arguments)]
395 pub fn gru(
396 &self,
397 source: &Tensor,
398 recurrent_weight: &Tensor,
399 input_weight: Option<&Tensor>,
400 bias: Option<&Tensor>,
401 init_state: Option<&Tensor>,
402 mask: Option<&Tensor>,
403 secondary_bias: Option<&Tensor>,
404 descriptor: &GRUDescriptor,
405 name: Option<&str>,
406 ) -> Option<Vec<Tensor>> {
407 let name = optional_cstring(name);
408 let box_handle = unsafe {
410 ffi::mpsgraph_graph_gru(
411 self.as_ptr(),
412 source.as_ptr(),
413 recurrent_weight.as_ptr(),
414 optional_tensor_ptr(input_weight),
415 optional_tensor_ptr(bias),
416 optional_tensor_ptr(init_state),
417 optional_tensor_ptr(mask),
418 optional_tensor_ptr(secondary_bias),
419 descriptor.as_ptr(),
420 cstring_ptr(&name),
421 )
422 };
423 wrap_tensor_array(box_handle)
424 }
425}