1use crate::ffi;
2use crate::graph::Tensor;
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8 name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
14}
15
16fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
17 if ptr.is_null() {
18 None
19 } else {
20 Some(Tensor::from_raw(ptr))
21 }
22}
23
24impl crate::graph::Graph {
25#[must_use]
27 pub fn gather_nd(
28 &self,
29 updates_tensor: &Tensor,
30 indices_tensor: &Tensor,
31 batch_dimensions: usize,
32 name: Option<&str>,
33 ) -> Option<Tensor> {
34 let name = optional_cstring(name);
35 let ptr = unsafe {
37 ffi::mpsgraph_graph_gather_nd(
38 self.as_ptr(),
39 updates_tensor.as_ptr(),
40 indices_tensor.as_ptr(),
41 batch_dimensions,
42 cstring_ptr(&name),
43 )
44 };
45 wrap_tensor(ptr)
46 }
47
48#[must_use]
50 pub fn gather(
51 &self,
52 updates_tensor: &Tensor,
53 indices_tensor: &Tensor,
54 axis: usize,
55 batch_dimensions: usize,
56 name: Option<&str>,
57 ) -> Option<Tensor> {
58 let name = optional_cstring(name);
59 let ptr = unsafe {
61 ffi::mpsgraph_graph_gather(
62 self.as_ptr(),
63 updates_tensor.as_ptr(),
64 indices_tensor.as_ptr(),
65 axis,
66 batch_dimensions,
67 cstring_ptr(&name),
68 )
69 };
70 wrap_tensor(ptr)
71 }
72
73#[must_use]
75 pub fn gather_along_axis(
76 &self,
77 axis: isize,
78 updates_tensor: &Tensor,
79 indices_tensor: &Tensor,
80 name: Option<&str>,
81 ) -> Option<Tensor> {
82 let name = optional_cstring(name);
83 let ptr = unsafe {
85 ffi::mpsgraph_graph_gather_along_axis(
86 self.as_ptr(),
87 axis,
88 updates_tensor.as_ptr(),
89 indices_tensor.as_ptr(),
90 cstring_ptr(&name),
91 )
92 };
93 wrap_tensor(ptr)
94 }
95
96#[must_use]
98 pub fn gather_along_axis_tensor(
99 &self,
100 axis_tensor: &Tensor,
101 updates_tensor: &Tensor,
102 indices_tensor: &Tensor,
103 name: Option<&str>,
104 ) -> Option<Tensor> {
105 let name = optional_cstring(name);
106 let ptr = unsafe {
108 ffi::mpsgraph_graph_gather_along_axis_tensor(
109 self.as_ptr(),
110 axis_tensor.as_ptr(),
111 updates_tensor.as_ptr(),
112 indices_tensor.as_ptr(),
113 cstring_ptr(&name),
114 )
115 };
116 wrap_tensor(ptr)
117 }
118}