Skip to main content

apple_mpsgraph/
gather.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8    name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
14}
15
16fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
17    if ptr.is_null() {
18        None
19    } else {
20        Some(Tensor::from_raw(ptr))
21    }
22}
23
24impl crate::graph::Graph {
25    #[must_use]
26    pub fn gather_nd(
27        &self,
28        updates_tensor: &Tensor,
29        indices_tensor: &Tensor,
30        batch_dimensions: usize,
31        name: Option<&str>,
32    ) -> Option<Tensor> {
33        let name = optional_cstring(name);
34        // SAFETY: all handles remain valid for the duration of the call.
35        let ptr = unsafe {
36            ffi::mpsgraph_graph_gather_nd(
37                self.as_ptr(),
38                updates_tensor.as_ptr(),
39                indices_tensor.as_ptr(),
40                batch_dimensions,
41                cstring_ptr(&name),
42            )
43        };
44        wrap_tensor(ptr)
45    }
46
47    #[must_use]
48    pub fn gather(
49        &self,
50        updates_tensor: &Tensor,
51        indices_tensor: &Tensor,
52        axis: usize,
53        batch_dimensions: usize,
54        name: Option<&str>,
55    ) -> Option<Tensor> {
56        let name = optional_cstring(name);
57        // SAFETY: all handles remain valid for the duration of the call.
58        let ptr = unsafe {
59            ffi::mpsgraph_graph_gather(
60                self.as_ptr(),
61                updates_tensor.as_ptr(),
62                indices_tensor.as_ptr(),
63                axis,
64                batch_dimensions,
65                cstring_ptr(&name),
66            )
67        };
68        wrap_tensor(ptr)
69    }
70
71    #[must_use]
72    pub fn gather_along_axis(
73        &self,
74        axis: isize,
75        updates_tensor: &Tensor,
76        indices_tensor: &Tensor,
77        name: Option<&str>,
78    ) -> Option<Tensor> {
79        let name = optional_cstring(name);
80        // SAFETY: all handles remain valid for the duration of the call.
81        let ptr = unsafe {
82            ffi::mpsgraph_graph_gather_along_axis(
83                self.as_ptr(),
84                axis,
85                updates_tensor.as_ptr(),
86                indices_tensor.as_ptr(),
87                cstring_ptr(&name),
88            )
89        };
90        wrap_tensor(ptr)
91    }
92
93    #[must_use]
94    pub fn gather_along_axis_tensor(
95        &self,
96        axis_tensor: &Tensor,
97        updates_tensor: &Tensor,
98        indices_tensor: &Tensor,
99        name: Option<&str>,
100    ) -> Option<Tensor> {
101        let name = optional_cstring(name);
102        // SAFETY: all handles remain valid for the duration of the call.
103        let ptr = unsafe {
104            ffi::mpsgraph_graph_gather_along_axis_tensor(
105                self.as_ptr(),
106                axis_tensor.as_ptr(),
107                updates_tensor.as_ptr(),
108                indices_tensor.as_ptr(),
109                cstring_ptr(&name),
110            )
111        };
112        wrap_tensor(ptr)
113    }
114}