1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::collect_owned_tensors;
4use core::ffi::{c_char, c_void};
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8 name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13 value.as_ref().map_or(core::ptr::null(), |value| value.as_ptr())
14}
15
16fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
17 if ptr.is_null() {
18 None
19 } else {
20 Some(Tensor::from_raw(ptr))
21 }
22}
23
24fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
25 let mut values = collect_owned_tensors(box_handle);
26 if values.len() != 2 {
27 return None;
28 }
29 let second = values.pop()?;
30 let first = values.pop()?;
31 Some((first, second))
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[repr(u32)]
36pub enum UnaryArithmeticOp {
37 Identity = 0,
38 Exponent = 1,
39 ExponentBase2 = 2,
40 ExponentBase10 = 3,
41 Logarithm = 4,
42 LogarithmBase2 = 5,
43 LogarithmBase10 = 6,
44 Square = 7,
45 SquareRoot = 8,
46 Reciprocal = 9,
47 Absolute = 10,
48 Negative = 11,
49 Sign = 12,
50 SignBit = 13,
51 Ceil = 14,
52 Floor = 15,
53 Round = 16,
54 Rint = 17,
55 Sin = 18,
56 Cos = 19,
57 Tan = 20,
58 Sinh = 21,
59 Cosh = 22,
60 Tanh = 23,
61 Asin = 24,
62 Acos = 25,
63 Atan = 26,
64 Asinh = 27,
65 Acosh = 28,
66 Atanh = 29,
67 IsNaN = 30,
68 IsInfinite = 31,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72#[repr(u32)]
73pub enum BinaryArithmeticOp {
74 Addition = 0,
75 Subtraction = 1,
76 Multiplication = 2,
77 Division = 3,
78 DivisionNoNaN = 4,
79 Power = 5,
80 Minimum = 6,
81 Maximum = 7,
82 Equal = 8,
83 NotEqual = 9,
84 GreaterThan = 10,
85 GreaterThanOrEqualTo = 11,
86 LessThan = 12,
87 LessThanOrEqualTo = 13,
88 LogicalAnd = 14,
89 LogicalOr = 15,
90 Atan2 = 16,
91 FloorModulo = 17,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[repr(u32)]
96pub enum ReductionAxisOp {
97 Sum = 0,
98 Maximum = 1,
99 Minimum = 2,
100 Product = 3,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104#[repr(u32)]
105pub enum ReductionAxesOp {
106 Sum = 0,
107 Maximum = 1,
108 Minimum = 2,
109 Product = 3,
110}
111
112impl crate::graph::Graph {
113 #[must_use]
114 pub fn unary_arithmetic(
115 &self,
116 op: UnaryArithmeticOp,
117 tensor: &Tensor,
118 name: Option<&str>,
119 ) -> Option<Tensor> {
120 let name = optional_cstring(name);
121 let ptr = unsafe {
123 ffi::mpsgraph_graph_arithmetic_unary(
124 self.as_ptr(),
125 op as u32,
126 tensor.as_ptr(),
127 cstring_ptr(&name),
128 )
129 };
130 wrap_tensor(ptr)
131 }
132
133 #[must_use]
134 pub fn binary_arithmetic(
135 &self,
136 op: BinaryArithmeticOp,
137 primary: &Tensor,
138 secondary: &Tensor,
139 name: Option<&str>,
140 ) -> Option<Tensor> {
141 let name = optional_cstring(name);
142 let ptr = unsafe {
144 ffi::mpsgraph_graph_arithmetic_binary(
145 self.as_ptr(),
146 op as u32,
147 primary.as_ptr(),
148 secondary.as_ptr(),
149 cstring_ptr(&name),
150 )
151 };
152 wrap_tensor(ptr)
153 }
154
155 #[must_use]
156 pub fn select(
157 &self,
158 predicate: &Tensor,
159 true_tensor: &Tensor,
160 false_tensor: &Tensor,
161 name: Option<&str>,
162 ) -> Option<Tensor> {
163 let name = optional_cstring(name);
164 let ptr = unsafe {
166 ffi::mpsgraph_graph_select(
167 self.as_ptr(),
168 predicate.as_ptr(),
169 true_tensor.as_ptr(),
170 false_tensor.as_ptr(),
171 cstring_ptr(&name),
172 )
173 };
174 wrap_tensor(ptr)
175 }
176
177 #[must_use]
178 pub fn relu_gradient(
179 &self,
180 gradient: &Tensor,
181 source: &Tensor,
182 name: Option<&str>,
183 ) -> Option<Tensor> {
184 let name = optional_cstring(name);
185 let ptr = unsafe {
187 ffi::mpsgraph_graph_relu_gradient(
188 self.as_ptr(),
189 gradient.as_ptr(),
190 source.as_ptr(),
191 cstring_ptr(&name),
192 )
193 };
194 wrap_tensor(ptr)
195 }
196
197 #[must_use]
198 pub fn sigmoid_gradient(
199 &self,
200 gradient: &Tensor,
201 source: &Tensor,
202 name: Option<&str>,
203 ) -> Option<Tensor> {
204 let name = optional_cstring(name);
205 let ptr = unsafe {
207 ffi::mpsgraph_graph_sigmoid_gradient(
208 self.as_ptr(),
209 gradient.as_ptr(),
210 source.as_ptr(),
211 cstring_ptr(&name),
212 )
213 };
214 wrap_tensor(ptr)
215 }
216
217 #[must_use]
218 pub fn softmax_gradient(
219 &self,
220 gradient: &Tensor,
221 source: &Tensor,
222 axis: isize,
223 name: Option<&str>,
224 ) -> Option<Tensor> {
225 let name = optional_cstring(name);
226 let ptr = unsafe {
228 ffi::mpsgraph_graph_softmax_gradient(
229 self.as_ptr(),
230 gradient.as_ptr(),
231 source.as_ptr(),
232 axis,
233 cstring_ptr(&name),
234 )
235 };
236 wrap_tensor(ptr)
237 }
238
239 #[must_use]
240 pub fn leaky_relu(
241 &self,
242 tensor: &Tensor,
243 alpha: f64,
244 name: Option<&str>,
245 ) -> Option<Tensor> {
246 let name = optional_cstring(name);
247 let ptr = unsafe {
249 ffi::mpsgraph_graph_leaky_relu_scalar(
250 self.as_ptr(),
251 tensor.as_ptr(),
252 alpha,
253 cstring_ptr(&name),
254 )
255 };
256 wrap_tensor(ptr)
257 }
258
259 #[must_use]
260 pub fn leaky_relu_tensor(
261 &self,
262 tensor: &Tensor,
263 alpha_tensor: &Tensor,
264 name: Option<&str>,
265 ) -> Option<Tensor> {
266 let name = optional_cstring(name);
267 let ptr = unsafe {
269 ffi::mpsgraph_graph_leaky_relu_tensor(
270 self.as_ptr(),
271 tensor.as_ptr(),
272 alpha_tensor.as_ptr(),
273 cstring_ptr(&name),
274 )
275 };
276 wrap_tensor(ptr)
277 }
278
279 #[must_use]
280 pub fn leaky_relu_gradient(
281 &self,
282 gradient: &Tensor,
283 source: &Tensor,
284 alpha_tensor: &Tensor,
285 name: Option<&str>,
286 ) -> Option<Tensor> {
287 let name = optional_cstring(name);
288 let ptr = unsafe {
290 ffi::mpsgraph_graph_leaky_relu_gradient(
291 self.as_ptr(),
292 gradient.as_ptr(),
293 source.as_ptr(),
294 alpha_tensor.as_ptr(),
295 cstring_ptr(&name),
296 )
297 };
298 wrap_tensor(ptr)
299 }
300
301 #[must_use]
302 pub fn reduce_axis(
303 &self,
304 op: ReductionAxisOp,
305 tensor: &Tensor,
306 axis: isize,
307 name: Option<&str>,
308 ) -> Option<Tensor> {
309 let name = optional_cstring(name);
310 let ptr = unsafe {
312 ffi::mpsgraph_graph_reduction_axis(
313 self.as_ptr(),
314 op as u32,
315 tensor.as_ptr(),
316 axis,
317 cstring_ptr(&name),
318 )
319 };
320 wrap_tensor(ptr)
321 }
322
323 #[must_use]
324 pub fn reduce_axes(
325 &self,
326 op: ReductionAxesOp,
327 tensor: &Tensor,
328 axes: &[usize],
329 name: Option<&str>,
330 ) -> Option<Tensor> {
331 let name = optional_cstring(name);
332 let ptr = unsafe {
334 ffi::mpsgraph_graph_reduction_axes(
335 self.as_ptr(),
336 op as u32,
337 tensor.as_ptr(),
338 axes.as_ptr(),
339 axes.len(),
340 cstring_ptr(&name),
341 )
342 };
343 wrap_tensor(ptr)
344 }
345
346 #[must_use]
347 pub fn concat_pair(
348 &self,
349 first: &Tensor,
350 second: &Tensor,
351 dimension: isize,
352 name: Option<&str>,
353 ) -> Option<Tensor> {
354 let name = optional_cstring(name);
355 let ptr = unsafe {
357 ffi::mpsgraph_graph_concat_pair(
358 self.as_ptr(),
359 first.as_ptr(),
360 second.as_ptr(),
361 dimension,
362 cstring_ptr(&name),
363 )
364 };
365 wrap_tensor(ptr)
366 }
367
368 #[must_use]
369 pub fn concat_tensors(
370 &self,
371 tensors: &[&Tensor],
372 dimension: isize,
373 interleave: bool,
374 name: Option<&str>,
375 ) -> Option<Tensor> {
376 let name = optional_cstring(name);
377 let handles = tensors.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
378 let ptr = unsafe {
380 ffi::mpsgraph_graph_concat_tensors(
381 self.as_ptr(),
382 handles.as_ptr(),
383 handles.len(),
384 dimension,
385 interleave,
386 cstring_ptr(&name),
387 )
388 };
389 wrap_tensor(ptr)
390 }
391
392 #[must_use]
393 pub fn split_sizes(
394 &self,
395 tensor: &Tensor,
396 split_sizes: &[usize],
397 axis: isize,
398 name: Option<&str>,
399 ) -> Vec<Tensor> {
400 let name = optional_cstring(name);
401 let box_handle = unsafe {
403 ffi::mpsgraph_graph_split_sizes(
404 self.as_ptr(),
405 tensor.as_ptr(),
406 split_sizes.as_ptr(),
407 split_sizes.len(),
408 axis,
409 cstring_ptr(&name),
410 )
411 };
412 collect_owned_tensors(box_handle)
413 }
414
415 #[must_use]
416 pub fn split_sizes_tensor(
417 &self,
418 tensor: &Tensor,
419 split_sizes_tensor: &Tensor,
420 axis: isize,
421 name: Option<&str>,
422 ) -> Vec<Tensor> {
423 let name = optional_cstring(name);
424 let box_handle = unsafe {
426 ffi::mpsgraph_graph_split_sizes_tensor(
427 self.as_ptr(),
428 tensor.as_ptr(),
429 split_sizes_tensor.as_ptr(),
430 axis,
431 cstring_ptr(&name),
432 )
433 };
434 collect_owned_tensors(box_handle)
435 }
436
437 #[must_use]
438 pub fn split_num(
439 &self,
440 tensor: &Tensor,
441 num_splits: usize,
442 axis: isize,
443 name: Option<&str>,
444 ) -> Vec<Tensor> {
445 let name = optional_cstring(name);
446 let box_handle = unsafe {
448 ffi::mpsgraph_graph_split_num(
449 self.as_ptr(),
450 tensor.as_ptr(),
451 num_splits,
452 axis,
453 cstring_ptr(&name),
454 )
455 };
456 collect_owned_tensors(box_handle)
457 }
458
459 #[must_use]
460 pub fn stack(&self, tensors: &[&Tensor], axis: isize, name: Option<&str>) -> Option<Tensor> {
461 let name = optional_cstring(name);
462 let handles = tensors.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
463 let ptr = unsafe {
465 ffi::mpsgraph_graph_stack(
466 self.as_ptr(),
467 handles.as_ptr(),
468 handles.len(),
469 axis,
470 cstring_ptr(&name),
471 )
472 };
473 wrap_tensor(ptr)
474 }
475
476 #[must_use]
477 pub fn pad(
478 &self,
479 tensor: &Tensor,
480 padding_mode: isize,
481 left_padding: &[isize],
482 right_padding: &[isize],
483 constant_value: f64,
484 name: Option<&str>,
485 ) -> Option<Tensor> {
486 let name = optional_cstring(name);
487 let ptr = unsafe {
489 ffi::mpsgraph_graph_pad(
490 self.as_ptr(),
491 tensor.as_ptr(),
492 padding_mode,
493 left_padding.as_ptr(),
494 left_padding.len(),
495 right_padding.as_ptr(),
496 right_padding.len(),
497 constant_value,
498 cstring_ptr(&name),
499 )
500 };
501 wrap_tensor(ptr)
502 }
503
504 #[must_use]
505 pub fn top_k(&self, source: &Tensor, k: usize, name: Option<&str>) -> Option<(Tensor, Tensor)> {
506 let name = optional_cstring(name);
507 let box_handle = unsafe {
509 ffi::mpsgraph_graph_top_k(self.as_ptr(), source.as_ptr(), k, cstring_ptr(&name))
510 };
511 wrap_tensor_pair(box_handle)
512 }
513
514 #[must_use]
515 pub fn top_k_tensor(
516 &self,
517 source: &Tensor,
518 k_tensor: &Tensor,
519 name: Option<&str>,
520 ) -> Option<(Tensor, Tensor)> {
521 let name = optional_cstring(name);
522 let box_handle = unsafe {
524 ffi::mpsgraph_graph_top_k_tensor(
525 self.as_ptr(),
526 source.as_ptr(),
527 k_tensor.as_ptr(),
528 cstring_ptr(&name),
529 )
530 };
531 wrap_tensor_pair(box_handle)
532 }
533}