1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::collect_owned_tensors;
4use core::ffi::{c_char, c_void};
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8 name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13 value
14 .as_ref()
15 .map_or(core::ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19 if ptr.is_null() {
20 None
21 } else {
22 Some(Tensor::from_raw(ptr))
23 }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27 let mut values = collect_owned_tensors(box_handle);
28 if values.len() != 2 {
29 return None;
30 }
31 let second = values.pop()?;
32 let first = values.pop()?;
33 Some((first, second))
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37#[repr(u32)]
38pub enum UnaryArithmeticOp {
39 Identity = 0,
40 Exponent = 1,
41 ExponentBase2 = 2,
42 ExponentBase10 = 3,
43 Logarithm = 4,
44 LogarithmBase2 = 5,
45 LogarithmBase10 = 6,
46 Square = 7,
47 SquareRoot = 8,
48 Reciprocal = 9,
49 Absolute = 10,
50 Negative = 11,
51 Sign = 12,
52 SignBit = 13,
53 Ceil = 14,
54 Floor = 15,
55 Round = 16,
56 Rint = 17,
57 Sin = 18,
58 Cos = 19,
59 Tan = 20,
60 Sinh = 21,
61 Cosh = 22,
62 Tanh = 23,
63 Asin = 24,
64 Acos = 25,
65 Atan = 26,
66 Asinh = 27,
67 Acosh = 28,
68 Atanh = 29,
69 IsNaN = 30,
70 IsInfinite = 31,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[repr(u32)]
75pub enum BinaryArithmeticOp {
76 Addition = 0,
77 Subtraction = 1,
78 Multiplication = 2,
79 Division = 3,
80 DivisionNoNaN = 4,
81 Power = 5,
82 Minimum = 6,
83 Maximum = 7,
84 Equal = 8,
85 NotEqual = 9,
86 GreaterThan = 10,
87 GreaterThanOrEqualTo = 11,
88 LessThan = 12,
89 LessThanOrEqualTo = 13,
90 LogicalAnd = 14,
91 LogicalOr = 15,
92 Atan2 = 16,
93 FloorModulo = 17,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97#[repr(u32)]
98pub enum ReductionAxisOp {
99 Sum = 0,
100 Maximum = 1,
101 Minimum = 2,
102 Product = 3,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106#[repr(u32)]
107pub enum ReductionAxesOp {
108 Sum = 0,
109 Maximum = 1,
110 Minimum = 2,
111 Product = 3,
112}
113
114impl crate::graph::Graph {
115 #[must_use]
116 pub fn unary_arithmetic(
117 &self,
118 op: UnaryArithmeticOp,
119 tensor: &Tensor,
120 name: Option<&str>,
121 ) -> Option<Tensor> {
122 let name = optional_cstring(name);
123 let ptr = unsafe {
125 ffi::mpsgraph_graph_arithmetic_unary(
126 self.as_ptr(),
127 op as u32,
128 tensor.as_ptr(),
129 cstring_ptr(&name),
130 )
131 };
132 wrap_tensor(ptr)
133 }
134
135 #[must_use]
136 pub fn binary_arithmetic(
137 &self,
138 op: BinaryArithmeticOp,
139 primary: &Tensor,
140 secondary: &Tensor,
141 name: Option<&str>,
142 ) -> Option<Tensor> {
143 let name = optional_cstring(name);
144 let ptr = unsafe {
146 ffi::mpsgraph_graph_arithmetic_binary(
147 self.as_ptr(),
148 op as u32,
149 primary.as_ptr(),
150 secondary.as_ptr(),
151 cstring_ptr(&name),
152 )
153 };
154 wrap_tensor(ptr)
155 }
156
157 #[must_use]
158 pub fn select(
159 &self,
160 predicate: &Tensor,
161 true_tensor: &Tensor,
162 false_tensor: &Tensor,
163 name: Option<&str>,
164 ) -> Option<Tensor> {
165 let name = optional_cstring(name);
166 let ptr = unsafe {
168 ffi::mpsgraph_graph_select(
169 self.as_ptr(),
170 predicate.as_ptr(),
171 true_tensor.as_ptr(),
172 false_tensor.as_ptr(),
173 cstring_ptr(&name),
174 )
175 };
176 wrap_tensor(ptr)
177 }
178
179 #[must_use]
180 pub fn relu_gradient(
181 &self,
182 gradient: &Tensor,
183 source: &Tensor,
184 name: Option<&str>,
185 ) -> Option<Tensor> {
186 let name = optional_cstring(name);
187 let ptr = unsafe {
189 ffi::mpsgraph_graph_relu_gradient(
190 self.as_ptr(),
191 gradient.as_ptr(),
192 source.as_ptr(),
193 cstring_ptr(&name),
194 )
195 };
196 wrap_tensor(ptr)
197 }
198
199 #[must_use]
200 pub fn sigmoid_gradient(
201 &self,
202 gradient: &Tensor,
203 source: &Tensor,
204 name: Option<&str>,
205 ) -> Option<Tensor> {
206 let name = optional_cstring(name);
207 let ptr = unsafe {
209 ffi::mpsgraph_graph_sigmoid_gradient(
210 self.as_ptr(),
211 gradient.as_ptr(),
212 source.as_ptr(),
213 cstring_ptr(&name),
214 )
215 };
216 wrap_tensor(ptr)
217 }
218
219 #[must_use]
220 pub fn softmax_gradient(
221 &self,
222 gradient: &Tensor,
223 source: &Tensor,
224 axis: isize,
225 name: Option<&str>,
226 ) -> Option<Tensor> {
227 let name = optional_cstring(name);
228 let ptr = unsafe {
230 ffi::mpsgraph_graph_softmax_gradient(
231 self.as_ptr(),
232 gradient.as_ptr(),
233 source.as_ptr(),
234 axis,
235 cstring_ptr(&name),
236 )
237 };
238 wrap_tensor(ptr)
239 }
240
241 #[must_use]
242 pub fn leaky_relu(&self, tensor: &Tensor, alpha: f64, name: Option<&str>) -> Option<Tensor> {
243 let name = optional_cstring(name);
244 let ptr = unsafe {
246 ffi::mpsgraph_graph_leaky_relu_scalar(
247 self.as_ptr(),
248 tensor.as_ptr(),
249 alpha,
250 cstring_ptr(&name),
251 )
252 };
253 wrap_tensor(ptr)
254 }
255
256 #[must_use]
257 pub fn leaky_relu_tensor(
258 &self,
259 tensor: &Tensor,
260 alpha_tensor: &Tensor,
261 name: Option<&str>,
262 ) -> Option<Tensor> {
263 let name = optional_cstring(name);
264 let ptr = unsafe {
266 ffi::mpsgraph_graph_leaky_relu_tensor(
267 self.as_ptr(),
268 tensor.as_ptr(),
269 alpha_tensor.as_ptr(),
270 cstring_ptr(&name),
271 )
272 };
273 wrap_tensor(ptr)
274 }
275
276 #[must_use]
277 pub fn leaky_relu_gradient(
278 &self,
279 gradient: &Tensor,
280 source: &Tensor,
281 alpha_tensor: &Tensor,
282 name: Option<&str>,
283 ) -> Option<Tensor> {
284 let name = optional_cstring(name);
285 let ptr = unsafe {
287 ffi::mpsgraph_graph_leaky_relu_gradient(
288 self.as_ptr(),
289 gradient.as_ptr(),
290 source.as_ptr(),
291 alpha_tensor.as_ptr(),
292 cstring_ptr(&name),
293 )
294 };
295 wrap_tensor(ptr)
296 }
297
298 #[must_use]
299 pub fn reduce_axis(
300 &self,
301 op: ReductionAxisOp,
302 tensor: &Tensor,
303 axis: isize,
304 name: Option<&str>,
305 ) -> Option<Tensor> {
306 let name = optional_cstring(name);
307 let ptr = unsafe {
309 ffi::mpsgraph_graph_reduction_axis(
310 self.as_ptr(),
311 op as u32,
312 tensor.as_ptr(),
313 axis,
314 cstring_ptr(&name),
315 )
316 };
317 wrap_tensor(ptr)
318 }
319
320 #[must_use]
321 pub fn reduce_axes(
322 &self,
323 op: ReductionAxesOp,
324 tensor: &Tensor,
325 axes: &[usize],
326 name: Option<&str>,
327 ) -> Option<Tensor> {
328 let name = optional_cstring(name);
329 let ptr = unsafe {
331 ffi::mpsgraph_graph_reduction_axes(
332 self.as_ptr(),
333 op as u32,
334 tensor.as_ptr(),
335 axes.as_ptr(),
336 axes.len(),
337 cstring_ptr(&name),
338 )
339 };
340 wrap_tensor(ptr)
341 }
342
343 #[must_use]
344 pub fn concat_pair(
345 &self,
346 first: &Tensor,
347 second: &Tensor,
348 dimension: isize,
349 name: Option<&str>,
350 ) -> Option<Tensor> {
351 let name = optional_cstring(name);
352 let ptr = unsafe {
354 ffi::mpsgraph_graph_concat_pair(
355 self.as_ptr(),
356 first.as_ptr(),
357 second.as_ptr(),
358 dimension,
359 cstring_ptr(&name),
360 )
361 };
362 wrap_tensor(ptr)
363 }
364
365 #[must_use]
366 pub fn concat_tensors(
367 &self,
368 tensors: &[&Tensor],
369 dimension: isize,
370 interleave: bool,
371 name: Option<&str>,
372 ) -> Option<Tensor> {
373 let name = optional_cstring(name);
374 let handles = tensors
375 .iter()
376 .map(|tensor| tensor.as_ptr())
377 .collect::<Vec<_>>();
378 let ptr = unsafe {
380 ffi::mpsgraph_graph_concat_tensors(
381 self.as_ptr(),
382 handles.as_ptr(),
383 handles.len(),
384 dimension,
385 interleave,
386 cstring_ptr(&name),
387 )
388 };
389 wrap_tensor(ptr)
390 }
391
392 #[must_use]
393 pub fn split_sizes(
394 &self,
395 tensor: &Tensor,
396 split_sizes: &[usize],
397 axis: isize,
398 name: Option<&str>,
399 ) -> Vec<Tensor> {
400 let name = optional_cstring(name);
401 let box_handle = unsafe {
403 ffi::mpsgraph_graph_split_sizes(
404 self.as_ptr(),
405 tensor.as_ptr(),
406 split_sizes.as_ptr(),
407 split_sizes.len(),
408 axis,
409 cstring_ptr(&name),
410 )
411 };
412 collect_owned_tensors(box_handle)
413 }
414
415 #[must_use]
416 pub fn split_sizes_tensor(
417 &self,
418 tensor: &Tensor,
419 split_sizes_tensor: &Tensor,
420 axis: isize,
421 name: Option<&str>,
422 ) -> Vec<Tensor> {
423 let name = optional_cstring(name);
424 let box_handle = unsafe {
426 ffi::mpsgraph_graph_split_sizes_tensor(
427 self.as_ptr(),
428 tensor.as_ptr(),
429 split_sizes_tensor.as_ptr(),
430 axis,
431 cstring_ptr(&name),
432 )
433 };
434 collect_owned_tensors(box_handle)
435 }
436
437 #[must_use]
438 pub fn split_num(
439 &self,
440 tensor: &Tensor,
441 num_splits: usize,
442 axis: isize,
443 name: Option<&str>,
444 ) -> Vec<Tensor> {
445 let name = optional_cstring(name);
446 let box_handle = unsafe {
448 ffi::mpsgraph_graph_split_num(
449 self.as_ptr(),
450 tensor.as_ptr(),
451 num_splits,
452 axis,
453 cstring_ptr(&name),
454 )
455 };
456 collect_owned_tensors(box_handle)
457 }
458
459 #[must_use]
460 pub fn stack(&self, tensors: &[&Tensor], axis: isize, name: Option<&str>) -> Option<Tensor> {
461 let name = optional_cstring(name);
462 let handles = tensors
463 .iter()
464 .map(|tensor| tensor.as_ptr())
465 .collect::<Vec<_>>();
466 let ptr = unsafe {
468 ffi::mpsgraph_graph_stack(
469 self.as_ptr(),
470 handles.as_ptr(),
471 handles.len(),
472 axis,
473 cstring_ptr(&name),
474 )
475 };
476 wrap_tensor(ptr)
477 }
478
479 #[must_use]
480 pub fn pad(
481 &self,
482 tensor: &Tensor,
483 padding_mode: isize,
484 left_padding: &[isize],
485 right_padding: &[isize],
486 constant_value: f64,
487 name: Option<&str>,
488 ) -> Option<Tensor> {
489 let name = optional_cstring(name);
490 let ptr = unsafe {
492 ffi::mpsgraph_graph_pad(
493 self.as_ptr(),
494 tensor.as_ptr(),
495 padding_mode,
496 left_padding.as_ptr(),
497 left_padding.len(),
498 right_padding.as_ptr(),
499 right_padding.len(),
500 constant_value,
501 cstring_ptr(&name),
502 )
503 };
504 wrap_tensor(ptr)
505 }
506
507 #[must_use]
508 pub fn top_k(&self, source: &Tensor, k: usize, name: Option<&str>) -> Option<(Tensor, Tensor)> {
509 let name = optional_cstring(name);
510 let box_handle = unsafe {
512 ffi::mpsgraph_graph_top_k(self.as_ptr(), source.as_ptr(), k, cstring_ptr(&name))
513 };
514 wrap_tensor_pair(box_handle)
515 }
516
517 #[must_use]
518 pub fn top_k_tensor(
519 &self,
520 source: &Tensor,
521 k_tensor: &Tensor,
522 name: Option<&str>,
523 ) -> Option<(Tensor, Tensor)> {
524 let name = optional_cstring(name);
525 let box_handle = unsafe {
527 ffi::mpsgraph_graph_top_k_tensor(
528 self.as_ptr(),
529 source.as_ptr(),
530 k_tensor.as_ptr(),
531 cstring_ptr(&name),
532 )
533 };
534 wrap_tensor_pair(box_handle)
535 }
536}