1use thiserror::Error;
2
3use super::{
4 allocator::{CpuAllocator, TensorAllocator, TensorAllocatorError},
5 storage::TensorStorage,
6 view::TensorView,
7};
8
9#[derive(Error, Debug, PartialEq)]
11pub enum TensorError {
12 #[error("Failed to cast data")]
14 CastError,
15
16 #[error("The number of elements in the data does not match the shape of the tensor: {0}")]
18 InvalidShape(usize),
19
20 #[error("Index out of bounds. The index {0} is out of bounds.")]
22 IndexOutOfBounds(usize),
23
24 #[error("Error with the tensor storage: {0}")]
26 StorageError(#[from] TensorAllocatorError),
27
28 #[error("Dimension mismatch: {0}")]
30 DimensionMismatch(String),
31
32 #[error("Unsupported operation: {0}")]
34 UnsupportedOperation(String),
35}
36
37pub fn get_strides_from_shape<const N: usize>(shape: [usize; N]) -> [usize; N] {
47 let mut strides: [usize; N] = [0; N];
48 let mut stride = 1;
49 for i in (0..shape.len()).rev() {
50 strides[i] = stride;
51 stride *= shape[i];
52 }
53 strides
54}
55
56pub struct Tensor<T, const N: usize, A: TensorAllocator> {
77 pub storage: TensorStorage<T, A>,
79 pub shape: [usize; N],
81 pub strides: [usize; N],
83}
84
85impl<T, const N: usize, A: TensorAllocator> Tensor<T, N, A> {
86 #[inline]
92 pub fn as_slice(&self) -> &[T] {
93 self.storage.as_slice()
94 }
95
96 #[inline]
102 pub fn as_slice_mut(&mut self) -> &mut [T] {
103 self.storage.as_mut_slice()
104 }
105
106 #[inline]
112 pub fn as_ptr(&self) -> *const T {
113 self.storage.as_ptr()
114 }
115
116 #[inline]
122 pub fn as_mut_ptr(&mut self) -> *mut T {
123 self.storage.as_mut_ptr()
124 }
125
126 #[inline]
132 pub fn into_vec(self) -> Vec<T> {
133 self.storage.into_vec()
134 }
135
136 pub fn from_shape_vec(shape: [usize; N], data: Vec<T>, alloc: A) -> Result<Self, TensorError> {
162 let numel = shape.iter().product::<usize>();
163 if numel != data.len() {
164 return Err(TensorError::InvalidShape(numel));
165 }
166 let storage = TensorStorage::from_vec(data, alloc);
167 let strides = get_strides_from_shape(shape);
168 Ok(Self {
169 storage,
170 shape,
171 strides,
172 })
173 }
174
175 pub fn from_shape_slice(shape: [usize; N], data: &[T], alloc: A) -> Result<Self, TensorError>
191 where
192 T: Clone,
193 {
194 let numel = shape.iter().product::<usize>();
195 if numel != data.len() {
196 return Err(TensorError::InvalidShape(numel));
197 }
198 let storage = TensorStorage::from_vec(data.to_vec(), alloc);
199 let strides = get_strides_from_shape(shape);
200 Ok(Self {
201 storage,
202 shape,
203 strides,
204 })
205 }
206
207 pub unsafe fn from_raw_parts(
220 shape: [usize; N],
221 data: *const T,
222 len: usize,
223 alloc: A,
224 ) -> Result<Self, TensorError>
225 where
226 T: Clone,
227 {
228 let storage = TensorStorage::from_raw_parts(data, len, alloc);
229 let strides = get_strides_from_shape(shape);
230 Ok(Self {
231 storage,
232 shape,
233 strides,
234 })
235 }
236
237 pub fn from_shape_val(shape: [usize; N], value: T, alloc: A) -> Self
264 where
265 T: Clone,
266 {
267 let numel = shape.iter().product::<usize>();
268 let data = vec![value; numel];
269 let storage = TensorStorage::from_vec(data, alloc);
270 let strides = get_strides_from_shape(shape);
271 Self {
272 storage,
273 shape,
274 strides,
275 }
276 }
277
278 pub fn from_shape_fn<F>(shape: [usize; N], alloc: A, f: F) -> Self
303 where
304 F: Fn([usize; N]) -> T,
305 {
306 let numel = shape.iter().product::<usize>();
307 let data: Vec<T> = (0..numel)
308 .map(|i| {
309 let mut index = [0; N];
310 let mut j = i;
311 for k in (0..N).rev() {
312 index[k] = j % shape[k];
313 j /= shape[k];
314 }
315 f(index)
316 })
317 .collect();
318 let storage = TensorStorage::from_vec(data, alloc);
319 let strides = get_strides_from_shape(shape);
320 Self {
321 storage,
322 shape,
323 strides,
324 }
325 }
326
327 #[inline]
333 pub fn numel(&self) -> usize {
334 self.storage.len() / std::mem::size_of::<T>()
335 }
336
337 pub fn get_iter_offset(&self, index: [usize; N]) -> Option<usize> {
347 let mut offset = 0;
348 for ((&idx, dim_size), stride) in index.iter().zip(self.shape).zip(self.strides) {
349 if idx >= dim_size {
350 return None;
351 }
352 offset += idx * stride;
353 }
354 Some(offset)
355 }
356
357 pub fn get_iter_offset_unchecked(&self, index: [usize; N]) -> usize {
367 let mut offset = 0;
368 for (&idx, stride) in index.iter().zip(self.strides) {
369 offset += idx * stride;
370 }
371 offset
372 }
373
374 pub fn get_index_unchecked(&self, offset: usize) -> [usize; N] {
384 let mut idx = [0; N];
385 let mut rem = offset;
386 for (dim_i, s) in self.strides.iter().enumerate() {
387 idx[dim_i] = rem / s;
388 rem = offset % s;
389 }
390
391 idx
392 }
393
394 pub fn get_index(&self, offset: usize) -> Result<[usize; N], TensorError> {
408 if offset >= self.numel() {
409 return Err(TensorError::IndexOutOfBounds(offset));
410 }
411 let idx = self.get_index_unchecked(offset);
412
413 Ok(idx)
414 }
415
416 pub fn get_unchecked(&self, index: [usize; N]) -> &T {
440 let offset = self.get_iter_offset_unchecked(index);
441 unsafe { self.storage.as_slice().get_unchecked(offset) }
442 }
443
444 pub fn get(&self, index: [usize; N]) -> Option<&T> {
475 self.get_iter_offset(index)
476 .and_then(|i| self.storage.as_slice().get(i))
477 }
478
479 pub fn reshape<const M: usize>(
508 &self,
509 shape: [usize; M],
510 ) -> Result<TensorView<'_, T, M, A>, TensorError> {
511 let numel = shape.iter().product::<usize>();
512 if numel != self.storage.len() {
513 return Err(TensorError::DimensionMismatch(format!(
514 "Cannot reshape tensor of shape {:?} with {} elements to shape {:?} with {} elements",
515 self.shape, self.storage.len(), shape, numel
516 )));
517 }
518
519 let strides = get_strides_from_shape(shape);
520
521 Ok(TensorView {
522 storage: &self.storage,
523 shape,
524 strides,
525 })
526 }
527
528 pub fn permute_axes(&self, axes: [usize; N]) -> TensorView<'_, T, N, A> {
541 let mut new_shape = [0; N];
542 let mut new_strides = [0; N];
543 for (i, &axis) in axes.iter().enumerate() {
544 new_shape[i] = self.shape[axis];
545 new_strides[i] = self.strides[axis];
546 }
547
548 TensorView {
549 storage: &self.storage,
550 shape: new_shape,
551 strides: new_strides,
552 }
553 }
554
555 pub fn view(&self) -> TensorView<'_, T, N, A> {
563 TensorView {
564 storage: &self.storage,
565 shape: self.shape,
566 strides: self.strides,
567 }
568 }
569
570 pub fn zeros(shape: [usize; N], alloc: A) -> Tensor<T, N, A>
579 where
580 T: Clone + num_traits::Zero,
581 {
582 Self::from_shape_val(shape, T::zero(), alloc)
584 }
585
586 pub fn map<U, F>(&self, f: F) -> Tensor<U, N, A>
608 where
609 F: Fn(&T) -> U,
610 {
611 let data: Vec<U> = self.as_slice().iter().map(f).collect();
612 let storage = TensorStorage::from_vec(data, self.storage.alloc().clone());
613
614 Tensor {
615 storage,
616 shape: self.shape,
617 strides: self.strides,
618 }
619 }
620
621 pub fn is_standard_layout(&self) -> bool {
638 let mut expected_stride: usize = 1;
639 for (&dim, &stride) in self.shape.iter().rev().zip(self.strides.iter().rev()) {
640 if stride != expected_stride {
641 return false;
642 }
643 expected_stride = expected_stride.saturating_mul(dim);
644 }
645 true
646 }
647
648 pub fn to_standard_layout(&self, alloc: A) -> Result<Self, TensorError>
674 where
675 T: Clone + std::fmt::Debug,
676 {
677 if self.is_standard_layout() {
678 return Ok(self.clone());
679 }
680
681 let total_elems: usize = self.shape.iter().product();
682 let mut flat = Vec::with_capacity(total_elems);
683 let mut idx = [0; N];
684 let slice = self.storage.as_slice();
685
686 for _ in 0..total_elems {
687 let offset = idx
688 .iter()
689 .zip(self.strides.iter())
690 .map(|(&i, &s)| i * s)
691 .sum::<usize>();
692
693 flat.push(slice[offset].clone());
694
695 for dim in (0..N).rev() {
697 idx[dim] += 1;
698 if idx[dim] < self.shape[dim] {
699 break;
700 } else {
701 idx[dim] = 0;
702 }
703 }
704 }
705
706 Tensor::from_shape_vec(self.shape, flat, alloc).map_err(|_| {
707 TensorError::DimensionMismatch(format!(
708 "Cannot construct tensor of shape {:?} with {:?} elements",
709 self.shape, total_elems,
710 ))
711 })
712 }
713
714 pub fn cast<U>(&self) -> Tensor<U, N, CpuAllocator>
732 where
733 U: From<T>,
734 T: Clone,
735 {
736 let mut data: Vec<U> = Vec::with_capacity(self.storage.len());
737 self.as_slice().iter().for_each(|x| {
738 data.push(U::from(x.clone()));
739 });
740 let storage = TensorStorage::from_vec(data, CpuAllocator);
741 Tensor {
742 storage,
743 shape: self.shape,
744 strides: self.strides,
745 }
746 }
747
748 pub fn element_wise_op<F>(
783 &self,
784 other: &Tensor<T, N, CpuAllocator>,
785 op: F,
786 ) -> Result<Tensor<T, N, CpuAllocator>, TensorError>
787 where
788 F: Fn(&T, &T) -> T,
789 {
790 if self.shape != other.shape {
791 return Err(TensorError::DimensionMismatch(format!(
792 "Shapes {:?} and {:?} are not compatible for element-wise operations",
793 self.shape, other.shape
794 )));
795 }
796
797 let data = self
798 .as_slice()
799 .iter()
800 .zip(other.as_slice().iter())
801 .map(|(a, b)| op(a, b))
802 .collect();
803
804 let storage = TensorStorage::from_vec(data, CpuAllocator);
805
806 Ok(Tensor {
807 storage,
808 shape: self.shape,
809 strides: self.strides,
810 })
811 }
812}
813
814impl<T, const N: usize, A> Clone for Tensor<T, N, A>
815where
816 T: Clone,
817 A: TensorAllocator + Clone,
818{
819 fn clone(&self) -> Self {
820 Self {
821 storage: self.storage.clone(),
822 shape: self.shape,
823 strides: self.strides,
824 }
825 }
826}
827
828impl<T, const N: usize, A> std::fmt::Display for Tensor<T, N, A>
829where
830 T: std::fmt::Display + std::fmt::LowerExp,
831 A: TensorAllocator,
832{
833 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
834 let width = self
835 .storage
836 .as_slice()
837 .iter()
838 .map(|v| format!("{v:.4}").len())
839 .max()
840 .unwrap();
841
842 let scientific = width > 8;
843
844 let should_mask: [bool; N] = self.shape.map(|s| s > 8);
845 let mut skip_until = 0;
846
847 for (i, v) in self.storage.as_slice().iter().enumerate() {
848 if i < skip_until {
849 continue;
850 }
851 let mut value = String::new();
852 let mut prefix = String::new();
853 let mut suffix = String::new();
854 let mut separator = ",".to_string();
855 let mut last_size = 1;
856 for (dim, (&size, maskable)) in self.shape.iter().zip(should_mask).enumerate().rev() {
857 let prod = size * last_size;
858 if i % prod == (3 * last_size) && maskable {
859 let pad = if dim == (N - 1) { 0 } else { dim + 1 };
860 value = format!("{}...", " ".repeat(pad));
861 skip_until = i + (size - 4) * last_size;
862 prefix = "".to_string();
863 if dim != (N - 1) {
864 separator = "\n".repeat(N - 1 - dim);
865 }
866 break;
867 } else if i % prod == 0 {
868 prefix.push('[');
869 } else if (i + 1) % prod == 0 {
870 suffix.push(']');
871 separator.push('\n');
872 if dim == 0 {
873 separator = "".to_string();
874 }
875 } else {
876 break;
877 }
878 last_size = prod;
879 }
880 if !prefix.is_empty() {
881 prefix = format!("{prefix:>N$}");
882 }
883
884 if value.is_empty() {
885 value = if scientific {
886 let num = format!("{v:.4e}");
887 let (before, after) = num.split_once('e').unwrap();
888 let after = if let Some(stripped) = after.strip_prefix('-') {
889 format!("-{:0>2}", &stripped)
890 } else {
891 format!("+{:0>2}", &after)
892 };
893 format!("{before}e{after}")
894 } else {
895 let rounded = format!("{v:.4}");
896 format!("{rounded:>width$}")
897 }
898 };
899 write!(f, "{prefix}{value}{suffix}{separator}",)?;
900 }
901 Ok(())
902 }
903}
904
905#[cfg(test)]
906mod tests {
907 use crate::allocator::CpuAllocator;
908 use crate::tensor::{Tensor, TensorError};
909
910 #[test]
911 fn constructor_1d() -> Result<(), TensorError> {
912 let data: Vec<u8> = vec![1];
913 let t = Tensor::<u8, 1, _>::from_shape_vec([1], data, CpuAllocator)?;
914 assert_eq!(t.shape, [1]);
915 assert_eq!(t.as_slice(), vec![1]);
916 assert_eq!(t.strides, [1]);
917 assert_eq!(t.numel(), 1);
918 Ok(())
919 }
920
921 #[test]
922 fn constructor_2d() -> Result<(), TensorError> {
923 let data: Vec<u8> = vec![1, 2];
924 let t = Tensor::<u8, 2, _>::from_shape_vec([1, 2], data, CpuAllocator)?;
925 assert_eq!(t.shape, [1, 2]);
926 assert_eq!(t.as_slice(), vec![1, 2]);
927 assert_eq!(t.strides, [2, 1]);
928 assert_eq!(t.numel(), 2);
929 Ok(())
930 }
931
932 #[test]
933 fn get_1d() -> Result<(), TensorError> {
934 let data: Vec<u8> = vec![1, 2, 3, 4];
935 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
936 assert_eq!(t.get([0]), Some(&1));
937 assert_eq!(t.get([1]), Some(&2));
938 assert_eq!(t.get([2]), Some(&3));
939 assert_eq!(t.get([3]), Some(&4));
940 assert!(t.get([4]).is_none());
941 Ok(())
942 }
943
944 #[test]
945 fn get_2d() -> Result<(), TensorError> {
946 let data: Vec<u8> = vec![1, 2, 3, 4];
947 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
948 assert_eq!(t.get([0, 0]), Some(&1));
949 assert_eq!(t.get([0, 1]), Some(&2));
950 assert_eq!(t.get([1, 0]), Some(&3));
951 assert_eq!(t.get([1, 1]), Some(&4));
952 assert!(t.get([2, 0]).is_none());
953 assert!(t.get([0, 2]).is_none());
954 Ok(())
955 }
956
957 #[test]
958 fn get_3d() -> Result<(), TensorError> {
959 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
960 let t = Tensor::<u8, 3, _>::from_shape_vec([2, 1, 3], data, CpuAllocator)?;
961 assert_eq!(t.get([0, 0, 0]), Some(&1));
962 assert_eq!(t.get([0, 0, 1]), Some(&2));
963 assert_eq!(t.get([0, 0, 2]), Some(&3));
964 assert_eq!(t.get([1, 0, 0]), Some(&4));
965 assert_eq!(t.get([1, 0, 1]), Some(&5));
966 assert_eq!(t.get([1, 0, 2]), Some(&6));
967 assert!(t.get([2, 0, 0]).is_none());
968 assert!(t.get([0, 1, 0]).is_none());
969 assert!(t.get([0, 0, 3]).is_none());
970 Ok(())
971 }
972
973 #[test]
974 fn get_checked_1d() -> Result<(), TensorError> {
975 let data: Vec<u8> = vec![1, 2, 3, 4];
976 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
977 assert_eq!(*t.get_unchecked([0]), 1);
978 assert_eq!(*t.get_unchecked([1]), 2);
979 assert_eq!(*t.get_unchecked([2]), 3);
980 assert_eq!(*t.get_unchecked([3]), 4);
981 Ok(())
982 }
983
984 #[test]
985 fn get_checked_2d() -> Result<(), TensorError> {
986 let data: Vec<u8> = vec![1, 2, 3, 4];
987 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
988 assert_eq!(*t.get_unchecked([0, 0]), 1);
989 assert_eq!(*t.get_unchecked([0, 1]), 2);
990 assert_eq!(*t.get_unchecked([1, 0]), 3);
991 assert_eq!(*t.get_unchecked([1, 1]), 4);
992 Ok(())
993 }
994 #[test]
995 fn reshape_1d() -> Result<(), TensorError> {
996 let data: Vec<u8> = vec![1, 2, 3, 4];
997 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
998
999 let view = t.reshape([2, 2])?;
1000
1001 assert_eq!(view.shape, [2, 2]);
1002 assert_eq!(view.as_slice(), vec![1, 2, 3, 4]);
1003 assert_eq!(view.strides, [2, 1]);
1004 assert_eq!(view.numel(), 4);
1005 assert_eq!(view.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
1006 Ok(())
1007 }
1008
1009 #[test]
1010 fn reshape_2d() -> Result<(), TensorError> {
1011 let data: Vec<u8> = vec![1, 2, 3, 4];
1012 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1013 let t2 = t.reshape([4])?;
1014
1015 assert_eq!(t2.shape, [4]);
1016 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1017 assert_eq!(t2.strides, [1]);
1018 assert_eq!(t2.numel(), 4);
1019 assert_eq!(t2.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
1020 Ok(())
1021 }
1022
1023 #[test]
1024 fn reshape_get_1d() -> Result<(), TensorError> {
1025 let data: Vec<u8> = vec![1, 2, 3, 4];
1026 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1027 let view = t.reshape([2, 2])?;
1028 assert_eq!(*view.get_unchecked([0, 0]), 1);
1029 assert_eq!(*view.get_unchecked([0, 1]), 2);
1030 assert_eq!(*view.get_unchecked([1, 0]), 3);
1031 assert_eq!(*view.get_unchecked([1, 1]), 4);
1032 assert_eq!(view.numel(), 4);
1033 assert_eq!(view.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
1034 Ok(())
1035 }
1036
1037 #[test]
1038 fn permute_axes_1d() -> Result<(), TensorError> {
1039 let data: Vec<u8> = vec![1, 2, 3, 4];
1040 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1041 let t2 = t.permute_axes([0]);
1042 assert_eq!(t2.shape, [4]);
1043 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1044 assert_eq!(t2.strides, [1]);
1045 assert_eq!(t2.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
1046 Ok(())
1047 }
1048
1049 #[test]
1050 fn permute_axes_2d() -> Result<(), TensorError> {
1051 let data: Vec<u8> = vec![1, 2, 3, 4];
1052 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1053 let view = t.permute_axes([1, 0]);
1054 assert_eq!(view.shape, [2, 2]);
1055 assert_eq!(*view.get_unchecked([0, 0]), 1u8);
1056 assert_eq!(*view.get_unchecked([1, 0]), 2u8);
1057 assert_eq!(*view.get_unchecked([0, 1]), 3u8);
1058 assert_eq!(*view.get_unchecked([1, 1]), 4u8);
1059 assert_eq!(view.strides, [1, 2]);
1060 assert_eq!(view.as_contiguous().as_slice(), vec![1, 3, 2, 4]);
1061 Ok(())
1062 }
1063
1064 #[test]
1065 fn contiguous_2d() -> Result<(), TensorError> {
1066 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
1067 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 3], data, CpuAllocator)?;
1068
1069 let view = t.permute_axes([1, 0]);
1070
1071 let contiguous = view.as_contiguous();
1072
1073 assert_eq!(contiguous.shape, [3, 2]);
1074 assert_eq!(contiguous.strides, [2, 1]);
1075 assert_eq!(contiguous.as_slice(), vec![1, 4, 2, 5, 3, 6]);
1076
1077 Ok(())
1078 }
1079
1080 #[test]
1081 fn zeros_1d() -> Result<(), TensorError> {
1082 let t = Tensor::<u8, 1, _>::zeros([4], CpuAllocator);
1083 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1084 Ok(())
1085 }
1086
1087 #[test]
1088 fn zeros_2d() -> Result<(), TensorError> {
1089 let t = Tensor::<u8, 2, _>::zeros([2, 2], CpuAllocator);
1090 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1091 Ok(())
1092 }
1093
1094 #[test]
1095 fn map_1d() -> Result<(), TensorError> {
1096 let data: Vec<u8> = vec![1, 2, 3, 4];
1097 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1098 let t2 = t.map(|x| *x + 1);
1099 assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
1100 Ok(())
1101 }
1102
1103 #[test]
1104 fn map_2d() -> Result<(), TensorError> {
1105 let data: Vec<u8> = vec![1, 2, 3, 4];
1106 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1107 let t2 = t.map(|x| *x + 1);
1108 assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
1109 Ok(())
1110 }
1111
1112 #[test]
1113 fn from_shape_val_1d() -> Result<(), TensorError> {
1114 let t = Tensor::<u8, 1, _>::from_shape_val([4], 0, CpuAllocator);
1115 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1116 Ok(())
1117 }
1118
1119 #[test]
1120 fn from_shape_val_2d() -> Result<(), TensorError> {
1121 let t = Tensor::<u8, 2, _>::from_shape_val([2, 2], 1, CpuAllocator);
1122 assert_eq!(t.as_slice(), vec![1, 1, 1, 1]);
1123 Ok(())
1124 }
1125
1126 #[test]
1127 fn from_shape_val_3d() -> Result<(), TensorError> {
1128 let t = Tensor::<u8, 3, _>::from_shape_val([2, 1, 3], 2, CpuAllocator);
1129 assert_eq!(t.as_slice(), vec![2, 2, 2, 2, 2, 2]);
1130 Ok(())
1131 }
1132
1133 #[test]
1134 fn cast_1d() -> Result<(), TensorError> {
1135 let data: Vec<u8> = vec![1, 2, 3, 4];
1136 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1137 let t2 = t.cast::<u16>();
1138 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1139 Ok(())
1140 }
1141
1142 #[test]
1143 fn cast_2d() -> Result<(), TensorError> {
1144 let data: Vec<u8> = vec![1, 2, 3, 4];
1145 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1146 let t2 = t.cast::<u16>();
1147 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1148 Ok(())
1149 }
1150
1151 #[test]
1152 fn from_shape_fn_1d() -> Result<(), TensorError> {
1153 let alloc = CpuAllocator;
1154 let t = Tensor::from_shape_fn([3, 3], alloc, |[i, j]| ((1 + i) * (1 + j)) as u8);
1155 assert_eq!(t.as_slice(), vec![1, 2, 3, 2, 4, 6, 3, 6, 9]);
1156 Ok(())
1157 }
1158
1159 #[test]
1160 fn from_shape_fn_2d() -> Result<(), TensorError> {
1161 let alloc = CpuAllocator;
1162 let t = Tensor::from_shape_fn([3, 3], alloc, |[i, j]| ((1 + i) * (1 + j)) as f32);
1163 assert_eq!(
1164 t.as_slice(),
1165 vec![1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0]
1166 );
1167 Ok(())
1168 }
1169
1170 #[test]
1171 fn from_shape_fn_3d() -> Result<(), TensorError> {
1172 let alloc = CpuAllocator;
1173 let t = Tensor::from_shape_fn([2, 3, 3], alloc, |[x, y, c]| {
1174 ((1 + x) * (1 + y) * (1 + c)) as i16
1175 });
1176 assert_eq!(
1177 t.as_slice(),
1178 vec![1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18]
1179 );
1180 Ok(())
1181 }
1182
1183 #[test]
1184 fn view_1d() -> Result<(), TensorError> {
1185 let alloc = CpuAllocator;
1186 let data: Vec<u8> = vec![1, 2, 3, 4];
1187 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, alloc)?;
1188 let view = t.view();
1189
1190 assert_eq!(view.as_slice(), t.as_slice());
1192
1193 assert!(std::ptr::eq(view.as_ptr(), t.as_ptr()));
1195
1196 Ok(())
1197 }
1198
1199 #[test]
1200 fn from_slice() -> Result<(), TensorError> {
1201 let data: [u8; 4] = [1, 2, 3, 4];
1202 let t = Tensor::<u8, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1203
1204 assert_eq!(t.shape, [2, 2]);
1205 assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
1206
1207 Ok(())
1208 }
1209
1210 #[test]
1211 fn display_2d() -> Result<(), TensorError> {
1212 let data: [u8; 4] = [1, 2, 3, 4];
1213 let t = Tensor::<u8, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1214 let disp = t.to_string();
1215 let lines = disp.lines().collect::<Vec<_>>();
1216
1217 #[rustfmt::skip]
1218 assert_eq!(lines.as_slice(),
1219 ["[[1,2],",
1220 " [3,4]]"]);
1221 Ok(())
1222 }
1223
1224 #[test]
1225 fn display_3d() -> Result<(), TensorError> {
1226 let data: [u8; 12] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1227 let t = Tensor::<u8, 3, _>::from_shape_slice([2, 3, 2], &data, CpuAllocator)?;
1228 let disp = t.to_string();
1229 let lines = disp.lines().collect::<Vec<_>>();
1230
1231 #[rustfmt::skip]
1232 assert_eq!(lines.as_slice(),
1233 ["[[[ 1, 2],",
1234 " [ 3, 4],",
1235 " [ 5, 6]],",
1236 "",
1237 " [[ 7, 8],",
1238 " [ 9,10],",
1239 " [11,12]]]"]);
1240 Ok(())
1241 }
1242
1243 #[test]
1244 fn display_float() -> Result<(), TensorError> {
1245 let data: [f32; 4] = [1.00001, 1.00009, 0.99991, 0.99999];
1246 let t = Tensor::<f32, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1247 let disp = t.to_string();
1248 let lines = disp.lines().collect::<Vec<_>>();
1249
1250 #[rustfmt::skip]
1251 assert_eq!(lines.as_slice(),
1252 ["[[1.0000,1.0001],",
1253 " [0.9999,1.0000]]"]);
1254 Ok(())
1255 }
1256
1257 #[test]
1258 fn display_big_float() -> Result<(), TensorError> {
1259 let data: [f32; 4] = [1000.00001, 1.00009, 0.99991, 0.99999];
1260 let t = Tensor::<f32, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1261 let disp = t.to_string();
1262 let lines = disp.lines().collect::<Vec<_>>();
1263
1264 #[rustfmt::skip]
1265 assert_eq!(lines.as_slice(),
1266 ["[[1.0000e+03,1.0001e+00],",
1267 " [9.9991e-01,9.9999e-01]]"]);
1268 Ok(())
1269 }
1270
1271 #[test]
1272 fn display_big_tensor() -> Result<(), TensorError> {
1273 let data: [u8; 1000] = [0; 1000];
1274 let t = Tensor::<u8, 3, _>::from_shape_slice([10, 10, 10], &data, CpuAllocator)?;
1275 let disp = t.to_string();
1276 let lines = disp.lines().collect::<Vec<_>>();
1277
1278 #[rustfmt::skip]
1279 assert_eq!(lines.as_slice(),
1280 ["[[[0,0,0,...,0],",
1281 " [0,0,0,...,0],",
1282 " [0,0,0,...,0],",
1283 " ...",
1284 " [0,0,0,...,0]],",
1285 "",
1286 " [[0,0,0,...,0],",
1287 " [0,0,0,...,0],",
1288 " [0,0,0,...,0],",
1289 " ...",
1290 " [0,0,0,...,0]],",
1291 "",
1292 " [[0,0,0,...,0],",
1293 " [0,0,0,...,0],",
1294 " [0,0,0,...,0],",
1295 " ...",
1296 " [0,0,0,...,0]],",
1297 "",
1298 " ...",
1299 "",
1300 " [[0,0,0,...,0],",
1301 " [0,0,0,...,0],",
1302 " [0,0,0,...,0],",
1303 " ...",
1304 " [0,0,0,...,0]]]"]);
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn get_index_unchecked_1d() -> Result<(), TensorError> {
1310 let data: Vec<u8> = vec![1, 2, 3, 4];
1311 let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator)?;
1312 assert_eq!(t.get_index_unchecked(0), [0]);
1313 assert_eq!(t.get_index_unchecked(1), [1]);
1314 assert_eq!(t.get_index_unchecked(2), [2]);
1315 assert_eq!(t.get_index_unchecked(3), [3]);
1316 Ok(())
1317 }
1318
1319 #[test]
1320 fn get_index_unchecked_2d() -> Result<(), TensorError> {
1321 let data: Vec<u8> = vec![1, 2, 3, 4];
1322 let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator)?;
1323 assert_eq!(t.get_index_unchecked(0), [0, 0]);
1324 assert_eq!(t.get_index_unchecked(1), [0, 1]);
1325 assert_eq!(t.get_index_unchecked(2), [1, 0]);
1326 assert_eq!(t.get_index_unchecked(3), [1, 1]);
1327 Ok(())
1328 }
1329
1330 #[test]
1331 fn get_index_unchecked_3d() -> Result<(), TensorError> {
1332 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1333 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1334 assert_eq!(t.get_index_unchecked(0), [0, 0, 0]);
1335 assert_eq!(t.get_index_unchecked(1), [0, 0, 1]);
1336 assert_eq!(t.get_index_unchecked(2), [0, 0, 2]);
1337 assert_eq!(t.get_index_unchecked(3), [0, 1, 0]);
1338 assert_eq!(t.get_index_unchecked(4), [0, 1, 1]);
1339 assert_eq!(t.get_index_unchecked(5), [0, 1, 2]);
1340 assert_eq!(t.get_index_unchecked(6), [1, 0, 0]);
1341 assert_eq!(t.get_index_unchecked(7), [1, 0, 1]);
1342 assert_eq!(t.get_index_unchecked(8), [1, 0, 2]);
1343 assert_eq!(t.get_index_unchecked(9), [1, 1, 0]);
1344 assert_eq!(t.get_index_unchecked(10), [1, 1, 1]);
1345 assert_eq!(t.get_index_unchecked(11), [1, 1, 2]);
1346 Ok(())
1347 }
1348
1349 #[test]
1350 fn get_index_to_offset_and_back() -> Result<(), TensorError> {
1351 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1352 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1353 for offset in 0..12 {
1354 assert_eq!(
1355 t.get_iter_offset_unchecked(t.get_index_unchecked(offset)),
1356 offset
1357 );
1358 }
1359 Ok(())
1360 }
1361
1362 #[test]
1363 fn get_offset_to_index_and_back() -> Result<(), TensorError> {
1364 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1365 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1366 for ind in [
1367 [0, 0, 0],
1368 [0, 0, 1],
1369 [0, 0, 2],
1370 [0, 1, 0],
1371 [0, 1, 1],
1372 [0, 1, 2],
1373 [1, 0, 0],
1374 [1, 0, 1],
1375 [1, 0, 2],
1376 [1, 1, 0],
1377 [1, 1, 1],
1378 [1, 1, 2],
1379 ] {
1380 assert_eq!(t.get_index_unchecked(t.get_iter_offset_unchecked(ind)), ind);
1381 }
1382 Ok(())
1383 }
1384
1385 #[test]
1386 fn get_index_1d() -> Result<(), TensorError> {
1387 let data: Vec<u8> = vec![1, 2, 3, 4];
1388 let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator)?;
1389 assert_eq!(t.get_index(3), Ok([3]));
1390 assert!(t
1391 .get_index(4)
1392 .is_err_and(|x| x == TensorError::IndexOutOfBounds(4)));
1393 Ok(())
1394 }
1395
1396 #[test]
1397 fn get_index_2d() -> Result<(), TensorError> {
1398 let data: Vec<u8> = vec![1, 2, 3, 4];
1399 let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator)?;
1400 assert_eq!(t.get_index_unchecked(3), [1, 1]);
1401 assert!(t
1402 .get_index(4)
1403 .is_err_and(|x| x == TensorError::IndexOutOfBounds(4)));
1404 Ok(())
1405 }
1406
1407 #[test]
1408 fn get_index_3d() -> Result<(), TensorError> {
1409 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1410 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1411 assert_eq!(t.get_index_unchecked(11), [1, 1, 2]);
1412 assert!(t
1413 .get_index(12)
1414 .is_err_and(|x| x == TensorError::IndexOutOfBounds(12)));
1415 Ok(())
1416 }
1417
1418 #[test]
1419 fn from_raw_parts() -> Result<(), TensorError> {
1420 let data: Vec<u8> = vec![1, 2, 3, 4];
1421 let t = unsafe { Tensor::from_raw_parts([2, 2], data.as_ptr(), data.len(), CpuAllocator)? };
1422 std::mem::forget(data);
1423 assert_eq!(t.shape, [2, 2]);
1424 assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
1425 Ok(())
1426 }
1427
1428 #[test]
1429 fn contiguous_tensor_is_standard_layout_true() -> Result<(), TensorError> {
1430 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1431 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1432 assert!(t.is_standard_layout());
1433 Ok(())
1434 }
1435
1436 #[test]
1437 fn broken_stride_is_standard_layout_false() -> Result<(), TensorError> {
1438 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1439 let mut t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1440 t.strides = [10, 5, 1];
1442 assert!(!t.is_standard_layout());
1443 Ok(())
1444 }
1445
1446 #[test]
1447 fn contiguous_tensor_roundtrip() -> Result<(), TensorError> {
1448 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1449 let t =
1450 Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data.clone(), CpuAllocator)?;
1451 assert!(t.is_standard_layout());
1452 match t.to_standard_layout(CpuAllocator) {
1453 Ok(t2) => {
1454 assert!(t2.is_standard_layout());
1455 assert_eq!(t2.storage.as_slice(), data.as_slice());
1456 }
1457 Err(e) => return Err(e),
1458 }
1459 Ok(())
1460 }
1461
1462 #[test]
1463 fn non_contiguous_to_standard_layout() -> Result<(), TensorError> {
1464 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1465 let mut t =
1466 Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data.clone(), CpuAllocator)?;
1467 t.strides = [1, 6, 2];
1469 assert!(!t.is_standard_layout());
1470 match t.to_standard_layout(CpuAllocator) {
1471 Ok(t2) => {
1472 assert!(t2.is_standard_layout());
1473 }
1474 Err(e) => return Err(e),
1475 }
1476 Ok(())
1477 }
1478}