1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::Tensor;
5use apple_metal::MetalDevice;
6use core::ffi::c_void;
7use core::ptr;
8
9fn release_handle(ptr: &mut *mut c_void) {
10 if !ptr.is_null() {
11 unsafe { ffi::mpsgraph_object_release(*ptr) };
13 *ptr = ptr::null_mut();
14 }
15}
16
17fn copy_optional_signed_shape(
18 handle: *mut c_void,
19 has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
20 shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
21 copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
22) -> Option<Vec<isize>> {
23 if unsafe { !has_shape(handle) } {
25 return None;
26 }
27 let len = unsafe { shape_len(handle) };
29 let mut shape = vec![0_isize; len];
30 if len > 0 {
31 unsafe { copy_shape(handle, shape.as_mut_ptr()) };
33 }
34 Some(shape)
35}
36
37fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
38 if handle.is_null() {
39 return Vec::new();
40 }
41
42 let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
44 let mut tensors = Vec::with_capacity(len);
45 for index in 0..len {
46 let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
48 if !tensor.is_null() {
49 tensors.push(Tensor::from_raw(tensor));
50 }
51 }
52 let mut box_handle = handle;
53 release_handle(&mut box_handle);
54 tensors
55}
56
57pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
58 if handle.is_null() {
59 return Vec::new();
60 }
61
62 let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
64 let mut values = Vec::with_capacity(len);
65 for index in 0..len {
66 let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
68 if !value.is_null() {
69 values.push(TensorData::from_raw(value));
70 }
71 }
72 let mut box_handle = handle;
73 release_handle(&mut box_handle);
74 values
75}
76
77pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
78 if handle.is_null() {
79 return Vec::new();
80 }
81
82 let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
84 let mut values = Vec::with_capacity(len);
85 for index in 0..len {
86 let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
88 if !value.is_null() {
89 values.push(ShapedType { ptr: value });
90 }
91 }
92 let mut box_handle = handle;
93 release_handle(&mut box_handle);
94 values
95}
96
97pub mod graph_device_type {
99pub const METAL: u32 = 0;
101}
102
103pub struct GraphDevice {
105 ptr: *mut c_void,
106}
107
108unsafe impl Send for GraphDevice {}
109unsafe impl Sync for GraphDevice {}
110
111impl Drop for GraphDevice {
112 fn drop(&mut self) {
113 release_handle(&mut self.ptr);
114 }
115}
116
117impl GraphDevice {
118 #[must_use]
120 pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
121 let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
123 if ptr.is_null() {
124 None
125 } else {
126 Some(Self { ptr })
127 }
128 }
129
130#[must_use]
132 pub const fn as_ptr(&self) -> *mut c_void {
133 self.ptr
134 }
135
136 #[must_use]
138 pub fn device_type(&self) -> u32 {
139 unsafe { ffi::mpsgraph_device_type(self.ptr) }
141 }
142}
143
144pub struct ShapedType {
146 ptr: *mut c_void,
147}
148
149unsafe impl Send for ShapedType {}
150unsafe impl Sync for ShapedType {}
151
152impl Drop for ShapedType {
153 fn drop(&mut self) {
154 release_handle(&mut self.ptr);
155 }
156}
157
158impl ShapedType {
159 #[must_use]
161 pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
162 let (shape_ptr, shape_len) =
163 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
164 let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
166 if ptr.is_null() {
167 None
168 } else {
169 Some(Self { ptr })
170 }
171 }
172
173 #[must_use]
174 pub(crate) const fn as_ptr(&self) -> *mut c_void {
175 self.ptr
176 }
177
178 #[must_use]
180 pub fn shape(&self) -> Option<Vec<isize>> {
181 copy_optional_signed_shape(
182 self.ptr,
183 ffi::mpsgraph_shaped_type_has_shape,
184 ffi::mpsgraph_shaped_type_shape_len,
185 ffi::mpsgraph_shaped_type_copy_shape,
186 )
187 }
188
189 #[must_use]
191 pub fn data_type(&self) -> u32 {
192 unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
194 }
195
196 pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
198 let (shape_ptr, shape_len) =
199 shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
200 let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
202 if ok {
203 Ok(())
204 } else {
205 Err(Error::OperationFailed("failed to set shaped type shape"))
206 }
207 }
208
209 pub fn set_data_type(&self, data_type: u32) -> Result<()> {
211 let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
213 if ok {
214 Ok(())
215 } else {
216 Err(Error::OperationFailed(
217 "failed to set shaped type data type",
218 ))
219 }
220 }
221
222 #[must_use]
224 pub fn is_equal(&self, other: Option<&Self>) -> bool {
225 let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
226 unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
228 }
229}
230
231pub struct Operation {
233 ptr: *mut c_void,
234}
235
236unsafe impl Send for Operation {}
237unsafe impl Sync for Operation {}
238
239impl Drop for Operation {
240 fn drop(&mut self) {
241 release_handle(&mut self.ptr);
242 }
243}
244
245impl Operation {
246 #[must_use]
247 pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
248 Self { ptr }
249 }
250
251#[must_use]
253 pub const fn as_ptr(&self) -> *mut c_void {
254 self.ptr
255 }
256}
257
258impl Tensor {
259 #[must_use]
261 pub fn shape(&self) -> Option<Vec<isize>> {
262 copy_optional_signed_shape(
263 self.as_ptr(),
264 ffi::mpsgraph_tensor_has_shape,
265 ffi::mpsgraph_tensor_shape_len,
266 ffi::mpsgraph_tensor_copy_shape,
267 )
268 }
269
270 #[must_use]
272 pub fn data_type(&self) -> u32 {
273 unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
275 }
276
277 #[must_use]
279 pub fn operation(&self) -> Option<Operation> {
280 let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
282 if ptr.is_null() {
283 None
284 } else {
285 Some(Operation { ptr })
286 }
287 }
288}
289
290impl TensorData {
291 #[must_use]
293 pub fn graph_device_type(&self) -> Option<u32> {
294 let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
296 if ptr.is_null() {
297 return None;
298 }
299 let device = GraphDevice { ptr };
300 Some(device.device_type())
301 }
302}
303
304pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
305 collect_tensor_array_box(handle)
306}