Skip to main content

apple_mpsgraph/
gather.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8    name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
14}
15
16fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
17    if ptr.is_null() {
18        None
19    } else {
20        Some(Tensor::from_raw(ptr))
21    }
22}
23
24impl crate::graph::Graph {
25/// Calls the `MPSGraph` framework counterpart for `gather_nd`.
26    #[must_use]
27    pub fn gather_nd(
28        &self,
29        updates_tensor: &Tensor,
30        indices_tensor: &Tensor,
31        batch_dimensions: usize,
32        name: Option<&str>,
33    ) -> Option<Tensor> {
34        let name = optional_cstring(name);
35        // SAFETY: all handles remain valid for the duration of the call.
36        let ptr = unsafe {
37            ffi::mpsgraph_graph_gather_nd(
38                self.as_ptr(),
39                updates_tensor.as_ptr(),
40                indices_tensor.as_ptr(),
41                batch_dimensions,
42                cstring_ptr(&name),
43            )
44        };
45        wrap_tensor(ptr)
46    }
47
48/// Calls the `MPSGraph` framework counterpart for `gather`.
49    #[must_use]
50    pub fn gather(
51        &self,
52        updates_tensor: &Tensor,
53        indices_tensor: &Tensor,
54        axis: usize,
55        batch_dimensions: usize,
56        name: Option<&str>,
57    ) -> Option<Tensor> {
58        let name = optional_cstring(name);
59        // SAFETY: all handles remain valid for the duration of the call.
60        let ptr = unsafe {
61            ffi::mpsgraph_graph_gather(
62                self.as_ptr(),
63                updates_tensor.as_ptr(),
64                indices_tensor.as_ptr(),
65                axis,
66                batch_dimensions,
67                cstring_ptr(&name),
68            )
69        };
70        wrap_tensor(ptr)
71    }
72
73/// Calls the `MPSGraph` framework counterpart for `gather_along_axis`.
74    #[must_use]
75    pub fn gather_along_axis(
76        &self,
77        axis: isize,
78        updates_tensor: &Tensor,
79        indices_tensor: &Tensor,
80        name: Option<&str>,
81    ) -> Option<Tensor> {
82        let name = optional_cstring(name);
83        // SAFETY: all handles remain valid for the duration of the call.
84        let ptr = unsafe {
85            ffi::mpsgraph_graph_gather_along_axis(
86                self.as_ptr(),
87                axis,
88                updates_tensor.as_ptr(),
89                indices_tensor.as_ptr(),
90                cstring_ptr(&name),
91            )
92        };
93        wrap_tensor(ptr)
94    }
95
96/// Calls the `MPSGraph` framework counterpart for `gather_along_axis_tensor`.
97    #[must_use]
98    pub fn gather_along_axis_tensor(
99        &self,
100        axis_tensor: &Tensor,
101        updates_tensor: &Tensor,
102        indices_tensor: &Tensor,
103        name: Option<&str>,
104    ) -> Option<Tensor> {
105        let name = optional_cstring(name);
106        // SAFETY: all handles remain valid for the duration of the call.
107        let ptr = unsafe {
108            ffi::mpsgraph_graph_gather_along_axis_tensor(
109                self.as_ptr(),
110                axis_tensor.as_ptr(),
111                updates_tensor.as_ptr(),
112                indices_tensor.as_ptr(),
113                cstring_ptr(&name),
114            )
115        };
116        wrap_tensor(ptr)
117    }
118}