1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::collect_owned_tensors;
4use core::ffi::{c_char, c_void};
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8 name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13 value
14 .as_ref()
15 .map_or(core::ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19 if ptr.is_null() {
20 None
21 } else {
22 Some(Tensor::from_raw(ptr))
23 }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27 let mut values = collect_owned_tensors(box_handle);
28 if values.len() != 2 {
29 return None;
30 }
31 let second = values.pop()?;
32 let first = values.pop()?;
33 Some((first, second))
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[repr(u32)]
39pub enum UnaryArithmeticOp {
40Identity = 0,
42Exponent = 1,
44ExponentBase2 = 2,
46ExponentBase10 = 3,
48Logarithm = 4,
50LogarithmBase2 = 5,
52LogarithmBase10 = 6,
54Square = 7,
56SquareRoot = 8,
58Reciprocal = 9,
60Absolute = 10,
62Negative = 11,
64Sign = 12,
66SignBit = 13,
68Ceil = 14,
70Floor = 15,
72Round = 16,
74Rint = 17,
76Sin = 18,
78Cos = 19,
80Tan = 20,
82Sinh = 21,
84Cosh = 22,
86Tanh = 23,
88Asin = 24,
90Acos = 25,
92Atan = 26,
94Asinh = 27,
96Acosh = 28,
98Atanh = 29,
100IsNaN = 30,
102IsInfinite = 31,
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[repr(u32)]
109pub enum BinaryArithmeticOp {
110Addition = 0,
112Subtraction = 1,
114Multiplication = 2,
116Division = 3,
118DivisionNoNaN = 4,
120Power = 5,
122Minimum = 6,
124Maximum = 7,
126Equal = 8,
128NotEqual = 9,
130GreaterThan = 10,
132GreaterThanOrEqualTo = 11,
134LessThan = 12,
136LessThanOrEqualTo = 13,
138LogicalAnd = 14,
140LogicalOr = 15,
142Atan2 = 16,
144FloorModulo = 17,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150#[repr(u32)]
151pub enum ReductionAxisOp {
152Sum = 0,
154Maximum = 1,
156Minimum = 2,
158Product = 3,
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164#[repr(u32)]
165pub enum ReductionAxesOp {
166Sum = 0,
168Maximum = 1,
170Minimum = 2,
172Product = 3,
174}
175
176impl crate::graph::Graph {
177#[must_use]
179 pub fn unary_arithmetic(
180 &self,
181 op: UnaryArithmeticOp,
182 tensor: &Tensor,
183 name: Option<&str>,
184 ) -> Option<Tensor> {
185 let name = optional_cstring(name);
186 let ptr = unsafe {
188 ffi::mpsgraph_graph_arithmetic_unary(
189 self.as_ptr(),
190 op as u32,
191 tensor.as_ptr(),
192 cstring_ptr(&name),
193 )
194 };
195 wrap_tensor(ptr)
196 }
197
198#[must_use]
200 pub fn binary_arithmetic(
201 &self,
202 op: BinaryArithmeticOp,
203 primary: &Tensor,
204 secondary: &Tensor,
205 name: Option<&str>,
206 ) -> Option<Tensor> {
207 let name = optional_cstring(name);
208 let ptr = unsafe {
210 ffi::mpsgraph_graph_arithmetic_binary(
211 self.as_ptr(),
212 op as u32,
213 primary.as_ptr(),
214 secondary.as_ptr(),
215 cstring_ptr(&name),
216 )
217 };
218 wrap_tensor(ptr)
219 }
220
221#[must_use]
223 pub fn select(
224 &self,
225 predicate: &Tensor,
226 true_tensor: &Tensor,
227 false_tensor: &Tensor,
228 name: Option<&str>,
229 ) -> Option<Tensor> {
230 let name = optional_cstring(name);
231 let ptr = unsafe {
233 ffi::mpsgraph_graph_select(
234 self.as_ptr(),
235 predicate.as_ptr(),
236 true_tensor.as_ptr(),
237 false_tensor.as_ptr(),
238 cstring_ptr(&name),
239 )
240 };
241 wrap_tensor(ptr)
242 }
243
244#[must_use]
246 pub fn relu_gradient(
247 &self,
248 gradient: &Tensor,
249 source: &Tensor,
250 name: Option<&str>,
251 ) -> Option<Tensor> {
252 let name = optional_cstring(name);
253 let ptr = unsafe {
255 ffi::mpsgraph_graph_relu_gradient(
256 self.as_ptr(),
257 gradient.as_ptr(),
258 source.as_ptr(),
259 cstring_ptr(&name),
260 )
261 };
262 wrap_tensor(ptr)
263 }
264
265#[must_use]
267 pub fn sigmoid_gradient(
268 &self,
269 gradient: &Tensor,
270 source: &Tensor,
271 name: Option<&str>,
272 ) -> Option<Tensor> {
273 let name = optional_cstring(name);
274 let ptr = unsafe {
276 ffi::mpsgraph_graph_sigmoid_gradient(
277 self.as_ptr(),
278 gradient.as_ptr(),
279 source.as_ptr(),
280 cstring_ptr(&name),
281 )
282 };
283 wrap_tensor(ptr)
284 }
285
286#[must_use]
288 pub fn softmax_gradient(
289 &self,
290 gradient: &Tensor,
291 source: &Tensor,
292 axis: isize,
293 name: Option<&str>,
294 ) -> Option<Tensor> {
295 let name = optional_cstring(name);
296 let ptr = unsafe {
298 ffi::mpsgraph_graph_softmax_gradient(
299 self.as_ptr(),
300 gradient.as_ptr(),
301 source.as_ptr(),
302 axis,
303 cstring_ptr(&name),
304 )
305 };
306 wrap_tensor(ptr)
307 }
308
309#[must_use]
311 pub fn leaky_relu(&self, tensor: &Tensor, alpha: f64, name: Option<&str>) -> Option<Tensor> {
312 let name = optional_cstring(name);
313 let ptr = unsafe {
315 ffi::mpsgraph_graph_leaky_relu_scalar(
316 self.as_ptr(),
317 tensor.as_ptr(),
318 alpha,
319 cstring_ptr(&name),
320 )
321 };
322 wrap_tensor(ptr)
323 }
324
325#[must_use]
327 pub fn leaky_relu_tensor(
328 &self,
329 tensor: &Tensor,
330 alpha_tensor: &Tensor,
331 name: Option<&str>,
332 ) -> Option<Tensor> {
333 let name = optional_cstring(name);
334 let ptr = unsafe {
336 ffi::mpsgraph_graph_leaky_relu_tensor(
337 self.as_ptr(),
338 tensor.as_ptr(),
339 alpha_tensor.as_ptr(),
340 cstring_ptr(&name),
341 )
342 };
343 wrap_tensor(ptr)
344 }
345
346#[must_use]
348 pub fn leaky_relu_gradient(
349 &self,
350 gradient: &Tensor,
351 source: &Tensor,
352 alpha_tensor: &Tensor,
353 name: Option<&str>,
354 ) -> Option<Tensor> {
355 let name = optional_cstring(name);
356 let ptr = unsafe {
358 ffi::mpsgraph_graph_leaky_relu_gradient(
359 self.as_ptr(),
360 gradient.as_ptr(),
361 source.as_ptr(),
362 alpha_tensor.as_ptr(),
363 cstring_ptr(&name),
364 )
365 };
366 wrap_tensor(ptr)
367 }
368
369#[must_use]
371 pub fn reduce_axis(
372 &self,
373 op: ReductionAxisOp,
374 tensor: &Tensor,
375 axis: isize,
376 name: Option<&str>,
377 ) -> Option<Tensor> {
378 let name = optional_cstring(name);
379 let ptr = unsafe {
381 ffi::mpsgraph_graph_reduction_axis(
382 self.as_ptr(),
383 op as u32,
384 tensor.as_ptr(),
385 axis,
386 cstring_ptr(&name),
387 )
388 };
389 wrap_tensor(ptr)
390 }
391
392#[must_use]
394 pub fn reduce_axes(
395 &self,
396 op: ReductionAxesOp,
397 tensor: &Tensor,
398 axes: &[usize],
399 name: Option<&str>,
400 ) -> Option<Tensor> {
401 let name = optional_cstring(name);
402 let ptr = unsafe {
404 ffi::mpsgraph_graph_reduction_axes(
405 self.as_ptr(),
406 op as u32,
407 tensor.as_ptr(),
408 axes.as_ptr(),
409 axes.len(),
410 cstring_ptr(&name),
411 )
412 };
413 wrap_tensor(ptr)
414 }
415
416#[must_use]
418 pub fn concat_pair(
419 &self,
420 first: &Tensor,
421 second: &Tensor,
422 dimension: isize,
423 name: Option<&str>,
424 ) -> Option<Tensor> {
425 let name = optional_cstring(name);
426 let ptr = unsafe {
428 ffi::mpsgraph_graph_concat_pair(
429 self.as_ptr(),
430 first.as_ptr(),
431 second.as_ptr(),
432 dimension,
433 cstring_ptr(&name),
434 )
435 };
436 wrap_tensor(ptr)
437 }
438
439#[must_use]
441 pub fn concat_tensors(
442 &self,
443 tensors: &[&Tensor],
444 dimension: isize,
445 interleave: bool,
446 name: Option<&str>,
447 ) -> Option<Tensor> {
448 let name = optional_cstring(name);
449 let handles = tensors
450 .iter()
451 .map(|tensor| tensor.as_ptr())
452 .collect::<Vec<_>>();
453 let ptr = unsafe {
455 ffi::mpsgraph_graph_concat_tensors(
456 self.as_ptr(),
457 handles.as_ptr(),
458 handles.len(),
459 dimension,
460 interleave,
461 cstring_ptr(&name),
462 )
463 };
464 wrap_tensor(ptr)
465 }
466
467#[must_use]
469 pub fn split_sizes(
470 &self,
471 tensor: &Tensor,
472 split_sizes: &[usize],
473 axis: isize,
474 name: Option<&str>,
475 ) -> Vec<Tensor> {
476 let name = optional_cstring(name);
477 let box_handle = unsafe {
479 ffi::mpsgraph_graph_split_sizes(
480 self.as_ptr(),
481 tensor.as_ptr(),
482 split_sizes.as_ptr(),
483 split_sizes.len(),
484 axis,
485 cstring_ptr(&name),
486 )
487 };
488 collect_owned_tensors(box_handle)
489 }
490
491#[must_use]
493 pub fn split_sizes_tensor(
494 &self,
495 tensor: &Tensor,
496 split_sizes_tensor: &Tensor,
497 axis: isize,
498 name: Option<&str>,
499 ) -> Vec<Tensor> {
500 let name = optional_cstring(name);
501 let box_handle = unsafe {
503 ffi::mpsgraph_graph_split_sizes_tensor(
504 self.as_ptr(),
505 tensor.as_ptr(),
506 split_sizes_tensor.as_ptr(),
507 axis,
508 cstring_ptr(&name),
509 )
510 };
511 collect_owned_tensors(box_handle)
512 }
513
514#[must_use]
516 pub fn split_num(
517 &self,
518 tensor: &Tensor,
519 num_splits: usize,
520 axis: isize,
521 name: Option<&str>,
522 ) -> Vec<Tensor> {
523 let name = optional_cstring(name);
524 let box_handle = unsafe {
526 ffi::mpsgraph_graph_split_num(
527 self.as_ptr(),
528 tensor.as_ptr(),
529 num_splits,
530 axis,
531 cstring_ptr(&name),
532 )
533 };
534 collect_owned_tensors(box_handle)
535 }
536
537#[must_use]
539 pub fn stack(&self, tensors: &[&Tensor], axis: isize, name: Option<&str>) -> Option<Tensor> {
540 let name = optional_cstring(name);
541 let handles = tensors
542 .iter()
543 .map(|tensor| tensor.as_ptr())
544 .collect::<Vec<_>>();
545 let ptr = unsafe {
547 ffi::mpsgraph_graph_stack(
548 self.as_ptr(),
549 handles.as_ptr(),
550 handles.len(),
551 axis,
552 cstring_ptr(&name),
553 )
554 };
555 wrap_tensor(ptr)
556 }
557
558#[must_use]
560 pub fn pad(
561 &self,
562 tensor: &Tensor,
563 padding_mode: isize,
564 left_padding: &[isize],
565 right_padding: &[isize],
566 constant_value: f64,
567 name: Option<&str>,
568 ) -> Option<Tensor> {
569 let name = optional_cstring(name);
570 let ptr = unsafe {
572 ffi::mpsgraph_graph_pad(
573 self.as_ptr(),
574 tensor.as_ptr(),
575 padding_mode,
576 left_padding.as_ptr(),
577 left_padding.len(),
578 right_padding.as_ptr(),
579 right_padding.len(),
580 constant_value,
581 cstring_ptr(&name),
582 )
583 };
584 wrap_tensor(ptr)
585 }
586
587#[must_use]
589 pub fn top_k(&self, source: &Tensor, k: usize, name: Option<&str>) -> Option<(Tensor, Tensor)> {
590 let name = optional_cstring(name);
591 let box_handle = unsafe {
593 ffi::mpsgraph_graph_top_k(self.as_ptr(), source.as_ptr(), k, cstring_ptr(&name))
594 };
595 wrap_tensor_pair(box_handle)
596 }
597
598#[must_use]
600 pub fn top_k_tensor(
601 &self,
602 source: &Tensor,
603 k_tensor: &Tensor,
604 name: Option<&str>,
605 ) -> Option<(Tensor, Tensor)> {
606 let name = optional_cstring(name);
607 let box_handle = unsafe {
609 ffi::mpsgraph_graph_top_k_tensor(
610 self.as_ptr(),
611 source.as_ptr(),
612 k_tensor.as_ptr(),
613 cstring_ptr(&name),
614 )
615 };
616 wrap_tensor_pair(box_handle)
617 }
618}