kornia_tensor/
tensor.rs

1use thiserror::Error;
2
3use super::{
4    allocator::{CpuAllocator, TensorAllocator, TensorAllocatorError},
5    storage::TensorStorage,
6    view::TensorView,
7};
8
9/// An error type for tensor operations.
10#[derive(Error, Debug, PartialEq)]
11pub enum TensorError {
12    /// Error when the cast operation fails.
13    #[error("Failed to cast data")]
14    CastError,
15
16    /// The number of elements in the data does not match the shape of the tensor.
17    #[error("The number of elements in the data does not match the shape of the tensor: {0}")]
18    InvalidShape(usize),
19
20    /// Index out of bounds.
21    #[error("Index out of bounds. The index {0} is out of bounds.")]
22    IndexOutOfBounds(usize),
23
24    /// Error with the tensor storage.
25    #[error("Error with the tensor storage: {0}")]
26    StorageError(#[from] TensorAllocatorError),
27
28    /// Dimension mismatch for operations requiring compatible shapes.
29    #[error("Dimension mismatch: {0}")]
30    DimensionMismatch(String),
31
32    /// Unsupported operation for the given data type or tensor configuration.
33    #[error("Unsupported operation: {0}")]
34    UnsupportedOperation(String),
35}
36
37/// Compute the strides from the shape of a tensor.
38///
39/// # Arguments
40///
41/// * `shape` - The shape of the tensor.
42///
43/// # Returns
44///
45/// * `strides` - The strides of the tensor.
46pub fn get_strides_from_shape<const N: usize>(shape: [usize; N]) -> [usize; N] {
47    let mut strides: [usize; N] = [0; N];
48    let mut stride = 1;
49    for i in (0..shape.len()).rev() {
50        strides[i] = stride;
51        stride *= shape[i];
52    }
53    strides
54}
55
56/// A data structure to represent a multi-dimensional tensor.
57///
58/// NOTE: Internally, the data is stored as an `arrow::ScalarBuffer` which represents a contiguous memory
59/// region that can be shared with other buffers and across thread boundaries.
60///
61/// # Attributes
62///
63/// * `storage` - The storage of the tensor.
64/// * `shape` - The shape of the tensor.
65/// * `strides` - The strides of the tensor data in memory.
66///
67/// # Example
68///
69/// ```
70/// use kornia_tensor::{Tensor, CpuAllocator};
71///
72/// let data: Vec<u8> = vec![1, 2, 3, 4];
73/// let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator).unwrap();
74/// assert_eq!(t.shape, [2, 2]);
75/// ```
76pub struct Tensor<T, const N: usize, A: TensorAllocator> {
77    /// The storage of the tensor.
78    pub storage: TensorStorage<T, A>,
79    /// The shape of the tensor.
80    pub shape: [usize; N],
81    /// The strides of the tensor data in memory.
82    pub strides: [usize; N],
83}
84
85impl<T, const N: usize, A: TensorAllocator> Tensor<T, N, A>
86where
87    A: 'static,
88{
89    /// Get the data of the tensor as a slice.
90    ///
91    /// # Returns
92    ///
93    /// A slice containing the data of the tensor.
94    #[inline]
95    pub fn as_slice(&self) -> &[T] {
96        self.storage.as_slice()
97    }
98
99    /// Get the data of the tensor as a mutable slice.
100    ///
101    /// # Returns
102    ///
103    /// A mutable slice containing the data of the tensor.
104    #[inline]
105    pub fn as_slice_mut(&mut self) -> &mut [T] {
106        self.storage.as_mut_slice()
107    }
108
109    /// Get the data of the tensor as a pointer.
110    ///
111    /// # Returns
112    ///
113    /// A pointer to the data of the tensor.
114    #[inline]
115    pub fn as_ptr(&self) -> *const T {
116        self.storage.as_ptr()
117    }
118
119    /// Get the data of the tensor as a mutable pointer.
120    ///
121    /// # Returns
122    ///
123    /// A mutable pointer to the data of the tensor.
124    #[inline]
125    pub fn as_mut_ptr(&mut self) -> *mut T {
126        self.storage.as_mut_ptr()
127    }
128
129    /// Consumes the tensor and returns the underlying vector.
130    ///
131    /// This method destroys the tensor and returns ownership of the underlying data.
132    /// The returned vector will have a length equal to the total number of elements in the tensor.
133    ///
134    #[inline]
135    pub fn into_vec(self) -> Vec<T> {
136        self.storage.into_vec()
137    }
138
139    /// Creates a new `Tensor` with the given shape and data.
140    ///
141    /// # Arguments
142    ///
143    /// * `shape` - An array containing the shape of the tensor.
144    /// * `data` - A vector containing the data of the tensor.
145    /// * `alloc` - The allocator to use.
146    ///
147    /// # Returns
148    ///
149    /// A new `Tensor` instance.
150    ///
151    /// # Errors
152    ///
153    /// If the number of elements in the data does not match the shape of the tensor, an error is returned.
154    ///
155    /// # Example
156    ///
157    /// ```
158    /// use kornia_tensor::{Tensor, CpuAllocator};
159    ///
160    /// let data: Vec<u8> = vec![1, 2, 3, 4];
161    /// let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator).unwrap();
162    /// assert_eq!(t.shape, [2, 2]);
163    /// ```
164    pub fn from_shape_vec(shape: [usize; N], data: Vec<T>, alloc: A) -> Result<Self, TensorError> {
165        let numel = shape.iter().product::<usize>();
166        if numel != data.len() {
167            return Err(TensorError::InvalidShape(numel));
168        }
169        let storage = TensorStorage::from_vec(data, alloc);
170        let strides = get_strides_from_shape(shape);
171        Ok(Self {
172            storage,
173            shape,
174            strides,
175        })
176    }
177
178    /// Creates a new `Tensor` with the given shape and slice of data.
179    ///
180    /// # Arguments
181    ///
182    /// * `shape` - An array containing the shape of the tensor.
183    /// * `data` - A slice containing the data of the tensor.
184    /// * `alloc` - The allocator to use.
185    ///
186    /// # Returns
187    ///
188    /// A new `Tensor` instance.
189    ///
190    /// # Errors
191    ///
192    /// If the number of elements in the data does not match the shape of the tensor, an error is returned.
193    pub fn from_shape_slice(shape: [usize; N], data: &[T], alloc: A) -> Result<Self, TensorError>
194    where
195        T: Clone,
196    {
197        let numel = shape.iter().product::<usize>();
198        if numel != data.len() {
199            return Err(TensorError::InvalidShape(numel));
200        }
201        let storage = TensorStorage::from_vec(data.to_vec(), alloc);
202        let strides = get_strides_from_shape(shape);
203        Ok(Self {
204            storage,
205            shape,
206            strides,
207        })
208    }
209
210    /// Creates a new `Tensor` with the given shape and raw parts.
211    ///
212    /// # Arguments
213    ///
214    /// * `shape` - An array containing the shape of the tensor.
215    /// * `data` - A pointer to the data of the tensor.
216    /// * `len` - The length of the data.
217    /// * `alloc` - The allocator to use.
218    ///
219    /// # Safety
220    ///
221    /// The pointer must be non-null and the length must be valid.
222    pub unsafe fn from_raw_parts(
223        shape: [usize; N],
224        data: *const T,
225        len: usize,
226        alloc: A,
227    ) -> Result<Self, TensorError>
228    where
229        T: Clone,
230    {
231        let storage = TensorStorage::from_raw_parts(data, len, alloc);
232        let strides = get_strides_from_shape(shape);
233        Ok(Self {
234            storage,
235            shape,
236            strides,
237        })
238    }
239
240    /// Creates a new `Tensor` with the given shape and a default value.
241    /// Creates a new `Tensor` with the given shape and a default value.
242    ///
243    /// # Arguments
244    ///
245    /// * `shape` - An array containing the shape of the tensor.
246    /// * `value` - The default value to fill the tensor with.
247    ///
248    /// # Returns
249    ///
250    /// A new `Tensor` instance.
251    ///
252    /// # Example
253    ///
254    /// ```
255    /// use kornia_tensor::{Tensor, CpuAllocator};
256    ///
257    /// let t = Tensor::<u8, 1, CpuAllocator>::from_shape_val([4], 0, CpuAllocator);
258    /// assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
259    ///
260    /// let t = Tensor::<u8, 2, CpuAllocator>::from_shape_val([2, 2], 1, CpuAllocator);
261    /// assert_eq!(t.as_slice(), vec![1, 1, 1, 1]);
262    ///
263    /// let t = Tensor::<u8, 3, CpuAllocator>::from_shape_val([2, 1, 3], 2, CpuAllocator);
264    /// assert_eq!(t.as_slice(), vec![2, 2, 2, 2, 2, 2]);
265    /// ```
266    pub fn from_shape_val(shape: [usize; N], value: T, alloc: A) -> Self
267    where
268        T: Clone,
269    {
270        let numel = shape.iter().product::<usize>();
271        let data = vec![value; numel];
272        let storage = TensorStorage::from_vec(data, alloc);
273        let strides = get_strides_from_shape(shape);
274        Self {
275            storage,
276            shape,
277            strides,
278        }
279    }
280
281    /// Create a new `Tensor` with the given shape and a function to generate the data.
282    ///
283    /// The function `f` is called with the index of the element to generate.
284    ///
285    /// # Arguments
286    ///
287    /// * `shape` - An array containing the shape of the tensor.
288    /// * `f` - The function to generate the data.
289    ///
290    /// # Returns
291    ///
292    /// A new `Tensor` instance.
293    ///
294    /// # Example
295    ///
296    /// ```
297    /// use kornia_tensor::{Tensor, CpuAllocator};
298    ///
299    /// let t = Tensor::<u8, 1, CpuAllocator>::from_shape_fn([4], CpuAllocator, |[i]| i as u8);
300    /// assert_eq!(t.as_slice(), vec![0, 1, 2, 3]);
301    ///
302    /// let t = Tensor::<u8, 2, CpuAllocator>::from_shape_fn([2, 2], CpuAllocator, |[i, j]| (i * 2 + j) as u8);
303    /// assert_eq!(t.as_slice(), vec![0, 1, 2, 3]);
304    /// ```
305    pub fn from_shape_fn<F>(shape: [usize; N], alloc: A, f: F) -> Self
306    where
307        F: Fn([usize; N]) -> T,
308    {
309        let numel = shape.iter().product::<usize>();
310        let data: Vec<T> = (0..numel)
311            .map(|i| {
312                let mut index = [0; N];
313                let mut j = i;
314                for k in (0..N).rev() {
315                    index[k] = j % shape[k];
316                    j /= shape[k];
317                }
318                f(index)
319            })
320            .collect();
321        let storage = TensorStorage::from_vec(data, alloc);
322        let strides = get_strides_from_shape(shape);
323        Self {
324            storage,
325            shape,
326            strides,
327        }
328    }
329
330    /// Returns the number of elements in the tensor.
331    ///
332    /// # Returns
333    ///
334    /// The number of elements in the tensor.
335    #[inline]
336    pub fn numel(&self) -> usize {
337        self.storage.len() / std::mem::size_of::<T>()
338    }
339
340    /// Get the offset of the element at the given index.
341    ///
342    /// # Arguments
343    ///
344    /// * `index` - The list of indices to get the element from.
345    ///
346    /// # Returns
347    ///
348    /// The offset of the element at the given index.
349    pub fn get_iter_offset(&self, index: [usize; N]) -> Option<usize> {
350        let mut offset = 0;
351        for ((&idx, dim_size), stride) in index.iter().zip(self.shape).zip(self.strides) {
352            if idx >= dim_size {
353                return None;
354            }
355            offset += idx * stride;
356        }
357        Some(offset)
358    }
359
360    /// Get the offset of the element at the given index without checking dim sizes.
361    ///
362    /// # Arguments
363    ///
364    /// * `index` - The list of indices to get the element from.
365    ///
366    /// # Returns
367    ///
368    /// The offset of the element at the given index.
369    pub fn get_iter_offset_unchecked(&self, index: [usize; N]) -> usize {
370        let mut offset = 0;
371        for (&idx, stride) in index.iter().zip(self.strides) {
372            offset += idx * stride;
373        }
374        offset
375    }
376
377    /// Get the index of the element at the given offset without checking dim sizes. The reverse of `Self::get_iter_offset_unchecked`.
378    ///
379    /// # Arguments
380    ///
381    /// * `offset` - The offset of the element at the given index.
382    ///
383    /// # Returns
384    ///
385    /// The array of indices to get the element from.
386    pub fn get_index_unchecked(&self, offset: usize) -> [usize; N] {
387        let mut idx = [0; N];
388        let mut rem = offset;
389        for (dim_i, s) in self.strides.iter().enumerate() {
390            idx[dim_i] = rem / s;
391            rem = offset % s;
392        }
393
394        idx
395    }
396
397    /// Get the index of the element at the given offset. The reverse of `Self::get_iter_offset`.
398    ///
399    /// # Arguments
400    ///
401    /// * `offset` - The offset of the element at the given index.
402    ///
403    /// # Returns
404    ///
405    /// The array of indices to get the element from.
406    ///
407    /// # Errors
408    ///
409    /// If the offset is out of bounds (>= numel), an error is returned.
410    pub fn get_index(&self, offset: usize) -> Result<[usize; N], TensorError> {
411        if offset >= self.numel() {
412            return Err(TensorError::IndexOutOfBounds(offset));
413        }
414        let idx = self.get_index_unchecked(offset);
415
416        Ok(idx)
417    }
418
419    /// Get the element at the given index without checking if the index is out of bounds.
420    ///
421    /// # Arguments
422    ///
423    /// * `index` - The list of indices to get the element from.
424    ///
425    /// # Returns
426    ///
427    /// A reference to the element at the given index.
428    ///
429    /// # Example
430    ///
431    /// ```
432    /// use kornia_tensor::{Tensor, CpuAllocator};
433    ///
434    /// let data: Vec<u8> = vec![1, 2, 3, 4];
435    ///
436    /// let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator).unwrap();
437    /// assert_eq!(*t.get_unchecked([0, 0]), 1);
438    /// assert_eq!(*t.get_unchecked([0, 1]), 2);
439    /// assert_eq!(*t.get_unchecked([1, 0]), 3);
440    /// assert_eq!(*t.get_unchecked([1, 1]), 4);
441    /// ```
442    pub fn get_unchecked(&self, index: [usize; N]) -> &T {
443        let offset = self.get_iter_offset_unchecked(index);
444        unsafe { self.storage.as_slice().get_unchecked(offset) }
445    }
446
447    /// Get the element at the given index, checking if the index is out of bounds.
448    ///
449    /// # Arguments
450    ///
451    /// * `index` - The list of indices to get the element from.
452    ///
453    /// # Returns
454    ///
455    /// A reference to the element at the given index.
456    ///
457    /// # Errors
458    ///
459    /// If the index is out of bounds, an error is returned.
460    ///
461    /// # Example
462    ///
463    /// ```
464    /// use kornia_tensor::{Tensor, CpuAllocator};
465    ///
466    /// let data: Vec<u8> = vec![1, 2, 3, 4];
467    ///
468    /// let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator).unwrap();
469    ///
470    /// assert_eq!(t.get([0, 0]), Some(&1));
471    /// assert_eq!(t.get([0, 1]), Some(&2));
472    /// assert_eq!(t.get([1, 0]), Some(&3));
473    /// assert_eq!(t.get([1, 1]), Some(&4));
474    ///
475    /// assert!(t.get([2, 0]).is_none());
476    /// ```
477    pub fn get(&self, index: [usize; N]) -> Option<&T> {
478        self.get_iter_offset(index)
479            .and_then(|i| self.storage.as_slice().get(i))
480    }
481
482    /// Reshape the tensor to a new shape.
483    ///
484    /// # Arguments
485    ///
486    /// * `shape` - The new shape of the tensor.
487    ///
488    /// # Returns
489    ///
490    /// A new `TensorView` instance.
491    ///
492    /// # Errors
493    ///
494    /// If the number of elements in the new shape does not match the number of elements in the tensor, an error is returned.
495    ///
496    /// # Example
497    ///
498    /// ```
499    /// use kornia_tensor::{Tensor, CpuAllocator};
500    ///
501    /// let data: Vec<u8> = vec![1, 2, 3, 4];
502    ///
503    /// let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator).unwrap();
504    /// let t2 = t.reshape([2, 2]).unwrap();
505    /// assert_eq!(t2.shape, [2, 2]);
506    /// assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
507    /// assert_eq!(t2.strides, [2, 1]);
508    /// assert_eq!(t2.numel(), 4);
509    /// ```
510    pub fn reshape<const M: usize>(
511        &self,
512        shape: [usize; M],
513    ) -> Result<TensorView<T, M, A>, TensorError> {
514        let numel = shape.iter().product::<usize>();
515        if numel != self.storage.len() {
516            return Err(TensorError::DimensionMismatch(format!(
517                "Cannot reshape tensor of shape {:?} with {} elements to shape {:?} with {} elements",
518                self.shape, self.storage.len(), shape, numel
519            )));
520        }
521
522        let strides = get_strides_from_shape(shape);
523
524        Ok(TensorView {
525            storage: &self.storage,
526            shape,
527            strides,
528        })
529    }
530
531    /// Permute the dimensions of the tensor.
532    ///
533    /// The permutation is given as an array of indices, where the value at each index is the new index of the dimension.
534    /// The data is not moved, only the order of the dimensions is changed.
535    ///
536    /// # Arguments
537    ///
538    /// * `axes` - The new order of the dimensions.
539    ///
540    /// # Returns
541    ///
542    /// A view of the tensor with the dimensions permuted.
543    pub fn permute_axes(&self, axes: [usize; N]) -> TensorView<T, N, A> {
544        let mut new_shape = [0; N];
545        let mut new_strides = [0; N];
546        for (i, &axis) in axes.iter().enumerate() {
547            new_shape[i] = self.shape[axis];
548            new_strides[i] = self.strides[axis];
549        }
550
551        TensorView {
552            storage: &self.storage,
553            shape: new_shape,
554            strides: new_strides,
555        }
556    }
557
558    /// Return a view of the tensor.
559    ///
560    /// The view is a reference to the tensor storage with a different shape and strides.
561    ///
562    /// # Returns
563    ///
564    /// A `TensorView` instance.
565    pub fn view(&self) -> TensorView<T, N, A> {
566        TensorView {
567            storage: &self.storage,
568            shape: self.shape,
569            strides: self.strides,
570        }
571    }
572
573    /// Create a new tensor with all elements set to zero.
574    ///
575    /// # Arguments
576    ///
577    /// * `shape` - The shape of the tensor.
578    /// * `alloc` - The allocator to use.
579    ///
580    /// # Returns
581    pub fn zeros(shape: [usize; N], alloc: A) -> Tensor<T, N, A>
582    where
583        T: Clone + num_traits::Zero,
584    {
585        // TODO: add allocator parameter
586        Self::from_shape_val(shape, T::zero(), alloc)
587    }
588
589    /// Apply a function to each element of the tensor.
590    ///
591    /// # Arguments
592    ///
593    /// * `f` - The function to apply to each element.
594    ///
595    /// # Returns
596    ///
597    /// A new `Tensor` instance.
598    ///
599    /// # Example
600    ///
601    /// ```
602    /// use kornia_tensor::{Tensor, CpuAllocator};
603    ///
604    /// let data: Vec<u8> = vec![1, 2, 3, 4];
605    /// let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator).unwrap();
606    ///
607    /// let t2 = t.map(|x| *x + 1);
608    /// assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
609    /// ```
610    pub fn map<U, F>(&self, f: F) -> Tensor<U, N, A>
611    where
612        F: Fn(&T) -> U,
613    {
614        let data: Vec<U> = self.as_slice().iter().map(f).collect();
615        let storage = TensorStorage::from_vec(data, self.storage.alloc().clone());
616
617        Tensor {
618            storage,
619            shape: self.shape,
620            strides: self.strides,
621        }
622    }
623
624    /// Cast the tensor to a new type.
625    ///
626    /// # Returns
627    ///
628    /// A new `Tensor` instance.
629    ///
630    /// # Example
631    ///
632    /// ```
633    /// use kornia_tensor::{Tensor, CpuAllocator};
634    ///
635    /// let data: Vec<u8> = vec![1, 2, 3, 4];
636    /// let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator).unwrap();
637    ///
638    /// let t2 = t.cast::<f32>();
639    /// assert_eq!(t2.as_slice(), vec![1.0, 2.0, 3.0, 4.0]);
640    /// ```
641    pub fn cast<U>(&self) -> Tensor<U, N, CpuAllocator>
642    where
643        U: From<T>,
644        T: Clone,
645    {
646        let mut data: Vec<U> = Vec::with_capacity(self.storage.len());
647        self.as_slice().iter().for_each(|x| {
648            data.push(U::from(x.clone()));
649        });
650        let storage = TensorStorage::from_vec(data, CpuAllocator);
651        Tensor {
652            storage,
653            shape: self.shape,
654            strides: self.strides,
655        }
656    }
657
658    /// Perform an element-wise operation on two tensors.
659    ///
660    /// # Arguments
661    ///
662    /// * `other` - The other tensor to perform the operation with.
663    /// * `op` - The operation to perform.
664    ///
665    /// # Returns
666    ///
667    /// A new `Tensor` instance.
668    ///
669    /// # Example
670    ///
671    /// ```
672    /// use kornia_tensor::{Tensor, CpuAllocator};
673    ///
674    /// let data1: Vec<u8> = vec![1, 2, 3, 4];
675    /// let t1 = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data1, CpuAllocator).unwrap();
676    ///
677    /// let data2: Vec<u8> = vec![1, 2, 3, 4];
678    /// let t2 = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data2, CpuAllocator).unwrap();
679    ///
680    /// let t3 = t1.element_wise_op(&t2, |a, b| *a + *b).unwrap();
681    /// assert_eq!(t3.as_slice(), vec![2, 4, 6, 8]);
682    ///
683    /// let t4 = t1.element_wise_op(&t2, |a, b| *a - *b).unwrap();
684    /// assert_eq!(t4.as_slice(), vec![0, 0, 0, 0]);
685    ///
686    /// let t5 = t1.element_wise_op(&t2, |a, b| *a * *b).unwrap();
687    /// assert_eq!(t5.as_slice(), vec![1, 4, 9, 16]);
688    ///
689    /// let t6 = t1.element_wise_op(&t2, |a, b| *a / *b).unwrap();
690    /// assert_eq!(t6.as_slice(), vec![1, 1, 1, 1]);
691    /// ```
692    pub fn element_wise_op<F>(
693        &self,
694        other: &Tensor<T, N, CpuAllocator>,
695        op: F,
696    ) -> Result<Tensor<T, N, CpuAllocator>, TensorError>
697    where
698        F: Fn(&T, &T) -> T,
699    {
700        if self.shape != other.shape {
701            return Err(TensorError::DimensionMismatch(format!(
702                "Shapes {:?} and {:?} are not compatible for element-wise operations",
703                self.shape, other.shape
704            )));
705        }
706
707        let data = self
708            .as_slice()
709            .iter()
710            .zip(other.as_slice().iter())
711            .map(|(a, b)| op(a, b))
712            .collect();
713
714        let storage = TensorStorage::from_vec(data, CpuAllocator);
715
716        Ok(Tensor {
717            storage,
718            shape: self.shape,
719            strides: self.strides,
720        })
721    }
722}
723
724impl<T, const N: usize, A> Clone for Tensor<T, N, A>
725where
726    T: Clone,
727    A: TensorAllocator + Clone + 'static,
728{
729    fn clone(&self) -> Self {
730        Self {
731            storage: self.storage.clone(),
732            shape: self.shape,
733            strides: self.strides,
734        }
735    }
736}
737
738impl<T, const N: usize, A> std::fmt::Display for Tensor<T, N, A>
739where
740    T: std::fmt::Display + std::fmt::LowerExp,
741    A: TensorAllocator + 'static,
742{
743    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744        let width = self
745            .storage
746            .as_slice()
747            .iter()
748            .map(|v| format!("{v:.4}").len())
749            .max()
750            .unwrap();
751
752        let scientific = width > 8;
753
754        let should_mask: [bool; N] = self.shape.map(|s| s > 8);
755        let mut skip_until = 0;
756
757        for (i, v) in self.storage.as_slice().iter().enumerate() {
758            if i < skip_until {
759                continue;
760            }
761            let mut value = String::new();
762            let mut prefix = String::new();
763            let mut suffix = String::new();
764            let mut separator = ",".to_string();
765            let mut last_size = 1;
766            for (dim, (&size, maskable)) in self.shape.iter().zip(should_mask).enumerate().rev() {
767                let prod = size * last_size;
768                if i % prod == (3 * last_size) && maskable {
769                    let pad = if dim == (N - 1) { 0 } else { dim + 1 };
770                    value = format!("{}...", " ".repeat(pad));
771                    skip_until = i + (size - 4) * last_size;
772                    prefix = "".to_string();
773                    if dim != (N - 1) {
774                        separator = "\n".repeat(N - 1 - dim);
775                    }
776                    break;
777                } else if i % prod == 0 {
778                    prefix.push('[');
779                } else if (i + 1) % prod == 0 {
780                    suffix.push(']');
781                    separator.push('\n');
782                    if dim == 0 {
783                        separator = "".to_string();
784                    }
785                } else {
786                    break;
787                }
788                last_size = prod;
789            }
790            if !prefix.is_empty() {
791                prefix = format!("{prefix:>N$}");
792            }
793
794            if value.is_empty() {
795                value = if scientific {
796                    let num = format!("{v:.4e}");
797                    let (before, after) = num.split_once('e').unwrap();
798                    let after = if let Some(stripped) = after.strip_prefix('-') {
799                        format!("-{:0>2}", &stripped)
800                    } else {
801                        format!("+{:0>2}", &after)
802                    };
803                    format!("{before}e{after}")
804                } else {
805                    let rounded = format!("{v:.4}");
806                    format!("{rounded:>width$}")
807                }
808            };
809            write!(f, "{prefix}{value}{suffix}{separator}",)?;
810        }
811        Ok(())
812    }
813}
814
815#[cfg(test)]
816mod tests {
817    use crate::allocator::CpuAllocator;
818    use crate::tensor::{Tensor, TensorError};
819
820    #[test]
821    fn constructor_1d() -> Result<(), TensorError> {
822        let data: Vec<u8> = vec![1];
823        let t = Tensor::<u8, 1, _>::from_shape_vec([1], data, CpuAllocator)?;
824        assert_eq!(t.shape, [1]);
825        assert_eq!(t.as_slice(), vec![1]);
826        assert_eq!(t.strides, [1]);
827        assert_eq!(t.numel(), 1);
828        Ok(())
829    }
830
831    #[test]
832    fn constructor_2d() -> Result<(), TensorError> {
833        let data: Vec<u8> = vec![1, 2];
834        let t = Tensor::<u8, 2, _>::from_shape_vec([1, 2], data, CpuAllocator)?;
835        assert_eq!(t.shape, [1, 2]);
836        assert_eq!(t.as_slice(), vec![1, 2]);
837        assert_eq!(t.strides, [2, 1]);
838        assert_eq!(t.numel(), 2);
839        Ok(())
840    }
841
842    #[test]
843    fn get_1d() -> Result<(), TensorError> {
844        let data: Vec<u8> = vec![1, 2, 3, 4];
845        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
846        assert_eq!(t.get([0]), Some(&1));
847        assert_eq!(t.get([1]), Some(&2));
848        assert_eq!(t.get([2]), Some(&3));
849        assert_eq!(t.get([3]), Some(&4));
850        assert!(t.get([4]).is_none());
851        Ok(())
852    }
853
854    #[test]
855    fn get_2d() -> Result<(), TensorError> {
856        let data: Vec<u8> = vec![1, 2, 3, 4];
857        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
858        assert_eq!(t.get([0, 0]), Some(&1));
859        assert_eq!(t.get([0, 1]), Some(&2));
860        assert_eq!(t.get([1, 0]), Some(&3));
861        assert_eq!(t.get([1, 1]), Some(&4));
862        assert!(t.get([2, 0]).is_none());
863        assert!(t.get([0, 2]).is_none());
864        Ok(())
865    }
866
867    #[test]
868    fn get_3d() -> Result<(), TensorError> {
869        let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
870        let t = Tensor::<u8, 3, _>::from_shape_vec([2, 1, 3], data, CpuAllocator)?;
871        assert_eq!(t.get([0, 0, 0]), Some(&1));
872        assert_eq!(t.get([0, 0, 1]), Some(&2));
873        assert_eq!(t.get([0, 0, 2]), Some(&3));
874        assert_eq!(t.get([1, 0, 0]), Some(&4));
875        assert_eq!(t.get([1, 0, 1]), Some(&5));
876        assert_eq!(t.get([1, 0, 2]), Some(&6));
877        assert!(t.get([2, 0, 0]).is_none());
878        assert!(t.get([0, 1, 0]).is_none());
879        assert!(t.get([0, 0, 3]).is_none());
880        Ok(())
881    }
882
883    #[test]
884    fn get_checked_1d() -> Result<(), TensorError> {
885        let data: Vec<u8> = vec![1, 2, 3, 4];
886        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
887        assert_eq!(*t.get_unchecked([0]), 1);
888        assert_eq!(*t.get_unchecked([1]), 2);
889        assert_eq!(*t.get_unchecked([2]), 3);
890        assert_eq!(*t.get_unchecked([3]), 4);
891        Ok(())
892    }
893
894    #[test]
895    fn get_checked_2d() -> Result<(), TensorError> {
896        let data: Vec<u8> = vec![1, 2, 3, 4];
897        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
898        assert_eq!(*t.get_unchecked([0, 0]), 1);
899        assert_eq!(*t.get_unchecked([0, 1]), 2);
900        assert_eq!(*t.get_unchecked([1, 0]), 3);
901        assert_eq!(*t.get_unchecked([1, 1]), 4);
902        Ok(())
903    }
904    #[test]
905    fn reshape_1d() -> Result<(), TensorError> {
906        let data: Vec<u8> = vec![1, 2, 3, 4];
907        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
908
909        let view = t.reshape([2, 2])?;
910
911        assert_eq!(view.shape, [2, 2]);
912        assert_eq!(view.as_slice(), vec![1, 2, 3, 4]);
913        assert_eq!(view.strides, [2, 1]);
914        assert_eq!(view.numel(), 4);
915        assert_eq!(view.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
916        Ok(())
917    }
918
919    #[test]
920    fn reshape_2d() -> Result<(), TensorError> {
921        let data: Vec<u8> = vec![1, 2, 3, 4];
922        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
923        let t2 = t.reshape([4])?;
924
925        assert_eq!(t2.shape, [4]);
926        assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
927        assert_eq!(t2.strides, [1]);
928        assert_eq!(t2.numel(), 4);
929        assert_eq!(t2.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
930        Ok(())
931    }
932
933    #[test]
934    fn reshape_get_1d() -> Result<(), TensorError> {
935        let data: Vec<u8> = vec![1, 2, 3, 4];
936        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
937        let view = t.reshape([2, 2])?;
938        assert_eq!(*view.get_unchecked([0, 0]), 1);
939        assert_eq!(*view.get_unchecked([0, 1]), 2);
940        assert_eq!(*view.get_unchecked([1, 0]), 3);
941        assert_eq!(*view.get_unchecked([1, 1]), 4);
942        assert_eq!(view.numel(), 4);
943        assert_eq!(view.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
944        Ok(())
945    }
946
947    #[test]
948    fn permute_axes_1d() -> Result<(), TensorError> {
949        let data: Vec<u8> = vec![1, 2, 3, 4];
950        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
951        let t2 = t.permute_axes([0]);
952        assert_eq!(t2.shape, [4]);
953        assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
954        assert_eq!(t2.strides, [1]);
955        assert_eq!(t2.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
956        Ok(())
957    }
958
959    #[test]
960    fn permute_axes_2d() -> Result<(), TensorError> {
961        let data: Vec<u8> = vec![1, 2, 3, 4];
962        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
963        let view = t.permute_axes([1, 0]);
964        assert_eq!(view.shape, [2, 2]);
965        assert_eq!(*view.get_unchecked([0, 0]), 1u8);
966        assert_eq!(*view.get_unchecked([1, 0]), 2u8);
967        assert_eq!(*view.get_unchecked([0, 1]), 3u8);
968        assert_eq!(*view.get_unchecked([1, 1]), 4u8);
969        assert_eq!(view.strides, [1, 2]);
970        assert_eq!(view.as_contiguous().as_slice(), vec![1, 3, 2, 4]);
971        Ok(())
972    }
973
974    #[test]
975    fn contiguous_2d() -> Result<(), TensorError> {
976        let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
977        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 3], data, CpuAllocator)?;
978
979        let view = t.permute_axes([1, 0]);
980
981        let contiguous = view.as_contiguous();
982
983        assert_eq!(contiguous.shape, [3, 2]);
984        assert_eq!(contiguous.strides, [2, 1]);
985        assert_eq!(contiguous.as_slice(), vec![1, 4, 2, 5, 3, 6]);
986
987        Ok(())
988    }
989
990    #[test]
991    fn zeros_1d() -> Result<(), TensorError> {
992        let t = Tensor::<u8, 1, _>::zeros([4], CpuAllocator);
993        assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
994        Ok(())
995    }
996
997    #[test]
998    fn zeros_2d() -> Result<(), TensorError> {
999        let t = Tensor::<u8, 2, _>::zeros([2, 2], CpuAllocator);
1000        assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1001        Ok(())
1002    }
1003
1004    #[test]
1005    fn map_1d() -> Result<(), TensorError> {
1006        let data: Vec<u8> = vec![1, 2, 3, 4];
1007        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1008        let t2 = t.map(|x| *x + 1);
1009        assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
1010        Ok(())
1011    }
1012
1013    #[test]
1014    fn map_2d() -> Result<(), TensorError> {
1015        let data: Vec<u8> = vec![1, 2, 3, 4];
1016        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1017        let t2 = t.map(|x| *x + 1);
1018        assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
1019        Ok(())
1020    }
1021
1022    #[test]
1023    fn from_shape_val_1d() -> Result<(), TensorError> {
1024        let t = Tensor::<u8, 1, _>::from_shape_val([4], 0, CpuAllocator);
1025        assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1026        Ok(())
1027    }
1028
1029    #[test]
1030    fn from_shape_val_2d() -> Result<(), TensorError> {
1031        let t = Tensor::<u8, 2, _>::from_shape_val([2, 2], 1, CpuAllocator);
1032        assert_eq!(t.as_slice(), vec![1, 1, 1, 1]);
1033        Ok(())
1034    }
1035
1036    #[test]
1037    fn from_shape_val_3d() -> Result<(), TensorError> {
1038        let t = Tensor::<u8, 3, _>::from_shape_val([2, 1, 3], 2, CpuAllocator);
1039        assert_eq!(t.as_slice(), vec![2, 2, 2, 2, 2, 2]);
1040        Ok(())
1041    }
1042
1043    #[test]
1044    fn cast_1d() -> Result<(), TensorError> {
1045        let data: Vec<u8> = vec![1, 2, 3, 4];
1046        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1047        let t2 = t.cast::<u16>();
1048        assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1049        Ok(())
1050    }
1051
1052    #[test]
1053    fn cast_2d() -> Result<(), TensorError> {
1054        let data: Vec<u8> = vec![1, 2, 3, 4];
1055        let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1056        let t2 = t.cast::<u16>();
1057        assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1058        Ok(())
1059    }
1060
1061    #[test]
1062    fn from_shape_fn_1d() -> Result<(), TensorError> {
1063        let alloc = CpuAllocator;
1064        let t = Tensor::from_shape_fn([3, 3], alloc, |[i, j]| ((1 + i) * (1 + j)) as u8);
1065        assert_eq!(t.as_slice(), vec![1, 2, 3, 2, 4, 6, 3, 6, 9]);
1066        Ok(())
1067    }
1068
1069    #[test]
1070    fn from_shape_fn_2d() -> Result<(), TensorError> {
1071        let alloc = CpuAllocator;
1072        let t = Tensor::from_shape_fn([3, 3], alloc, |[i, j]| ((1 + i) * (1 + j)) as f32);
1073        assert_eq!(
1074            t.as_slice(),
1075            vec![1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0]
1076        );
1077        Ok(())
1078    }
1079
1080    #[test]
1081    fn from_shape_fn_3d() -> Result<(), TensorError> {
1082        let alloc = CpuAllocator;
1083        let t = Tensor::from_shape_fn([2, 3, 3], alloc, |[x, y, c]| {
1084            ((1 + x) * (1 + y) * (1 + c)) as i16
1085        });
1086        assert_eq!(
1087            t.as_slice(),
1088            vec![1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18]
1089        );
1090        Ok(())
1091    }
1092
1093    #[test]
1094    fn view_1d() -> Result<(), TensorError> {
1095        let alloc = CpuAllocator;
1096        let data: Vec<u8> = vec![1, 2, 3, 4];
1097        let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, alloc)?;
1098        let view = t.view();
1099
1100        // check that the view has the same data
1101        assert_eq!(view.as_slice(), t.as_slice());
1102
1103        // check that the data pointer is the same
1104        assert!(std::ptr::eq(view.as_ptr(), t.as_ptr()));
1105
1106        Ok(())
1107    }
1108
1109    #[test]
1110    fn from_slice() -> Result<(), TensorError> {
1111        let data: [u8; 4] = [1, 2, 3, 4];
1112        let t = Tensor::<u8, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1113
1114        assert_eq!(t.shape, [2, 2]);
1115        assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
1116
1117        Ok(())
1118    }
1119
1120    #[test]
1121    fn display_2d() -> Result<(), TensorError> {
1122        let data: [u8; 4] = [1, 2, 3, 4];
1123        let t = Tensor::<u8, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1124        let disp = t.to_string();
1125        let lines = disp.lines().collect::<Vec<_>>();
1126
1127        #[rustfmt::skip]
1128        assert_eq!(lines.as_slice(),
1129        ["[[1,2],",
1130         " [3,4]]"]);
1131        Ok(())
1132    }
1133
1134    #[test]
1135    fn display_3d() -> Result<(), TensorError> {
1136        let data: [u8; 12] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1137        let t = Tensor::<u8, 3, _>::from_shape_slice([2, 3, 2], &data, CpuAllocator)?;
1138        let disp = t.to_string();
1139        let lines = disp.lines().collect::<Vec<_>>();
1140
1141        #[rustfmt::skip]
1142        assert_eq!(lines.as_slice(),
1143        ["[[[ 1, 2],",
1144         "  [ 3, 4],",
1145         "  [ 5, 6]],",
1146         "",
1147         " [[ 7, 8],",
1148         "  [ 9,10],",
1149         "  [11,12]]]"]);
1150        Ok(())
1151    }
1152
1153    #[test]
1154    fn display_float() -> Result<(), TensorError> {
1155        let data: [f32; 4] = [1.00001, 1.00009, 0.99991, 0.99999];
1156        let t = Tensor::<f32, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1157        let disp = t.to_string();
1158        let lines = disp.lines().collect::<Vec<_>>();
1159
1160        #[rustfmt::skip]
1161        assert_eq!(lines.as_slice(),
1162        ["[[1.0000,1.0001],",
1163         " [0.9999,1.0000]]"]);
1164        Ok(())
1165    }
1166
1167    #[test]
1168    fn display_big_float() -> Result<(), TensorError> {
1169        let data: [f32; 4] = [1000.00001, 1.00009, 0.99991, 0.99999];
1170        let t = Tensor::<f32, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1171        let disp = t.to_string();
1172        let lines = disp.lines().collect::<Vec<_>>();
1173
1174        #[rustfmt::skip]
1175        assert_eq!(lines.as_slice(),
1176        ["[[1.0000e+03,1.0001e+00],",
1177         " [9.9991e-01,9.9999e-01]]"]);
1178        Ok(())
1179    }
1180
1181    #[test]
1182    fn display_big_tensor() -> Result<(), TensorError> {
1183        let data: [u8; 1000] = [0; 1000];
1184        let t = Tensor::<u8, 3, _>::from_shape_slice([10, 10, 10], &data, CpuAllocator)?;
1185        let disp = t.to_string();
1186        let lines = disp.lines().collect::<Vec<_>>();
1187
1188        #[rustfmt::skip]
1189        assert_eq!(lines.as_slice(),
1190        ["[[[0,0,0,...,0],",
1191         "  [0,0,0,...,0],",
1192         "  [0,0,0,...,0],",
1193         "  ...",
1194         "  [0,0,0,...,0]],",
1195         "",
1196         " [[0,0,0,...,0],",
1197         "  [0,0,0,...,0],",
1198         "  [0,0,0,...,0],",
1199         "  ...",
1200         "  [0,0,0,...,0]],",
1201         "",
1202         " [[0,0,0,...,0],",
1203         "  [0,0,0,...,0],",
1204         "  [0,0,0,...,0],",
1205         "  ...",
1206         "  [0,0,0,...,0]],",
1207         "",
1208         " ...",
1209         "",
1210         " [[0,0,0,...,0],",
1211         "  [0,0,0,...,0],",
1212         "  [0,0,0,...,0],",
1213         "  ...",
1214         "  [0,0,0,...,0]]]"]);
1215        Ok(())
1216    }
1217
1218    #[test]
1219    fn get_index_unchecked_1d() -> Result<(), TensorError> {
1220        let data: Vec<u8> = vec![1, 2, 3, 4];
1221        let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator)?;
1222        assert_eq!(t.get_index_unchecked(0), [0]);
1223        assert_eq!(t.get_index_unchecked(1), [1]);
1224        assert_eq!(t.get_index_unchecked(2), [2]);
1225        assert_eq!(t.get_index_unchecked(3), [3]);
1226        Ok(())
1227    }
1228
1229    #[test]
1230    fn get_index_unchecked_2d() -> Result<(), TensorError> {
1231        let data: Vec<u8> = vec![1, 2, 3, 4];
1232        let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator)?;
1233        assert_eq!(t.get_index_unchecked(0), [0, 0]);
1234        assert_eq!(t.get_index_unchecked(1), [0, 1]);
1235        assert_eq!(t.get_index_unchecked(2), [1, 0]);
1236        assert_eq!(t.get_index_unchecked(3), [1, 1]);
1237        Ok(())
1238    }
1239
1240    #[test]
1241    fn get_index_unchecked_3d() -> Result<(), TensorError> {
1242        let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1243        let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1244        assert_eq!(t.get_index_unchecked(0), [0, 0, 0]);
1245        assert_eq!(t.get_index_unchecked(1), [0, 0, 1]);
1246        assert_eq!(t.get_index_unchecked(2), [0, 0, 2]);
1247        assert_eq!(t.get_index_unchecked(3), [0, 1, 0]);
1248        assert_eq!(t.get_index_unchecked(4), [0, 1, 1]);
1249        assert_eq!(t.get_index_unchecked(5), [0, 1, 2]);
1250        assert_eq!(t.get_index_unchecked(6), [1, 0, 0]);
1251        assert_eq!(t.get_index_unchecked(7), [1, 0, 1]);
1252        assert_eq!(t.get_index_unchecked(8), [1, 0, 2]);
1253        assert_eq!(t.get_index_unchecked(9), [1, 1, 0]);
1254        assert_eq!(t.get_index_unchecked(10), [1, 1, 1]);
1255        assert_eq!(t.get_index_unchecked(11), [1, 1, 2]);
1256        Ok(())
1257    }
1258
1259    #[test]
1260    fn get_index_to_offset_and_back() -> Result<(), TensorError> {
1261        let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1262        let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1263        for offset in 0..12 {
1264            assert_eq!(
1265                t.get_iter_offset_unchecked(t.get_index_unchecked(offset)),
1266                offset
1267            );
1268        }
1269        Ok(())
1270    }
1271
1272    #[test]
1273    fn get_offset_to_index_and_back() -> Result<(), TensorError> {
1274        let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1275        let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1276        for ind in [
1277            [0, 0, 0],
1278            [0, 0, 1],
1279            [0, 0, 2],
1280            [0, 1, 0],
1281            [0, 1, 1],
1282            [0, 1, 2],
1283            [1, 0, 0],
1284            [1, 0, 1],
1285            [1, 0, 2],
1286            [1, 1, 0],
1287            [1, 1, 1],
1288            [1, 1, 2],
1289        ] {
1290            assert_eq!(t.get_index_unchecked(t.get_iter_offset_unchecked(ind)), ind);
1291        }
1292        Ok(())
1293    }
1294
1295    #[test]
1296    fn get_index_1d() -> Result<(), TensorError> {
1297        let data: Vec<u8> = vec![1, 2, 3, 4];
1298        let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator)?;
1299        assert_eq!(t.get_index(3), Ok([3]));
1300        assert!(t
1301            .get_index(4)
1302            .is_err_and(|x| x == TensorError::IndexOutOfBounds(4)));
1303        Ok(())
1304    }
1305
1306    #[test]
1307    fn get_index_2d() -> Result<(), TensorError> {
1308        let data: Vec<u8> = vec![1, 2, 3, 4];
1309        let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator)?;
1310        assert_eq!(t.get_index_unchecked(3), [1, 1]);
1311        assert!(t
1312            .get_index(4)
1313            .is_err_and(|x| x == TensorError::IndexOutOfBounds(4)));
1314        Ok(())
1315    }
1316
1317    #[test]
1318    fn get_index_3d() -> Result<(), TensorError> {
1319        let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1320        let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1321        assert_eq!(t.get_index_unchecked(11), [1, 1, 2]);
1322        assert!(t
1323            .get_index(12)
1324            .is_err_and(|x| x == TensorError::IndexOutOfBounds(12)));
1325        Ok(())
1326    }
1327
1328    #[test]
1329    fn from_raw_parts() -> Result<(), TensorError> {
1330        let data: Vec<u8> = vec![1, 2, 3, 4];
1331        let t = unsafe { Tensor::from_raw_parts([2, 2], data.as_ptr(), data.len(), CpuAllocator)? };
1332        std::mem::forget(data);
1333        assert_eq!(t.shape, [2, 2]);
1334        assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
1335        Ok(())
1336    }
1337}