1use crate::ffi;
2use crate::graph::Tensor;
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8 name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
14}
15
16fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
17 if ptr.is_null() {
18 None
19 } else {
20 Some(Tensor::from_raw(ptr))
21 }
22}
23
24impl crate::graph::Graph {
25 #[must_use]
26 pub fn gather_nd(
27 &self,
28 updates_tensor: &Tensor,
29 indices_tensor: &Tensor,
30 batch_dimensions: usize,
31 name: Option<&str>,
32 ) -> Option<Tensor> {
33 let name = optional_cstring(name);
34 let ptr = unsafe {
36 ffi::mpsgraph_graph_gather_nd(
37 self.as_ptr(),
38 updates_tensor.as_ptr(),
39 indices_tensor.as_ptr(),
40 batch_dimensions,
41 cstring_ptr(&name),
42 )
43 };
44 wrap_tensor(ptr)
45 }
46
47 #[must_use]
48 pub fn gather(
49 &self,
50 updates_tensor: &Tensor,
51 indices_tensor: &Tensor,
52 axis: usize,
53 batch_dimensions: usize,
54 name: Option<&str>,
55 ) -> Option<Tensor> {
56 let name = optional_cstring(name);
57 let ptr = unsafe {
59 ffi::mpsgraph_graph_gather(
60 self.as_ptr(),
61 updates_tensor.as_ptr(),
62 indices_tensor.as_ptr(),
63 axis,
64 batch_dimensions,
65 cstring_ptr(&name),
66 )
67 };
68 wrap_tensor(ptr)
69 }
70
71 #[must_use]
72 pub fn gather_along_axis(
73 &self,
74 axis: isize,
75 updates_tensor: &Tensor,
76 indices_tensor: &Tensor,
77 name: Option<&str>,
78 ) -> Option<Tensor> {
79 let name = optional_cstring(name);
80 let ptr = unsafe {
82 ffi::mpsgraph_graph_gather_along_axis(
83 self.as_ptr(),
84 axis,
85 updates_tensor.as_ptr(),
86 indices_tensor.as_ptr(),
87 cstring_ptr(&name),
88 )
89 };
90 wrap_tensor(ptr)
91 }
92
93 #[must_use]
94 pub fn gather_along_axis_tensor(
95 &self,
96 axis_tensor: &Tensor,
97 updates_tensor: &Tensor,
98 indices_tensor: &Tensor,
99 name: Option<&str>,
100 ) -> Option<Tensor> {
101 let name = optional_cstring(name);
102 let ptr = unsafe {
104 ffi::mpsgraph_graph_gather_along_axis_tensor(
105 self.as_ptr(),
106 axis_tensor.as_ptr(),
107 updates_tensor.as_ptr(),
108 indices_tensor.as_ptr(),
109 cstring_ptr(&name),
110 )
111 };
112 wrap_tensor(ptr)
113 }
114}