1use crate::error::{Error, ErrorKind, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum DataType {
8 Float16,
9 Float32,
10 Float64,
11 Int32,
12 Int16,
13 Int8,
14 UInt32,
15 UInt16,
16 UInt8,
17}
18
19impl DataType {
20 pub fn byte_size(self) -> usize {
22 match self {
23 Self::Float16 => 2,
24 Self::Float32 => 4,
25 Self::Float64 => 8,
26 Self::Int32 => 4,
27 Self::Int16 => 2,
28 Self::Int8 => 1,
29 Self::UInt32 => 4,
30 Self::UInt16 => 2,
31 Self::UInt8 => 1,
32 }
33 }
34}
35
36impl std::fmt::Display for DataType {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Float16 => write!(f, "Float16"),
40 Self::Float32 => write!(f, "Float32"),
41 Self::Float64 => write!(f, "Float64"),
42 Self::Int32 => write!(f, "Int32"),
43 Self::Int16 => write!(f, "Int16"),
44 Self::Int8 => write!(f, "Int8"),
45 Self::UInt32 => write!(f, "UInt32"),
46 Self::UInt16 => write!(f, "UInt16"),
47 Self::UInt8 => write!(f, "UInt8"),
48 }
49 }
50}
51
52pub fn element_count(shape: &[usize]) -> usize {
54 shape.iter().copied().product()
55}
56
57pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
59 let ndims = shape.len();
60 if ndims == 0 {
61 return vec![];
62 }
63 let mut strides = vec![1usize; ndims];
64 for i in (0..ndims - 1).rev() {
65 strides[i] = strides[i + 1] * shape[i + 1];
66 }
67 strides
68}
69
70pub fn validate_shape(data_len: usize, shape: &[usize]) -> Result<()> {
72 if shape.is_empty() {
73 return Err(Error::new(ErrorKind::InvalidShape, "shape must not be empty"));
74 }
75 if shape.contains(&0) {
76 return Err(Error::new(
77 ErrorKind::InvalidShape,
78 format!("shape contains zero dimension: {shape:?}"),
79 ));
80 }
81 let expected = element_count(shape);
82 if data_len != expected {
83 return Err(Error::new(
84 ErrorKind::InvalidShape,
85 format!("data length {data_len} does not match shape {shape:?} (expected {expected} elements)"),
86 ));
87 }
88 Ok(())
89}
90
91#[cfg(target_vendor = "apple")]
94mod platform {
95 use super::*;
96 use crate::ffi;
97 use objc2::rc::Retained;
98 use objc2::AnyThread;
99 use objc2_core_ml::MLMultiArray;
100 use std::ffi::c_void;
101 use std::ptr::NonNull;
102
103 pub struct BorrowedTensor<'a> {
104 pub(crate) inner: Retained<MLMultiArray>,
105 shape: Vec<usize>,
106 data_type: DataType,
107 _marker: std::marker::PhantomData<&'a [u8]>,
108 }
109
110 impl std::fmt::Debug for BorrowedTensor<'_> {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("BorrowedTensor")
113 .field("shape", &self.shape)
114 .field("data_type", &self.data_type)
115 .finish()
116 }
117 }
118
119 impl<'a> BorrowedTensor<'a> {
120 pub fn from_f32(data: &'a [f32], shape: &[usize]) -> Result<Self> {
121 validate_shape(data.len(), shape)?;
122 let ns_shape = ffi::shape_to_nsarray(shape);
123 let strides = compute_strides(shape);
124 let ns_strides = ffi::shape_to_nsarray(&strides);
125 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Float32));
126
127 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
128 Error::new(ErrorKind::TensorCreate, "null data pointer")
129 })?;
130
131 let inner = unsafe {
132 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
133 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
134 )
135 }
136 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
137
138 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Float32, _marker: std::marker::PhantomData })
139 }
140
141 pub fn from_i32(data: &'a [i32], shape: &[usize]) -> Result<Self> {
142 validate_shape(data.len(), shape)?;
143 let ns_shape = ffi::shape_to_nsarray(shape);
144 let strides = compute_strides(shape);
145 let ns_strides = ffi::shape_to_nsarray(&strides);
146 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Int32));
147
148 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
149 Error::new(ErrorKind::TensorCreate, "null data pointer")
150 })?;
151
152 let inner = unsafe {
153 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
154 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
155 )
156 }
157 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
158
159 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Int32, _marker: std::marker::PhantomData })
160 }
161
162 pub fn from_f64(data: &'a [f64], shape: &[usize]) -> Result<Self> {
163 validate_shape(data.len(), shape)?;
164 let ns_shape = ffi::shape_to_nsarray(shape);
165 let strides = compute_strides(shape);
166 let ns_strides = ffi::shape_to_nsarray(&strides);
167 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Float64));
168
169 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
170 Error::new(ErrorKind::TensorCreate, "null data pointer")
171 })?;
172
173 let inner = unsafe {
174 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
175 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
176 )
177 }
178 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
179
180 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Float64, _marker: std::marker::PhantomData })
181 }
182
183 pub fn from_f16_bits(data: &'a [u16], shape: &[usize]) -> Result<Self> {
184 validate_shape(data.len(), shape)?;
185 let ns_shape = ffi::shape_to_nsarray(shape);
186 let strides = compute_strides(shape);
187 let ns_strides = ffi::shape_to_nsarray(&strides);
188 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Float16));
189
190 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
191 Error::new(ErrorKind::TensorCreate, "null data pointer")
192 })?;
193
194 let inner = unsafe {
195 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
196 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
197 )
198 }
199 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
200
201 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Float16, _marker: std::marker::PhantomData })
202 }
203
204 pub fn from_i16(data: &'a [i16], shape: &[usize]) -> Result<Self> {
205 validate_shape(data.len(), shape)?;
206 let ns_shape = ffi::shape_to_nsarray(shape);
207 let strides = compute_strides(shape);
208 let ns_strides = ffi::shape_to_nsarray(&strides);
209 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Int16));
210
211 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
212 Error::new(ErrorKind::TensorCreate, "null data pointer")
213 })?;
214
215 let inner = unsafe {
216 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
217 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
218 )
219 }
220 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
221
222 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Int16, _marker: std::marker::PhantomData })
223 }
224
225 pub fn from_i8(data: &'a [i8], shape: &[usize]) -> Result<Self> {
226 validate_shape(data.len(), shape)?;
227 let ns_shape = ffi::shape_to_nsarray(shape);
228 let strides = compute_strides(shape);
229 let ns_strides = ffi::shape_to_nsarray(&strides);
230 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::Int8));
231
232 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
233 Error::new(ErrorKind::TensorCreate, "null data pointer")
234 })?;
235
236 let inner = unsafe {
237 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
238 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
239 )
240 }
241 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
242
243 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::Int8, _marker: std::marker::PhantomData })
244 }
245
246 pub fn from_u32(data: &'a [u32], shape: &[usize]) -> Result<Self> {
247 validate_shape(data.len(), shape)?;
248 let ns_shape = ffi::shape_to_nsarray(shape);
249 let strides = compute_strides(shape);
250 let ns_strides = ffi::shape_to_nsarray(&strides);
251 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::UInt32));
252
253 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
254 Error::new(ErrorKind::TensorCreate, "null data pointer")
255 })?;
256
257 let inner = unsafe {
258 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
259 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
260 )
261 }
262 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
263
264 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::UInt32, _marker: std::marker::PhantomData })
265 }
266
267 pub fn from_u16(data: &'a [u16], shape: &[usize]) -> Result<Self> {
268 validate_shape(data.len(), shape)?;
269 let ns_shape = ffi::shape_to_nsarray(shape);
270 let strides = compute_strides(shape);
271 let ns_strides = ffi::shape_to_nsarray(&strides);
272 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::UInt16));
273
274 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
275 Error::new(ErrorKind::TensorCreate, "null data pointer")
276 })?;
277
278 let inner = unsafe {
279 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
280 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
281 )
282 }
283 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
284
285 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::UInt16, _marker: std::marker::PhantomData })
286 }
287
288 pub fn from_u8(data: &'a [u8], shape: &[usize]) -> Result<Self> {
289 validate_shape(data.len(), shape)?;
290 let ns_shape = ffi::shape_to_nsarray(shape);
291 let strides = compute_strides(shape);
292 let ns_strides = ffi::shape_to_nsarray(&strides);
293 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(DataType::UInt8));
294
295 let ptr = NonNull::new(data.as_ptr() as *mut c_void).ok_or_else(|| {
296 Error::new(ErrorKind::TensorCreate, "null data pointer")
297 })?;
298
299 let inner = unsafe {
300 MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
301 MLMultiArray::alloc(), ptr, &ns_shape, ml_dtype, &ns_strides, None,
302 )
303 }
304 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
305
306 Ok(Self { inner, shape: shape.to_vec(), data_type: DataType::UInt8, _marker: std::marker::PhantomData })
307 }
308
309 pub fn shape(&self) -> &[usize] { &self.shape }
310 pub fn data_type(&self) -> DataType { self.data_type }
311 pub fn element_count(&self) -> usize { element_count(&self.shape) }
312 pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
313 }
314
315 unsafe impl Send for BorrowedTensor<'_> {}
316
317 pub struct OwnedTensor {
318 pub(crate) inner: Retained<MLMultiArray>,
319 shape: Vec<usize>,
320 data_type: DataType,
321 }
322
323 impl std::fmt::Debug for OwnedTensor {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("OwnedTensor")
326 .field("shape", &self.shape)
327 .field("data_type", &self.data_type)
328 .finish()
329 }
330 }
331
332 impl OwnedTensor {
333 pub fn zeros(data_type: DataType, shape: &[usize]) -> Result<Self> {
334 if shape.is_empty() {
335 return Err(Error::new(ErrorKind::InvalidShape, "shape must not be empty"));
336 }
337 if shape.contains(&0) {
338 return Err(Error::new(ErrorKind::InvalidShape, format!("shape contains zero dimension: {shape:?}")));
339 }
340
341 let ns_shape = ffi::shape_to_nsarray(shape);
342 let ml_dtype = objc2_core_ml::MLMultiArrayDataType(ffi::datatype_to_ml(data_type));
343
344 let inner = unsafe {
345 MLMultiArray::initWithShape_dataType_error(MLMultiArray::alloc(), &ns_shape, ml_dtype)
346 }
347 .map_err(|e| Error::from_nserror(ErrorKind::TensorCreate, &e))?;
348
349 Ok(Self { inner, shape: shape.to_vec(), data_type })
350 }
351
352 pub fn shape(&self) -> &[usize] { &self.shape }
353 pub fn data_type(&self) -> DataType { self.data_type }
354 pub fn element_count(&self) -> usize { element_count(&self.shape) }
355 pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
356
357 #[allow(deprecated)]
358 pub fn copy_to_f32(&self, buf: &mut [f32]) -> Result<()> {
359 if self.data_type != DataType::Float32 {
360 return Err(Error::new(ErrorKind::TensorCreate, format!("tensor is {:?}, not Float32", self.data_type)));
361 }
362 let count = self.element_count();
363 if buf.len() < count {
364 return Err(Error::new(ErrorKind::InvalidShape, format!("buffer length {} < element count {count}", buf.len())));
365 }
366 unsafe {
367 let ptr = self.inner.dataPointer();
368 let src = ptr.as_ptr() as *const f32;
369 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
370 }
371 Ok(())
372 }
373
374 pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
375 let mut buf = vec![0.0f32; self.element_count()];
376 self.copy_to_f32(&mut buf)?;
377 Ok(buf)
378 }
379
380 #[allow(deprecated)]
382 pub fn copy_to_i32(&self, buf: &mut [i32]) -> Result<()> {
383 if self.data_type != DataType::Int32 {
384 return Err(Error::new(ErrorKind::TensorCreate, format!("tensor is {:?}, not Int32", self.data_type)));
385 }
386 let count = self.element_count();
387 if buf.len() < count {
388 return Err(Error::new(ErrorKind::InvalidShape, format!("buffer length {} < element count {count}", buf.len())));
389 }
390 unsafe {
391 let ptr = self.inner.dataPointer();
392 let src = ptr.as_ptr() as *const i32;
393 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
394 }
395 Ok(())
396 }
397
398 pub fn to_vec_i32(&self) -> Result<Vec<i32>> {
400 let mut buf = vec![0i32; self.element_count()];
401 self.copy_to_i32(&mut buf)?;
402 Ok(buf)
403 }
404
405 #[allow(deprecated)]
407 pub fn copy_to_f64(&self, buf: &mut [f64]) -> Result<()> {
408 if self.data_type != DataType::Float64 {
409 return Err(Error::new(ErrorKind::TensorCreate, format!("tensor is {:?}, not Float64", self.data_type)));
410 }
411 let count = self.element_count();
412 if buf.len() < count {
413 return Err(Error::new(ErrorKind::InvalidShape, format!("buffer length {} < element count {count}", buf.len())));
414 }
415 unsafe {
416 let ptr = self.inner.dataPointer();
417 let src = ptr.as_ptr() as *const f64;
418 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
419 }
420 Ok(())
421 }
422
423 pub fn to_vec_f64(&self) -> Result<Vec<f64>> {
425 let mut buf = vec![0.0f64; self.element_count()];
426 self.copy_to_f64(&mut buf)?;
427 Ok(buf)
428 }
429
430 #[allow(deprecated)]
432 pub fn to_raw_bytes(&self) -> Result<Vec<u8>> {
433 let byte_count = self.element_count() * self.data_type.byte_size();
434 let mut buf = vec![0u8; byte_count];
435 unsafe {
436 let ptr = self.inner.dataPointer();
437 let src = ptr.as_ptr() as *const u8;
438 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), byte_count);
439 }
440 Ok(buf)
441 }
442 }
443
444 unsafe impl Send for OwnedTensor {}
445}
446
447#[cfg(not(target_vendor = "apple"))]
450mod platform {
451 use super::*;
452
453 #[derive(Debug)]
454 pub struct BorrowedTensor<'a> {
455 shape: Vec<usize>,
456 data_type: DataType,
457 _marker: std::marker::PhantomData<&'a [u8]>,
458 }
459
460 impl<'a> BorrowedTensor<'a> {
461 pub fn from_f32(_data: &'a [f32], shape: &[usize]) -> Result<Self> {
462 validate_shape(_data.len(), shape)?;
463 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
464 }
465 pub fn from_i32(_data: &'a [i32], shape: &[usize]) -> Result<Self> {
466 validate_shape(_data.len(), shape)?;
467 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
468 }
469 pub fn from_f64(_data: &'a [f64], shape: &[usize]) -> Result<Self> {
470 validate_shape(_data.len(), shape)?;
471 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
472 }
473 pub fn from_f16_bits(_data: &'a [u16], shape: &[usize]) -> Result<Self> {
474 validate_shape(_data.len(), shape)?;
475 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
476 }
477 pub fn from_i16(_data: &'a [i16], shape: &[usize]) -> Result<Self> {
478 validate_shape(_data.len(), shape)?;
479 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
480 }
481 pub fn from_i8(_data: &'a [i8], shape: &[usize]) -> Result<Self> {
482 validate_shape(_data.len(), shape)?;
483 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
484 }
485 pub fn from_u32(_data: &'a [u32], shape: &[usize]) -> Result<Self> {
486 validate_shape(_data.len(), shape)?;
487 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
488 }
489 pub fn from_u16(_data: &'a [u16], shape: &[usize]) -> Result<Self> {
490 validate_shape(_data.len(), shape)?;
491 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
492 }
493 pub fn from_u8(_data: &'a [u8], shape: &[usize]) -> Result<Self> {
494 validate_shape(_data.len(), shape)?;
495 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
496 }
497 pub fn shape(&self) -> &[usize] { &self.shape }
498 pub fn data_type(&self) -> DataType { self.data_type }
499 pub fn element_count(&self) -> usize { element_count(&self.shape) }
500 pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
501 }
502
503 #[derive(Debug)]
504 pub struct OwnedTensor {
505 shape: Vec<usize>,
506 data_type: DataType,
507 }
508
509 impl OwnedTensor {
510 pub fn zeros(_data_type: DataType, shape: &[usize]) -> Result<Self> {
511 if shape.is_empty() || shape.iter().any(|&d| d == 0) {
512 return Err(Error::new(ErrorKind::InvalidShape, format!("invalid shape: {shape:?}")));
513 }
514 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
515 }
516 pub fn shape(&self) -> &[usize] { &self.shape }
517 pub fn data_type(&self) -> DataType { self.data_type }
518 pub fn element_count(&self) -> usize { element_count(&self.shape) }
519 pub fn size_bytes(&self) -> usize { self.element_count() * self.data_type.byte_size() }
520 pub fn copy_to_f32(&self, _buf: &mut [f32]) -> Result<()> {
521 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
522 }
523 pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
524 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
525 }
526 pub fn copy_to_i32(&self, _buf: &mut [i32]) -> Result<()> {
527 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
528 }
529 pub fn to_vec_i32(&self) -> Result<Vec<i32>> {
530 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
531 }
532 pub fn copy_to_f64(&self, _buf: &mut [f64]) -> Result<()> {
533 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
534 }
535 pub fn to_vec_f64(&self) -> Result<Vec<f64>> {
536 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
537 }
538 pub fn to_raw_bytes(&self) -> Result<Vec<u8>> {
539 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
540 }
541 }
542}
543
544pub use platform::{BorrowedTensor, OwnedTensor};
545
546#[cfg(target_vendor = "apple")]
550pub trait AsMultiArray {
551 fn as_ml_multi_array(&self) -> &objc2::rc::Retained<objc2_core_ml::MLMultiArray>;
552}
553
554#[cfg(target_vendor = "apple")]
555impl AsMultiArray for BorrowedTensor<'_> {
556 fn as_ml_multi_array(&self) -> &objc2::rc::Retained<objc2_core_ml::MLMultiArray> {
557 &self.inner
558 }
559}
560
561#[cfg(target_vendor = "apple")]
562impl AsMultiArray for OwnedTensor {
563 fn as_ml_multi_array(&self) -> &objc2::rc::Retained<objc2_core_ml::MLMultiArray> {
564 &self.inner
565 }
566}
567
568#[cfg(not(target_vendor = "apple"))]
569pub trait AsMultiArray {}
570
571#[cfg(not(target_vendor = "apple"))]
572impl AsMultiArray for BorrowedTensor<'_> {}
573
574#[cfg(not(target_vendor = "apple"))]
575impl AsMultiArray for OwnedTensor {}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn datatype_byte_sizes() {
583 assert_eq!(DataType::Float16.byte_size(), 2);
584 assert_eq!(DataType::Float32.byte_size(), 4);
585 assert_eq!(DataType::Float64.byte_size(), 8);
586 assert_eq!(DataType::Int32.byte_size(), 4);
587 assert_eq!(DataType::Int16.byte_size(), 2);
588 assert_eq!(DataType::Int8.byte_size(), 1);
589 assert_eq!(DataType::UInt32.byte_size(), 4);
590 assert_eq!(DataType::UInt16.byte_size(), 2);
591 assert_eq!(DataType::UInt8.byte_size(), 1);
592 }
593
594 #[test]
595 fn datatype_display() {
596 assert_eq!(format!("{}", DataType::Float32), "Float32");
597 }
598
599 #[test]
600 fn element_count_works() {
601 assert_eq!(element_count(&[1, 128, 500]), 64000);
602 }
603
604 #[test]
605 fn compute_strides_row_major() {
606 assert_eq!(compute_strides(&[1, 128, 500]), vec![64000, 500, 1]);
607 }
608
609 #[test]
610 fn validate_shape_correct() {
611 assert!(validate_shape(64000, &[1, 128, 500]).is_ok());
612 }
613
614 #[test]
615 fn validate_shape_mismatch() {
616 let err = validate_shape(100, &[1, 128, 500]).unwrap_err();
617 assert_eq!(err.kind(), &ErrorKind::InvalidShape);
618 }
619
620 #[test]
621 fn validate_shape_empty() {
622 assert!(validate_shape(0, &[]).is_err());
623 }
624
625 #[test]
626 fn validate_shape_zero_dim() {
627 assert!(validate_shape(0, &[1, 0, 500]).is_err());
628 }
629
630 #[cfg(target_vendor = "apple")]
631 mod apple_tests {
632 use super::super::*;
633
634 #[test]
635 fn borrowed_tensor_from_f32() {
636 let data = vec![1.0f32; 6];
637 let tensor = BorrowedTensor::from_f32(&data, &[2, 3]).unwrap();
638 assert_eq!(tensor.shape(), &[2, 3]);
639 assert_eq!(tensor.data_type(), DataType::Float32);
640 assert_eq!(tensor.element_count(), 6);
641 assert_eq!(tensor.size_bytes(), 24);
642 }
643
644 #[test]
645 fn borrowed_tensor_shape_mismatch() {
646 let data = vec![1.0f32; 5];
647 assert!(BorrowedTensor::from_f32(&data, &[2, 3]).is_err());
648 }
649
650 #[test]
651 fn borrowed_tensor_from_i32() {
652 let data = vec![42i32; 4];
653 let tensor = BorrowedTensor::from_i32(&data, &[2, 2]).unwrap();
654 assert_eq!(tensor.data_type(), DataType::Int32);
655 }
656
657 #[test]
658 fn owned_tensor_zeros() {
659 let tensor = OwnedTensor::zeros(DataType::Float32, &[2, 3]).unwrap();
660 assert_eq!(tensor.shape(), &[2, 3]);
661 let data = tensor.to_vec_f32().unwrap();
662 assert_eq!(data, vec![0.0f32; 6]);
663 }
664
665 #[test]
666 fn owned_tensor_empty_shape_fails() {
667 assert!(OwnedTensor::zeros(DataType::Float32, &[]).is_err());
668 }
669
670 #[test]
671 fn owned_tensor_zero_dim_fails() {
672 assert!(OwnedTensor::zeros(DataType::Float32, &[1, 0]).is_err());
673 }
674
675 #[test]
676 fn owned_tensor_copy_wrong_type() {
677 let tensor = OwnedTensor::zeros(DataType::Int32, &[4]).unwrap();
678 let mut buf = vec![0.0f32; 4];
679 assert!(tensor.copy_to_f32(&mut buf).is_err());
680 }
681
682 #[test]
683 fn borrowed_tensor_from_f64() {
684 let data = vec![1.0f64; 6];
685 let tensor = BorrowedTensor::from_f64(&data, &[2, 3]).unwrap();
686 assert_eq!(tensor.data_type(), DataType::Float64);
687 }
688
689 #[test]
690 fn borrowed_tensor_from_f16_bits() {
691 let data = vec![0x3C00u16; 4];
693 let tensor = BorrowedTensor::from_f16_bits(&data, &[2, 2]).unwrap();
694 assert_eq!(tensor.data_type(), DataType::Float16);
695 }
696
697 #[test]
698 fn owned_tensor_i32_roundtrip() {
699 let tensor = OwnedTensor::zeros(DataType::Int32, &[4]).unwrap();
700 let data = tensor.to_vec_i32().unwrap();
701 assert_eq!(data, vec![0i32; 4]);
702 }
703
704 #[test]
705 fn owned_tensor_raw_bytes() {
706 let tensor = OwnedTensor::zeros(DataType::Float32, &[2]).unwrap();
707 let bytes = tensor.to_raw_bytes().unwrap();
708 assert_eq!(bytes.len(), 8); }
710
711 #[test]
712 fn borrowed_tensor_from_i16() {
713 let data = vec![1i16; 4];
714 let tensor = BorrowedTensor::from_i16(&data, &[2, 2]).unwrap();
715 assert_eq!(tensor.data_type(), DataType::Int16);
716 assert_eq!(tensor.element_count(), 4);
717 assert_eq!(tensor.size_bytes(), 8);
718 }
719
720 #[test]
721 fn borrowed_tensor_from_i8() {
722 let data = vec![1i8; 4];
723 let tensor = BorrowedTensor::from_i8(&data, &[2, 2]).unwrap();
724 assert_eq!(tensor.data_type(), DataType::Int8);
725 assert_eq!(tensor.element_count(), 4);
726 assert_eq!(tensor.size_bytes(), 4);
727 }
728
729 #[test]
730 fn borrowed_tensor_from_u32() {
731 let data = vec![1u32; 4];
732 let tensor = BorrowedTensor::from_u32(&data, &[2, 2]).unwrap();
733 assert_eq!(tensor.data_type(), DataType::UInt32);
734 assert_eq!(tensor.element_count(), 4);
735 assert_eq!(tensor.size_bytes(), 16);
736 }
737
738 #[test]
739 fn borrowed_tensor_from_u16() {
740 let data = vec![1u16; 4];
741 let tensor = BorrowedTensor::from_u16(&data, &[2, 2]).unwrap();
742 assert_eq!(tensor.data_type(), DataType::UInt16);
743 assert_eq!(tensor.element_count(), 4);
744 assert_eq!(tensor.size_bytes(), 8);
745 }
746
747 #[test]
748 fn borrowed_tensor_from_u8() {
749 let data = vec![1u8; 4];
750 let tensor = BorrowedTensor::from_u8(&data, &[2, 2]).unwrap();
751 assert_eq!(tensor.data_type(), DataType::UInt8);
752 assert_eq!(tensor.element_count(), 4);
753 assert_eq!(tensor.size_bytes(), 4);
754 }
755 }
756}