1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::Tensor;
5use apple_metal::MetalDevice;
6use core::ffi::c_void;
7use core::ptr;
8
9fn release_handle(ptr: &mut *mut c_void) {
10 if !ptr.is_null() {
11 unsafe { ffi::mpsgraph_object_release(*ptr) };
13 *ptr = ptr::null_mut();
14 }
15}
16
17fn copy_optional_signed_shape(
18 handle: *mut c_void,
19 has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
20 shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
21 copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
22) -> Option<Vec<isize>> {
23 if unsafe { !has_shape(handle) } {
25 return None;
26 }
27 let len = unsafe { shape_len(handle) };
29 let mut shape = vec![0_isize; len];
30 if len > 0 {
31 unsafe { copy_shape(handle, shape.as_mut_ptr()) };
33 }
34 Some(shape)
35}
36
37fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
38 if handle.is_null() {
39 return Vec::new();
40 }
41
42 let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
44 let mut tensors = Vec::with_capacity(len);
45 for index in 0..len {
46 let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
48 if !tensor.is_null() {
49 tensors.push(Tensor::from_raw(tensor));
50 }
51 }
52 let mut box_handle = handle;
53 release_handle(&mut box_handle);
54 tensors
55}
56
57pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
58 if handle.is_null() {
59 return Vec::new();
60 }
61
62 let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
64 let mut values = Vec::with_capacity(len);
65 for index in 0..len {
66 let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
68 if !value.is_null() {
69 values.push(TensorData::from_raw(value));
70 }
71 }
72 let mut box_handle = handle;
73 release_handle(&mut box_handle);
74 values
75}
76
77pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
78 if handle.is_null() {
79 return Vec::new();
80 }
81
82 let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
84 let mut values = Vec::with_capacity(len);
85 for index in 0..len {
86 let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
88 if !value.is_null() {
89 values.push(ShapedType { ptr: value });
90 }
91 }
92 let mut box_handle = handle;
93 release_handle(&mut box_handle);
94 values
95}
96
97pub mod graph_device_type {
99 pub const METAL: u32 = 0;
100}
101
102pub struct GraphDevice {
104 ptr: *mut c_void,
105}
106
107unsafe impl Send for GraphDevice {}
108unsafe impl Sync for GraphDevice {}
109
110impl Drop for GraphDevice {
111 fn drop(&mut self) {
112 release_handle(&mut self.ptr);
113 }
114}
115
116impl GraphDevice {
117 #[must_use]
119 pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
120 let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
122 if ptr.is_null() {
123 None
124 } else {
125 Some(Self { ptr })
126 }
127 }
128
129 #[must_use]
130 pub const fn as_ptr(&self) -> *mut c_void {
131 self.ptr
132 }
133
134 #[must_use]
136 pub fn device_type(&self) -> u32 {
137 unsafe { ffi::mpsgraph_device_type(self.ptr) }
139 }
140}
141
142pub struct ShapedType {
144 ptr: *mut c_void,
145}
146
147unsafe impl Send for ShapedType {}
148unsafe impl Sync for ShapedType {}
149
150impl Drop for ShapedType {
151 fn drop(&mut self) {
152 release_handle(&mut self.ptr);
153 }
154}
155
156impl ShapedType {
157 #[must_use]
159 pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
160 let (shape_ptr, shape_len) =
161 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
162 let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
164 if ptr.is_null() {
165 None
166 } else {
167 Some(Self { ptr })
168 }
169 }
170
171 #[must_use]
172 pub(crate) const fn as_ptr(&self) -> *mut c_void {
173 self.ptr
174 }
175
176 #[must_use]
178 pub fn shape(&self) -> Option<Vec<isize>> {
179 copy_optional_signed_shape(
180 self.ptr,
181 ffi::mpsgraph_shaped_type_has_shape,
182 ffi::mpsgraph_shaped_type_shape_len,
183 ffi::mpsgraph_shaped_type_copy_shape,
184 )
185 }
186
187 #[must_use]
189 pub fn data_type(&self) -> u32 {
190 unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
192 }
193
194 pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
196 let (shape_ptr, shape_len) =
197 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
198 let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
200 if ok {
201 Ok(())
202 } else {
203 Err(Error::OperationFailed("failed to set shaped type shape"))
204 }
205 }
206
207 pub fn set_data_type(&self, data_type: u32) -> Result<()> {
209 let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
211 if ok {
212 Ok(())
213 } else {
214 Err(Error::OperationFailed(
215 "failed to set shaped type data type",
216 ))
217 }
218 }
219
220 #[must_use]
222 pub fn is_equal(&self, other: Option<&Self>) -> bool {
223 let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
224 unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
226 }
227}
228
229pub struct Operation {
231 ptr: *mut c_void,
232}
233
234unsafe impl Send for Operation {}
235unsafe impl Sync for Operation {}
236
237impl Drop for Operation {
238 fn drop(&mut self) {
239 release_handle(&mut self.ptr);
240 }
241}
242
243impl Operation {
244 #[must_use]
245 pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
246 Self { ptr }
247 }
248
249 #[must_use]
250 pub const fn as_ptr(&self) -> *mut c_void {
251 self.ptr
252 }
253}
254
255impl Tensor {
256 #[must_use]
258 pub fn shape(&self) -> Option<Vec<isize>> {
259 copy_optional_signed_shape(
260 self.as_ptr(),
261 ffi::mpsgraph_tensor_has_shape,
262 ffi::mpsgraph_tensor_shape_len,
263 ffi::mpsgraph_tensor_copy_shape,
264 )
265 }
266
267 #[must_use]
269 pub fn data_type(&self) -> u32 {
270 unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
272 }
273
274 #[must_use]
276 pub fn operation(&self) -> Option<Operation> {
277 let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
279 if ptr.is_null() {
280 None
281 } else {
282 Some(Operation { ptr })
283 }
284 }
285}
286
287impl TensorData {
288 #[must_use]
290 pub fn graph_device_type(&self) -> Option<u32> {
291 let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
293 if ptr.is_null() {
294 return None;
295 }
296 let device = GraphDevice { ptr };
297 Some(device.device_type())
298 }
299}
300
301pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
302 collect_tensor_array_box(handle)
303}