Skip to main content

coreml_native/
tensor.rs

1//! Tensor types for zero-copy data exchange with CoreML.
2
3use crate::error::{Error, ErrorKind, Result};
4
5/// Numeric data types supported by CoreML tensors.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum DataType {
8    Float16,
9    Float32,
10    Float64,
11    Int32,
12    Int16,
13    Int8,
14    UInt32,
15    UInt16,
16    UInt8,
17}
18
19impl DataType {
20    /// Returns the size in bytes of a single element of this data type.
21    pub fn byte_size(self) -> usize {
22        match self {
23            Self::Float16 => 2,
24            Self::Float32 => 4,
25            Self::Float64 => 8,
26            Self::Int32 => 4,
27            Self::Int16 => 2,
28            Self::Int8 => 1,
29            Self::UInt32 => 4,
30            Self::UInt16 => 2,
31            Self::UInt8 => 1,
32        }
33    }
34}
35
36impl std::fmt::Display for DataType {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::Float16 => write!(f, "Float16"),
40            Self::Float32 => write!(f, "Float32"),
41            Self::Float64 => write!(f, "Float64"),
42            Self::Int32 => write!(f, "Int32"),
43            Self::Int16 => write!(f, "Int16"),
44            Self::Int8 => write!(f, "Int8"),
45            Self::UInt32 => write!(f, "UInt32"),
46            Self::UInt16 => write!(f, "UInt16"),
47            Self::UInt8 => write!(f, "UInt8"),
48        }
49    }
50}
51
52/// Returns the total number of elements for a tensor with the given shape.
53pub fn element_count(shape: &[usize]) -> usize {
54    shape.iter().copied().product()
55}
56
57/// Computes row-major strides for a tensor with the given shape.
58pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
59    let ndims = shape.len();
60    if ndims == 0 {
61        return vec![];
62    }
63    let mut strides = vec![1usize; ndims];
64    for i in (0..ndims - 1).rev() {
65        strides[i] = strides[i + 1] * shape[i + 1];
66    }
67    strides
68}
69
70/// Validates that `data_len` matches the element count implied by `shape`.
71pub fn validate_shape(data_len: usize, shape: &[usize]) -> Result<()> {
72    if shape.is_empty() {
73        return Err(Error::new(ErrorKind::InvalidShape, "shape must not be empty"));
74    }
75    if shape.contains(&0) {
76        return Err(Error::new(
77            ErrorKind::InvalidShape,
78            format!("shape contains zero dimension: {shape:?}"),
79        ));
80    }
81    let expected = element_count(shape);
82    if data_len != expected {
83        return Err(Error::new(
84            ErrorKind::InvalidShape,
85            format!("data length {data_len} does not match shape {shape:?} (expected {expected} elements)"),
86        ));
87    }
88    Ok(())
89}
90
91// ─── Apple platform implementation ──────────────────────────────────────────
92
93#[cfg(target_vendor = "apple")]
94mod platform {
95    use super::*;
96    use crate::ffi;
97    use objc2::rc::Retained;
98    use objc2::AnyThread;
99    use objc2_core_ml::MLMultiArray;
100    use std::ffi::c_void;
101    use std::ptr::NonNull;
102
103    pub struct BorrowedTensor<'a> {
104        pub(crate) inner: Retained<MLMultiArray>,
105        shape: Vec<usize>,
106        data_type: DataType,
107        _marker: std::marker::PhantomData<&'a [u8]>,
108    }
109
110    impl std::fmt::Debug for BorrowedTensor<'_> {
111        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112            f.debug_struct("BorrowedTensor")
113                .field("shape", &self.shape)
114                .field("data_type", &self.data_type)
115                .finish()
116        }
117    }
118
119    impl<'a> BorrowedTensor<'a> {
120        pub fn from_f32(data: &'a [f32], shape: &[usize]) -> Result<Self> {
121            validate_shape(data.len(), shape)?;
122            let ns_shape = ffi::shape_to_nsarray(shape);
123            let strides = compute_strides(shape);
124            let ns_strides = ffi::shape_to_nsarray(&strides);
125            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Float32));
126
127            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
128                Error::new(ErrorKind::TensorCreate, "null data pointer")
129            })?;
130
131            let inner = unsafe {
132                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
133                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
134                )
135            }
136            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
137
138            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Float32, _marker: std::marker::PhantomData })
139        }
140
141        pub fn from_i32(data: &'a [i32], shape: &[usize]) -> Result<Self> {
142            validate_shape(data.len(), shape)?;
143            let ns_shape = ffi::shape_to_nsarray(shape);
144            let strides = compute_strides(shape);
145            let ns_strides = ffi::shape_to_nsarray(&strides);
146            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Int32));
147
148            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
149                Error::new(ErrorKind::TensorCreate, "null data pointer")
150            })?;
151
152            let inner = unsafe {
153                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
154                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
155                )
156            }
157            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
158
159            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Int32, _marker: std::marker::PhantomData })
160        }
161
162        pub fn from_f64(data: &'a [f64], shape: &[usize]) -> Result<Self> {
163            validate_shape(data.len(), shape)?;
164            let ns_shape = ffi::shape_to_nsarray(shape);
165            let strides = compute_strides(shape);
166            let ns_strides = ffi::shape_to_nsarray(&strides);
167            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Float64));
168
169            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
170                Error::new(ErrorKind::TensorCreate, "null data pointer")
171            })?;
172
173            let inner = unsafe {
174                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
175                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
176                )
177            }
178            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
179
180            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Float64, _marker: std::marker::PhantomData })
181        }
182
183        pub fn from_f16_bits(data: &'a [u16], shape: &[usize]) -> Result<Self> {
184            validate_shape(data.len(), shape)?;
185            let ns_shape = ffi::shape_to_nsarray(shape);
186            let strides = compute_strides(shape);
187            let ns_strides = ffi::shape_to_nsarray(&strides);
188            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Float16));
189
190            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
191                Error::new(ErrorKind::TensorCreate, "null data pointer")
192            })?;
193
194            let inner = unsafe {
195                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
196                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
197                )
198            }
199            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
200
201            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Float16, _marker: std::marker::PhantomData })
202        }
203
204        pub fn from_i16(data: &'a [i16], shape: &[usize]) -> Result<Self> {
205            validate_shape(data.len(), shape)?;
206            let ns_shape = ffi::shape_to_nsarray(shape);
207            let strides = compute_strides(shape);
208            let ns_strides = ffi::shape_to_nsarray(&strides);
209            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Int16));
210
211            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
212                Error::new(ErrorKind::TensorCreate, "null data pointer")
213            })?;
214
215            let inner = unsafe {
216                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
217                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
218                )
219            }
220            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
221
222            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Int16, _marker: std::marker::PhantomData })
223        }
224
225        pub fn from_i8(data: &'a [i8], shape: &[usize]) -> Result<Self> {
226            validate_shape(data.len(), shape)?;
227            let ns_shape = ffi::shape_to_nsarray(shape);
228            let strides = compute_strides(shape);
229            let ns_strides = ffi::shape_to_nsarray(&strides);
230            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Int8));
231
232            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
233                Error::new(ErrorKind::TensorCreate, "null data pointer")
234            })?;
235
236            let inner = unsafe {
237                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
238                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
239                )
240            }
241            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
242
243            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Int8, _marker: std::marker::PhantomData })
244        }
245
246        pub fn from_u32(data: &'a [u32], shape: &[usize]) -> Result<Self> {
247            validate_shape(data.len(), shape)?;
248            let ns_shape = ffi::shape_to_nsarray(shape);
249            let strides = compute_strides(shape);
250            let ns_strides = ffi::shape_to_nsarray(&strides);
251            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::UInt32));
252
253            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
254                Error::new(ErrorKind::TensorCreate, "null data pointer")
255            })?;
256
257            let inner = unsafe {
258                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
259                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
260                )
261            }
262            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
263
264            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::UInt32, _marker: std::marker::PhantomData })
265        }
266
267        pub fn from_u16(data: &'a [u16], shape: &[usize]) -> Result<Self> {
268            validate_shape(data.len(), shape)?;
269            let ns_shape = ffi::shape_to_nsarray(shape);
270            let strides = compute_strides(shape);
271            let ns_strides = ffi::shape_to_nsarray(&strides);
272            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::UInt16));
273
274            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
275                Error::new(ErrorKind::TensorCreate, "null data pointer")
276            })?;
277
278            let inner = unsafe {
279                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
280                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
281                )
282            }
283            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
284
285            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::UInt16, _marker: std::marker::PhantomData })
286        }
287
288        pub fn from_u8(data: &'a [u8], shape: &[usize]) -> Result<Self> {
289            validate_shape(data.len(), shape)?;
290            let ns_shape = ffi::shape_to_nsarray(shape);
291            let strides = compute_strides(shape);
292            let ns_strides = ffi::shape_to_nsarray(&strides);
293            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::UInt8));
294
295            let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
296                Error::new(ErrorKind::TensorCreate, "null data pointer")
297            })?;
298
299            let inner = unsafe {
300                MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
301                    MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
302                )
303            }
304            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
305
306            Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::UInt8, _marker: std::marker::PhantomData })
307        }
308
309        pub fn shape(&self) -> &[usize] { &self.shape }
310        pub fn data_type(&self) -> DataType { self.data_type }
311        pub fn element_count(&self) -> usize { element_count(&self.shape) }
312        pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
313    }
314
315    unsafe impl Send for BorrowedTensor<'_> {}
316
317    pub struct OwnedTensor {
318        pub(crate) inner: Retained<MLMultiArray>,
319        shape: Vec<usize>,
320        data_type: DataType,
321    }
322
323    impl std::fmt::Debug for OwnedTensor {
324        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325            f.debug_struct("OwnedTensor")
326                .field("shape", &self.shape)
327                .field("data_type", &self.data_type)
328                .finish()
329        }
330    }
331
332    impl OwnedTensor {
333        pub fn zeros(data_type: DataType, shape: &[usize]) -> Result<Self> {
334            if shape.is_empty() {
335                return Err(Error::new(ErrorKind::InvalidShape, "shape must not be empty"));
336            }
337            if shape.contains(&0) {
338                return Err(Error::new(ErrorKind::InvalidShape, format!("shape contains zero dimension: {shape:?}")));
339            }
340
341            let ns_shape = ffi::shape_to_nsarray(shape);
342            let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(data_type));
343
344            let inner = unsafe {
345                MLMultiArray::initWithShape_dataType_error(MLMultiArray::alloc(), &ns_shape, ml_dtype)
346            }
347            .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
348
349            Ok(Self { inner, shape: shape.to_vec(), data_type })
350        }
351
352        pub fn shape(&self) -> &[usize] { &self.shape }
353        pub fn data_type(&self) -> DataType { self.data_type }
354        pub fn element_count(&self) -> usize { element_count(&self.shape) }
355        pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
356
357        #[allow(deprecated)]
358        pub fn copy_to_f32(&self, buf: &mut [f32]) -> Result<()> {
359            if self.data_type != DataType::Float32 {
360                return Err(Error::new(ErrorKind::TensorCreate, format!("tensor is {:?}, not Float32", self.data_type)));
361            }
362            let count = self.element_count();
363            if buf.len() < count {
364                return Err(Error::new(ErrorKind::InvalidShape, format!("buffer length {} < element count {count}", buf.len())));
365            }
366            unsafe {
367                let ptr = self.inner.dataPointer();
368                let src = ptr.as_ptr() as *const f32;
369                std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
370            }
371            Ok(())
372        }
373
374        pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
375            let mut buf = vec![0.0f32; self.element_count()];
376            self.copy_to_f32(&mut buf)?;
377            Ok(buf)
378        }
379
380        /// Copy output data as i32 values.
381        #[allow(deprecated)]
382        pub fn copy_to_i32(&self, buf: &mut [i32]) -> Result<()> {
383            if self.data_type != DataType::Int32 {
384                return Err(Error::new(ErrorKind::TensorCreate, format!("tensor is {:?}, not Int32", self.data_type)));
385            }
386            let count = self.element_count();
387            if buf.len() < count {
388                return Err(Error::new(ErrorKind::InvalidShape, format!("buffer length {} < element count {count}", buf.len())));
389            }
390            unsafe {
391                let ptr = self.inner.dataPointer();
392                let src = ptr.as_ptr() as *const i32;
393                std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
394            }
395            Ok(())
396        }
397
398        /// Convert to Vec<i32>.
399        pub fn to_vec_i32(&self) -> Result<Vec<i32>> {
400            let mut buf = vec![0i32; self.element_count()];
401            self.copy_to_i32(&mut buf)?;
402            Ok(buf)
403        }
404
405        /// Copy output data as f64 values.
406        #[allow(deprecated)]
407        pub fn copy_to_f64(&self, buf: &mut [f64]) -> Result<()> {
408            if self.data_type != DataType::Float64 {
409                return Err(Error::new(ErrorKind::TensorCreate, format!("tensor is {:?}, not Float64", self.data_type)));
410            }
411            let count = self.element_count();
412            if buf.len() < count {
413                return Err(Error::new(ErrorKind::InvalidShape, format!("buffer length {} < element count {count}", buf.len())));
414            }
415            unsafe {
416                let ptr = self.inner.dataPointer();
417                let src = ptr.as_ptr() as *const f64;
418                std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
419            }
420            Ok(())
421        }
422
423        /// Convert to Vec<f64>.
424        pub fn to_vec_f64(&self) -> Result<Vec<f64>> {
425            let mut buf = vec![0.0f64; self.element_count()];
426            self.copy_to_f64(&mut buf)?;
427            Ok(buf)
428        }
429
430        /// Returns a Vec<u8> copy of the raw data.
431        #[allow(deprecated)]
432        pub fn to_raw_bytes(&self) -> Result<Vec<u8>> {
433            let byte_count = self.element_count() * self.data_type.byte_size();
434            let mut buf = vec![0u8; byte_count];
435            unsafe {
436                let ptr = self.inner.dataPointer();
437                let src = ptr.as_ptr() as *const u8;
438                std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), byte_count);
439            }
440            Ok(buf)
441        }
442    }
443
444    unsafe impl Send for OwnedTensor {}
445}
446
447// ─── Non-Apple stubs ────────────────────────────────────────────────────────
448
449#[cfg(not(target_vendor = "apple"))]
450mod platform {
451    use super::*;
452
453    #[derive(Debug)]
454    pub struct BorrowedTensor<'a> {
455        shape: Vec<usize>,
456        data_type: DataType,
457        _marker: std::marker::PhantomData<&'a [u8]>,
458    }
459
460    impl<'a> BorrowedTensor<'a> {
461        pub fn from_f32(_data: &'a [f32], shape: &[usize]) -> Result<Self> {
462            validate_shape(_data.len(), shape)?;
463            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
464        }
465        pub fn from_i32(_data: &'a [i32], shape: &[usize]) -> Result<Self> {
466            validate_shape(_data.len(), shape)?;
467            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
468        }
469        pub fn from_f64(_data: &'a [f64], shape: &[usize]) -> Result<Self> {
470            validate_shape(_data.len(), shape)?;
471            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
472        }
473        pub fn from_f16_bits(_data: &'a [u16], shape: &[usize]) -> Result<Self> {
474            validate_shape(_data.len(), shape)?;
475            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
476        }
477        pub fn from_i16(_data: &'a [i16], shape: &[usize]) -> Result<Self> {
478            validate_shape(_data.len(), shape)?;
479            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
480        }
481        pub fn from_i8(_data: &'a [i8], shape: &[usize]) -> Result<Self> {
482            validate_shape(_data.len(), shape)?;
483            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
484        }
485        pub fn from_u32(_data: &'a [u32], shape: &[usize]) -> Result<Self> {
486            validate_shape(_data.len(), shape)?;
487            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
488        }
489        pub fn from_u16(_data: &'a [u16], shape: &[usize]) -> Result<Self> {
490            validate_shape(_data.len(), shape)?;
491            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
492        }
493        pub fn from_u8(_data: &'a [u8], shape: &[usize]) -> Result<Self> {
494            validate_shape(_data.len(), shape)?;
495            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
496        }
497        pub fn shape(&self) -> &[usize] { &self.shape }
498        pub fn data_type(&self) -> DataType { self.data_type }
499        pub fn element_count(&self) -> usize { element_count(&self.shape) }
500        pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
501    }
502
503    #[derive(Debug)]
504    pub struct OwnedTensor {
505        shape: Vec<usize>,
506        data_type: DataType,
507    }
508
509    impl OwnedTensor {
510        pub fn zeros(_data_type: DataType, shape: &[usize]) -> Result<Self> {
511            if shape.is_empty() || shape.iter().any(|&d| d == 0) {
512                return Err(Error::new(ErrorKind::InvalidShape, format!("invalid shape: {shape:?}")));
513            }
514            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
515        }
516        pub fn shape(&self) -> &[usize] { &self.shape }
517        pub fn data_type(&self) -> DataType { self.data_type }
518        pub fn element_count(&self) -> usize { element_count(&self.shape) }
519        pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
520        pub fn copy_to_f32(&self, _buf: &mut [f32]) -> Result<()> {
521            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
522        }
523        pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
524            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
525        }
526        pub fn copy_to_i32(&self, _buf: &mut [i32]) -> Result<()> {
527            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
528        }
529        pub fn to_vec_i32(&self) -> Result<Vec<i32>> {
530            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
531        }
532        pub fn copy_to_f64(&self, _buf: &mut [f64]) -> Result<()> {
533            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
534        }
535        pub fn to_vec_f64(&self) -> Result<Vec<f64>> {
536            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
537        }
538        pub fn to_raw_bytes(&self) -> Result<Vec<u8>> {
539            Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
540        }
541    }
542}
543
544pub use platform::{BorrowedTensor, OwnedTensor};
545
546/// Trait for types that can be used as prediction inputs.
547///
548/// Implemented by both `BorrowedTensor` and `OwnedTensor`.
549#[cfg(target_vendor = "apple")]
550pub trait AsMultiArray {
551    fn as_ml_multi_array(&self) -> &objc2::rc::Retained<objc2_core_ml::MLMultiArray>;
552}
553
554#[cfg(target_vendor = "apple")]
555impl AsMultiArray for BorrowedTensor<'_> {
556    fn as_ml_multi_array(&self) -> &objc2::rc::Retained<objc2_core_ml::MLMultiArray> {
557        &self.inner
558    }
559}
560
561#[cfg(target_vendor = "apple")]
562impl AsMultiArray for OwnedTensor {
563    fn as_ml_multi_array(&self) -> &objc2::rc::Retained<objc2_core_ml::MLMultiArray> {
564        &self.inner
565    }
566}
567
568#[cfg(not(target_vendor = "apple"))]
569pub trait AsMultiArray {}
570
571#[cfg(not(target_vendor = "apple"))]
572impl AsMultiArray for BorrowedTensor<'_> {}
573
574#[cfg(not(target_vendor = "apple"))]
575impl AsMultiArray for OwnedTensor {}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn datatype_byte_sizes() {
583        assert_eq!(DataType::Float16.byte_size(), 2);
584        assert_eq!(DataType::Float32.byte_size(), 4);
585        assert_eq!(DataType::Float64.byte_size(), 8);
586        assert_eq!(DataType::Int32.byte_size(), 4);
587        assert_eq!(DataType::Int16.byte_size(), 2);
588        assert_eq!(DataType::Int8.byte_size(), 1);
589        assert_eq!(DataType::UInt32.byte_size(), 4);
590        assert_eq!(DataType::UInt16.byte_size(), 2);
591        assert_eq!(DataType::UInt8.byte_size(), 1);
592    }
593
594    #[test]
595    fn datatype_display() {
596        assert_eq!(format!("{}", DataType::Float32), "Float32");
597    }
598
599    #[test]
600    fn element_count_works() {
601        assert_eq!(element_count(&[1, 128, 500]), 64000);
602    }
603
604    #[test]
605    fn compute_strides_row_major() {
606        assert_eq!(compute_strides(&[1, 128, 500]), vec![64000, 500, 1]);
607    }
608
609    #[test]
610    fn validate_shape_correct() {
611        assert!(validate_shape(64000, &[1, 128, 500]).is_ok());
612    }
613
614    #[test]
615    fn validate_shape_mismatch() {
616        let err = validate_shape(100, &[1, 128, 500]).unwrap_err();
617        assert_eq!(err.kind(), &ErrorKind::InvalidShape);
618    }
619
620    #[test]
621    fn validate_shape_empty() {
622        assert!(validate_shape(0, &[]).is_err());
623    }
624
625    #[test]
626    fn validate_shape_zero_dim() {
627        assert!(validate_shape(0, &[1, 0, 500]).is_err());
628    }
629
630    #[cfg(target_vendor = "apple")]
631    mod apple_tests {
632        use super::super::*;
633
634        #[test]
635        fn borrowed_tensor_from_f32() {
636            let data = vec![1.0f32; 6];
637            let tensor = BorrowedTensor::from_f32(&data, &[2, 3]).unwrap();
638            assert_eq!(tensor.shape(), &[2, 3]);
639            assert_eq!(tensor.data_type(), DataType::Float32);
640            assert_eq!(tensor.element_count(), 6);
641            assert_eq!(tensor.size_bytes(), 24);
642        }
643
644        #[test]
645        fn borrowed_tensor_shape_mismatch() {
646            let data = vec![1.0f32; 5];
647            assert!(BorrowedTensor::from_f32(&data, &[2, 3]).is_err());
648        }
649
650        #[test]
651        fn borrowed_tensor_from_i32() {
652            let data = vec![42i32; 4];
653            let tensor = BorrowedTensor::from_i32(&data, &[2, 2]).unwrap();
654            assert_eq!(tensor.data_type(), DataType::Int32);
655        }
656
657        #[test]
658        fn owned_tensor_zeros() {
659            let tensor = OwnedTensor::zeros(DataType::Float32, &[2, 3]).unwrap();
660            assert_eq!(tensor.shape(), &[2, 3]);
661            let data = tensor.to_vec_f32().unwrap();
662            assert_eq!(data, vec![0.0f32; 6]);
663        }
664
665        #[test]
666        fn owned_tensor_empty_shape_fails() {
667            assert!(OwnedTensor::zeros(DataType::Float32, &[]).is_err());
668        }
669
670        #[test]
671        fn owned_tensor_zero_dim_fails() {
672            assert!(OwnedTensor::zeros(DataType::Float32, &[1, 0]).is_err());
673        }
674
675        #[test]
676        fn owned_tensor_copy_wrong_type() {
677            let tensor = OwnedTensor::zeros(DataType::Int32, &[4]).unwrap();
678            let mut buf = vec![0.0f32; 4];
679            assert!(tensor.copy_to_f32(&mut buf).is_err());
680        }
681
682        #[test]
683        fn borrowed_tensor_from_f64() {
684            let data = vec![1.0f64; 6];
685            let tensor = BorrowedTensor::from_f64(&data, &[2, 3]).unwrap();
686            assert_eq!(tensor.data_type(), DataType::Float64);
687        }
688
689        #[test]
690        fn borrowed_tensor_from_f16_bits() {
691            // f16 representation of 1.0 is 0x3C00
692            let data = vec![0x3C00u16; 4];
693            let tensor = BorrowedTensor::from_f16_bits(&data, &[2, 2]).unwrap();
694            assert_eq!(tensor.data_type(), DataType::Float16);
695        }
696
697        #[test]
698        fn owned_tensor_i32_roundtrip() {
699            let tensor = OwnedTensor::zeros(DataType::Int32, &[4]).unwrap();
700            let data = tensor.to_vec_i32().unwrap();
701            assert_eq!(data, vec![0i32; 4]);
702        }
703
704        #[test]
705        fn owned_tensor_raw_bytes() {
706            let tensor = OwnedTensor::zeros(DataType::Float32, &[2]).unwrap();
707            let bytes = tensor.to_raw_bytes().unwrap();
708            assert_eq!(bytes.len(), 8); // 2 elements * 4 bytes
709        }
710
711        #[test]
712        fn borrowed_tensor_from_i16() {
713            let data = vec![1i16; 4];
714            let tensor = BorrowedTensor::from_i16(&data, &[2, 2]).unwrap();
715            assert_eq!(tensor.data_type(), DataType::Int16);
716            assert_eq!(tensor.element_count(), 4);
717            assert_eq!(tensor.size_bytes(), 8);
718        }
719
720        #[test]
721        fn borrowed_tensor_from_i8() {
722            let data = vec![1i8; 4];
723            let tensor = BorrowedTensor::from_i8(&data, &[2, 2]).unwrap();
724            assert_eq!(tensor.data_type(), DataType::Int8);
725            assert_eq!(tensor.element_count(), 4);
726            assert_eq!(tensor.size_bytes(), 4);
727        }
728
729        #[test]
730        fn borrowed_tensor_from_u32() {
731            let data = vec![1u32; 4];
732            let tensor = BorrowedTensor::from_u32(&data, &[2, 2]).unwrap();
733            assert_eq!(tensor.data_type(), DataType::UInt32);
734            assert_eq!(tensor.element_count(), 4);
735            assert_eq!(tensor.size_bytes(), 16);
736        }
737
738        #[test]
739        fn borrowed_tensor_from_u16() {
740            let data = vec![1u16; 4];
741            let tensor = BorrowedTensor::from_u16(&data, &[2, 2]).unwrap();
742            assert_eq!(tensor.data_type(), DataType::UInt16);
743            assert_eq!(tensor.element_count(), 4);
744            assert_eq!(tensor.size_bytes(), 8);
745        }
746
747        #[test]
748        fn borrowed_tensor_from_u8() {
749            let data = vec![1u8; 4];
750            let tensor = BorrowedTensor::from_u8(&data, &[2, 2]).unwrap();
751            assert_eq!(tensor.data_type(), DataType::UInt8);
752            assert_eq!(tensor.element_count(), 4);
753            assert_eq!(tensor.size_bytes(), 4);
754        }
755    }
756}