1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::{data_type, data_type_size};
4use apple_metal::{MetalBuffer, MetalDevice, MetalTensor};
5use core::ffi::c_void;
6use core::ptr;
7
8fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
9 let element_size = data_type_size(data_type)?;
10 shape
11 .iter()
12 .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
13}
14
15pub struct TensorData {
17 ptr: *mut c_void,
18}
19
20unsafe impl Send for TensorData {}
21unsafe impl Sync for TensorData {}
22
23impl Drop for TensorData {
24 fn drop(&mut self) {
25 if !self.ptr.is_null() {
26 unsafe { ffi::mpsgraph_object_release(self.ptr) };
28 self.ptr = ptr::null_mut();
29 }
30 }
31}
32
33impl TensorData {
34 pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
35 Self { ptr }
36 }
37
38 #[must_use]
40 pub fn from_bytes(
41 device: &MetalDevice,
42 bytes: &[u8],
43 shape: &[usize],
44 data_type: u32,
45 ) -> Option<Self> {
46 let expected = checked_byte_len(shape, data_type)?;
47 if bytes.len() != expected {
48 return None;
49 }
50
51 let ptr = unsafe {
53 ffi::mpsgraph_tensor_data_new_with_bytes(
54 device.as_ptr(),
55 bytes.as_ptr().cast(),
56 bytes.len(),
57 shape.as_ptr(),
58 shape.len(),
59 data_type,
60 )
61 };
62 if ptr.is_null() {
63 None
64 } else {
65 Some(Self { ptr })
66 }
67 }
68
69 #[must_use]
71 pub fn from_f32_slice(device: &MetalDevice, values: &[f32], shape: &[usize]) -> Option<Self> {
72 let bytes = unsafe {
74 core::slice::from_raw_parts(
75 values.as_ptr().cast::<u8>(),
76 core::mem::size_of_val(values),
77 )
78 };
79 Self::from_bytes(device, bytes, shape, data_type::FLOAT32)
80 }
81
82 #[must_use]
84 pub fn from_buffer(buffer: &MetalBuffer, shape: &[usize], data_type: u32) -> Option<Self> {
85 let ptr = unsafe {
87 ffi::mpsgraph_tensor_data_new_with_buffer(
88 buffer.as_ptr(),
89 shape.as_ptr(),
90 shape.len(),
91 data_type,
92 )
93 };
94 if ptr.is_null() {
95 None
96 } else {
97 Some(Self { ptr })
98 }
99 }
100
101 #[must_use]
103 pub fn from_tensor(tensor: &MetalTensor) -> Option<Self> {
104 let ptr = unsafe { ffi::mpsgraph_tensor_data_new_with_tensor(tensor.as_ptr()) };
105 if ptr.is_null() {
106 None
107 } else {
108 Some(Self { ptr })
109 }
110 }
111
112#[must_use]
114 pub const fn as_ptr(&self) -> *mut c_void {
115 self.ptr
116 }
117
118#[must_use]
120 pub fn data_type(&self) -> u32 {
121 unsafe { ffi::mpsgraph_tensor_data_data_type(self.ptr) }
123 }
124
125#[must_use]
127 pub fn shape(&self) -> Vec<usize> {
128 let len = unsafe { ffi::mpsgraph_tensor_data_shape_len(self.ptr) };
130 let mut shape = vec![0_usize; len];
131 if len > 0 {
132 unsafe { ffi::mpsgraph_tensor_data_copy_shape(self.ptr, shape.as_mut_ptr()) };
134 }
135 shape
136 }
137
138#[must_use]
140 pub fn element_count(&self) -> usize {
141 self.shape().iter().product()
142 }
143
144pub fn byte_len(&self) -> Result<usize> {
146 checked_byte_len(&self.shape(), self.data_type())
147 .ok_or_else(|| Error::UnsupportedDataType(self.data_type()))
148 }
149
150pub fn read_bytes(&self) -> Result<Vec<u8>> {
152 let byte_len = self.byte_len()?;
153 let mut bytes = vec![0_u8; byte_len];
154 let ok = unsafe {
156 ffi::mpsgraph_tensor_data_read_bytes(self.ptr, bytes.as_mut_ptr().cast(), byte_len)
157 };
158 if ok {
159 Ok(bytes)
160 } else {
161 Err(Error::OperationFailed("failed to read tensor data"))
162 }
163 }
164
165pub fn read_f32(&self) -> Result<Vec<f32>> {
167 if self.data_type() != data_type::FLOAT32 {
168 return Err(Error::UnsupportedDataType(self.data_type()));
169 }
170
171 let byte_len = self.byte_len()?;
172 let mut values = vec![0.0_f32; byte_len / core::mem::size_of::<f32>()];
173 let ok = unsafe {
175 ffi::mpsgraph_tensor_data_read_bytes(self.ptr, values.as_mut_ptr().cast(), byte_len)
176 };
177 if ok {
178 Ok(values)
179 } else {
180 Err(Error::OperationFailed("failed to read tensor data"))
181 }
182 }
183}