1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::Tensor;
5use apple_metal::MetalDevice;
6use core::ffi::c_void;
7use core::ptr;
8
9fn release_handle(ptr: &mut *mut c_void) {
10 if !ptr.is_null() {
11 unsafe { ffi::mpsgraph_object_release(*ptr) };
13 *ptr = ptr::null_mut();
14 }
15}
16
17fn copy_optional_signed_shape(
18 handle: *mut c_void,
19 has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
20 shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
21 copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
22) -> Option<Vec<isize>> {
23 if unsafe { !has_shape(handle) } {
25 return None;
26 }
27 let len = unsafe { shape_len(handle) };
29 let mut shape = vec![0_isize; len];
30 if len > 0 {
31 unsafe { copy_shape(handle, shape.as_mut_ptr()) };
33 }
34 Some(shape)
35}
36
37fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
38 if handle.is_null() {
39 return Vec::new();
40 }
41
42 let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
44 let mut tensors = Vec::with_capacity(len);
45 for index in 0..len {
46 let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
48 if !tensor.is_null() {
49 tensors.push(Tensor::from_raw(tensor));
50 }
51 }
52 let mut box_handle = handle;
53 release_handle(&mut box_handle);
54 tensors
55}
56
57pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
58 if handle.is_null() {
59 return Vec::new();
60 }
61
62 let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
64 let mut values = Vec::with_capacity(len);
65 for index in 0..len {
66 let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
68 if !value.is_null() {
69 values.push(TensorData::from_raw(value));
70 }
71 }
72 let mut box_handle = handle;
73 release_handle(&mut box_handle);
74 values
75}
76
77pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
78 if handle.is_null() {
79 return Vec::new();
80 }
81
82 let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
84 let mut values = Vec::with_capacity(len);
85 for index in 0..len {
86 let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
88 if !value.is_null() {
89 values.push(ShapedType { ptr: value });
90 }
91 }
92 let mut box_handle = handle;
93 release_handle(&mut box_handle);
94 values
95}
96
97pub mod graph_device_type {
99 pub const METAL: u32 = 0;
100}
101
102pub struct GraphDevice {
104 ptr: *mut c_void,
105}
106
107unsafe impl Send for GraphDevice {}
108unsafe impl Sync for GraphDevice {}
109
110impl Drop for GraphDevice {
111 fn drop(&mut self) {
112 release_handle(&mut self.ptr);
113 }
114}
115
116impl GraphDevice {
117 #[must_use]
119 pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
120 let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
122 if ptr.is_null() {
123 None
124 } else {
125 Some(Self { ptr })
126 }
127 }
128
129 #[must_use]
130 pub const fn as_ptr(&self) -> *mut c_void {
131 self.ptr
132 }
133
134 #[must_use]
136 pub fn device_type(&self) -> u32 {
137 unsafe { ffi::mpsgraph_device_type(self.ptr) }
139 }
140}
141
142pub struct ShapedType {
144 ptr: *mut c_void,
145}
146
147unsafe impl Send for ShapedType {}
148unsafe impl Sync for ShapedType {}
149
150impl Drop for ShapedType {
151 fn drop(&mut self) {
152 release_handle(&mut self.ptr);
153 }
154}
155
156impl ShapedType {
157 #[must_use]
159 pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
160 let (shape_ptr, shape_len) = shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
161 let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
163 if ptr.is_null() {
164 None
165 } else {
166 Some(Self { ptr })
167 }
168 }
169
170 #[must_use]
171 pub(crate) const fn as_ptr(&self) -> *mut c_void {
172 self.ptr
173 }
174
175 #[must_use]
177 pub fn shape(&self) -> Option<Vec<isize>> {
178 copy_optional_signed_shape(
179 self.ptr,
180 ffi::mpsgraph_shaped_type_has_shape,
181 ffi::mpsgraph_shaped_type_shape_len,
182 ffi::mpsgraph_shaped_type_copy_shape,
183 )
184 }
185
186 #[must_use]
188 pub fn data_type(&self) -> u32 {
189 unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
191 }
192
193 pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
195 let (shape_ptr, shape_len) = shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
196 let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
198 if ok {
199 Ok(())
200 } else {
201 Err(Error::OperationFailed("failed to set shaped type shape"))
202 }
203 }
204
205 pub fn set_data_type(&self, data_type: u32) -> Result<()> {
207 let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
209 if ok {
210 Ok(())
211 } else {
212 Err(Error::OperationFailed("failed to set shaped type data type"))
213 }
214 }
215
216 #[must_use]
218 pub fn is_equal(&self, other: Option<&Self>) -> bool {
219 let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
220 unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
222 }
223}
224
225pub struct Operation {
227 ptr: *mut c_void,
228}
229
230unsafe impl Send for Operation {}
231unsafe impl Sync for Operation {}
232
233impl Drop for Operation {
234 fn drop(&mut self) {
235 release_handle(&mut self.ptr);
236 }
237}
238
239impl Operation {
240 #[must_use]
241 pub const fn as_ptr(&self) -> *mut c_void {
242 self.ptr
243 }
244}
245
246impl Tensor {
247 #[must_use]
249 pub fn shape(&self) -> Option<Vec<isize>> {
250 copy_optional_signed_shape(
251 self.as_ptr(),
252 ffi::mpsgraph_tensor_has_shape,
253 ffi::mpsgraph_tensor_shape_len,
254 ffi::mpsgraph_tensor_copy_shape,
255 )
256 }
257
258 #[must_use]
260 pub fn data_type(&self) -> u32 {
261 unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
263 }
264
265 #[must_use]
267 pub fn operation(&self) -> Option<Operation> {
268 let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
270 if ptr.is_null() {
271 None
272 } else {
273 Some(Operation { ptr })
274 }
275 }
276}
277
278impl TensorData {
279 #[must_use]
281 pub fn graph_device_type(&self) -> Option<u32> {
282 let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
284 if ptr.is_null() {
285 return None;
286 }
287 let device = GraphDevice { ptr };
288 Some(device.device_type())
289 }
290}
291
292pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
293 collect_tensor_array_box(handle)
294}