Skip to main content

apple_mpsgraph/
data.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::{data_type, data_type_size};
4use apple_metal::{MetalBuffer, MetalDevice, MetalTensor};
5use core::ffi::c_void;
6use core::ptr;
7
8fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
9    let element_size = data_type_size(data_type)?;
10    shape
11        .iter()
12        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
13}
14
15/// Safe owner for an Objective-C `MPSGraphTensorData`.
16pub struct TensorData {
17    ptr: *mut c_void,
18}
19
20unsafe impl Send for TensorData {}
21unsafe impl Sync for TensorData {}
22
23impl Drop for TensorData {
24    fn drop(&mut self) {
25        if !self.ptr.is_null() {
26            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
27            unsafe { ffi::mpsgraph_object_release(self.ptr) };
28            self.ptr = ptr::null_mut();
29        }
30    }
31}
32
33impl TensorData {
34    pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
35        Self { ptr }
36    }
37
38    /// Build tensor data by copying CPU bytes onto the given Metal device.
39    #[must_use]
40    pub fn from_bytes(
41        device: &MetalDevice,
42        bytes: &[u8],
43        shape: &[usize],
44        data_type: u32,
45    ) -> Option<Self> {
46        let expected = checked_byte_len(shape, data_type)?;
47        if bytes.len() != expected {
48            return None;
49        }
50
51        // SAFETY: The device handle and byte slice stay valid for the duration of the FFI call.
52        let ptr = unsafe {
53            ffi::mpsgraph_tensor_data_new_with_bytes(
54                device.as_ptr(),
55                bytes.as_ptr().cast(),
56                bytes.len(),
57                shape.as_ptr(),
58                shape.len(),
59                data_type,
60            )
61        };
62        if ptr.is_null() {
63            None
64        } else {
65            Some(Self { ptr })
66        }
67    }
68
69    /// Build tensor data from a contiguous `f32` slice.
70    #[must_use]
71    pub fn from_f32_slice(device: &MetalDevice, values: &[f32], shape: &[usize]) -> Option<Self> {
72        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
73        let bytes = unsafe {
74            core::slice::from_raw_parts(
75                values.as_ptr().cast::<u8>(),
76                core::mem::size_of_val(values),
77            )
78        };
79        Self::from_bytes(device, bytes, shape, data_type::FLOAT32)
80    }
81
82    /// Alias an existing `MTLBuffer` as tensor data.
83    #[must_use]
84    pub fn from_buffer(buffer: &MetalBuffer, shape: &[usize], data_type: u32) -> Option<Self> {
85        // SAFETY: The buffer handle remains valid for the duration of the FFI call.
86        let ptr = unsafe {
87            ffi::mpsgraph_tensor_data_new_with_buffer(
88                buffer.as_ptr(),
89                shape.as_ptr(),
90                shape.len(),
91                data_type,
92            )
93        };
94        if ptr.is_null() {
95            None
96        } else {
97            Some(Self { ptr })
98        }
99    }
100
101    /// Alias an existing `MTLTensor` as tensor data.
102    #[must_use]
103    pub fn from_tensor(tensor: &MetalTensor) -> Option<Self> {
104        let ptr = unsafe { ffi::mpsgraph_tensor_data_new_with_tensor(tensor.as_ptr()) };
105        if ptr.is_null() {
106            None
107        } else {
108            Some(Self { ptr })
109        }
110    }
111
112/// Mirrors the `MPSGraph` framework constant `fn`.
113    #[must_use]
114    pub const fn as_ptr(&self) -> *mut c_void {
115        self.ptr
116    }
117
118/// Calls the `MPSGraph` framework counterpart for `data_type`.
119    #[must_use]
120    pub fn data_type(&self) -> u32 {
121        // SAFETY: `self.ptr` is a valid `MPSGraphTensorData` while `self` is alive.
122        unsafe { ffi::mpsgraph_tensor_data_data_type(self.ptr) }
123    }
124
125/// Calls the `MPSGraph` framework counterpart for `shape`.
126    #[must_use]
127    pub fn shape(&self) -> Vec<usize> {
128        // SAFETY: `self.ptr` is a valid `MPSGraphTensorData` while `self` is alive.
129        let len = unsafe { ffi::mpsgraph_tensor_data_shape_len(self.ptr) };
130        let mut shape = vec![0_usize; len];
131        if len > 0 {
132            // SAFETY: `shape` has capacity for exactly `len` values and the tensor data outlives the call.
133            unsafe { ffi::mpsgraph_tensor_data_copy_shape(self.ptr, shape.as_mut_ptr()) };
134        }
135        shape
136    }
137
138/// Calls the `MPSGraph` framework counterpart for `element_count`.
139    #[must_use]
140    pub fn element_count(&self) -> usize {
141        self.shape().iter().product()
142    }
143
144/// Calls the `MPSGraph` framework counterpart for `byte_len`.
145    pub fn byte_len(&self) -> Result<usize> {
146        checked_byte_len(&self.shape(), self.data_type())
147            .ok_or_else(|| Error::UnsupportedDataType(self.data_type()))
148    }
149
150/// Calls the `MPSGraph` framework counterpart for `read_bytes`.
151    pub fn read_bytes(&self) -> Result<Vec<u8>> {
152        let byte_len = self.byte_len()?;
153        let mut bytes = vec![0_u8; byte_len];
154        // SAFETY: `bytes` is valid for writes of `byte_len` bytes and the tensor data outlives the call.
155        let ok = unsafe {
156            ffi::mpsgraph_tensor_data_read_bytes(self.ptr, bytes.as_mut_ptr().cast(), byte_len)
157        };
158        if ok {
159            Ok(bytes)
160        } else {
161            Err(Error::OperationFailed("failed to read tensor data"))
162        }
163    }
164
165/// Calls the `MPSGraph` framework counterpart for `read_f32`.
166    pub fn read_f32(&self) -> Result<Vec<f32>> {
167        if self.data_type() != data_type::FLOAT32 {
168            return Err(Error::UnsupportedDataType(self.data_type()));
169        }
170
171        let byte_len = self.byte_len()?;
172        let mut values = vec![0.0_f32; byte_len / core::mem::size_of::<f32>()];
173        // SAFETY: `values` is a contiguous `Vec<f32>` with `byte_len` bytes of backing storage.
174        let ok = unsafe {
175            ffi::mpsgraph_tensor_data_read_bytes(self.ptr, values.as_mut_ptr().cast(), byte_len)
176        };
177        if ok {
178            Ok(values)
179        } else {
180            Err(Error::OperationFailed("failed to read tensor data"))
181        }
182    }
183}