1use thiserror::Error;
2
3use super::{
4 allocator::{CpuAllocator, TensorAllocator, TensorAllocatorError},
5 storage::TensorStorage,
6 view::TensorView,
7};
8
9#[derive(Error, Debug, PartialEq)]
11pub enum TensorError {
12 #[error("Failed to cast data")]
14 CastError,
15
16 #[error("The number of elements in the data does not match the shape of the tensor: {0}")]
18 InvalidShape(usize),
19
20 #[error("Index out of bounds. The index {0} is out of bounds.")]
22 IndexOutOfBounds(usize),
23
24 #[error("Error with the tensor storage: {0}")]
26 StorageError(#[from] TensorAllocatorError),
27
28 #[error("Dimension mismatch: {0}")]
30 DimensionMismatch(String),
31
32 #[error("Unsupported operation: {0}")]
34 UnsupportedOperation(String),
35}
36
37pub fn get_strides_from_shape<const N: usize>(shape: [usize; N]) -> [usize; N] {
47 let mut strides: [usize; N] = [0; N];
48 let mut stride = 1;
49 for i in (0..shape.len()).rev() {
50 strides[i] = stride;
51 stride *= shape[i];
52 }
53 strides
54}
55
56pub struct Tensor<T, const N: usize, A: TensorAllocator> {
77 pub storage: TensorStorage<T, A>,
79 pub shape: [usize; N],
81 pub strides: [usize; N],
83}
84
85impl<T, const N: usize, A: TensorAllocator> Tensor<T, N, A>
86where
87 A: 'static,
88{
89 #[inline]
95 pub fn as_slice(&self) -> &[T] {
96 self.storage.as_slice()
97 }
98
99 #[inline]
105 pub fn as_slice_mut(&mut self) -> &mut [T] {
106 self.storage.as_mut_slice()
107 }
108
109 #[inline]
115 pub fn as_ptr(&self) -> *const T {
116 self.storage.as_ptr()
117 }
118
119 #[inline]
125 pub fn as_mut_ptr(&mut self) -> *mut T {
126 self.storage.as_mut_ptr()
127 }
128
129 #[inline]
135 pub fn into_vec(self) -> Vec<T> {
136 self.storage.into_vec()
137 }
138
139 pub fn from_shape_vec(shape: [usize; N], data: Vec<T>, alloc: A) -> Result<Self, TensorError> {
165 let numel = shape.iter().product::<usize>();
166 if numel != data.len() {
167 return Err(TensorError::InvalidShape(numel));
168 }
169 let storage = TensorStorage::from_vec(data, alloc);
170 let strides = get_strides_from_shape(shape);
171 Ok(Self {
172 storage,
173 shape,
174 strides,
175 })
176 }
177
178 pub fn from_shape_slice(shape: [usize; N], data: &[T], alloc: A) -> Result<Self, TensorError>
194 where
195 T: Clone,
196 {
197 let numel = shape.iter().product::<usize>();
198 if numel != data.len() {
199 return Err(TensorError::InvalidShape(numel));
200 }
201 let storage = TensorStorage::from_vec(data.to_vec(), alloc);
202 let strides = get_strides_from_shape(shape);
203 Ok(Self {
204 storage,
205 shape,
206 strides,
207 })
208 }
209
210 pub unsafe fn from_raw_parts(
223 shape: [usize; N],
224 data: *const T,
225 len: usize,
226 alloc: A,
227 ) -> Result<Self, TensorError>
228 where
229 T: Clone,
230 {
231 let storage = TensorStorage::from_raw_parts(data, len, alloc);
232 let strides = get_strides_from_shape(shape);
233 Ok(Self {
234 storage,
235 shape,
236 strides,
237 })
238 }
239
240 pub fn from_shape_val(shape: [usize; N], value: T, alloc: A) -> Self
267 where
268 T: Clone,
269 {
270 let numel = shape.iter().product::<usize>();
271 let data = vec![value; numel];
272 let storage = TensorStorage::from_vec(data, alloc);
273 let strides = get_strides_from_shape(shape);
274 Self {
275 storage,
276 shape,
277 strides,
278 }
279 }
280
281 pub fn from_shape_fn<F>(shape: [usize; N], alloc: A, f: F) -> Self
306 where
307 F: Fn([usize; N]) -> T,
308 {
309 let numel = shape.iter().product::<usize>();
310 let data: Vec<T> = (0..numel)
311 .map(|i| {
312 let mut index = [0; N];
313 let mut j = i;
314 for k in (0..N).rev() {
315 index[k] = j % shape[k];
316 j /= shape[k];
317 }
318 f(index)
319 })
320 .collect();
321 let storage = TensorStorage::from_vec(data, alloc);
322 let strides = get_strides_from_shape(shape);
323 Self {
324 storage,
325 shape,
326 strides,
327 }
328 }
329
330 #[inline]
336 pub fn numel(&self) -> usize {
337 self.storage.len() / std::mem::size_of::<T>()
338 }
339
340 pub fn get_iter_offset(&self, index: [usize; N]) -> Option<usize> {
350 let mut offset = 0;
351 for ((&idx, dim_size), stride) in index.iter().zip(self.shape).zip(self.strides) {
352 if idx >= dim_size {
353 return None;
354 }
355 offset += idx * stride;
356 }
357 Some(offset)
358 }
359
360 pub fn get_iter_offset_unchecked(&self, index: [usize; N]) -> usize {
370 let mut offset = 0;
371 for (&idx, stride) in index.iter().zip(self.strides) {
372 offset += idx * stride;
373 }
374 offset
375 }
376
377 pub fn get_index_unchecked(&self, offset: usize) -> [usize; N] {
387 let mut idx = [0; N];
388 let mut rem = offset;
389 for (dim_i, s) in self.strides.iter().enumerate() {
390 idx[dim_i] = rem / s;
391 rem = offset % s;
392 }
393
394 idx
395 }
396
397 pub fn get_index(&self, offset: usize) -> Result<[usize; N], TensorError> {
411 if offset >= self.numel() {
412 return Err(TensorError::IndexOutOfBounds(offset));
413 }
414 let idx = self.get_index_unchecked(offset);
415
416 Ok(idx)
417 }
418
419 pub fn get_unchecked(&self, index: [usize; N]) -> &T {
443 let offset = self.get_iter_offset_unchecked(index);
444 unsafe { self.storage.as_slice().get_unchecked(offset) }
445 }
446
447 pub fn get(&self, index: [usize; N]) -> Option<&T> {
478 self.get_iter_offset(index)
479 .and_then(|i| self.storage.as_slice().get(i))
480 }
481
482 pub fn reshape<const M: usize>(
511 &self,
512 shape: [usize; M],
513 ) -> Result<TensorView<T, M, A>, TensorError> {
514 let numel = shape.iter().product::<usize>();
515 if numel != self.storage.len() {
516 return Err(TensorError::DimensionMismatch(format!(
517 "Cannot reshape tensor of shape {:?} with {} elements to shape {:?} with {} elements",
518 self.shape, self.storage.len(), shape, numel
519 )));
520 }
521
522 let strides = get_strides_from_shape(shape);
523
524 Ok(TensorView {
525 storage: &self.storage,
526 shape,
527 strides,
528 })
529 }
530
531 pub fn permute_axes(&self, axes: [usize; N]) -> TensorView<T, N, A> {
544 let mut new_shape = [0; N];
545 let mut new_strides = [0; N];
546 for (i, &axis) in axes.iter().enumerate() {
547 new_shape[i] = self.shape[axis];
548 new_strides[i] = self.strides[axis];
549 }
550
551 TensorView {
552 storage: &self.storage,
553 shape: new_shape,
554 strides: new_strides,
555 }
556 }
557
558 pub fn view(&self) -> TensorView<T, N, A> {
566 TensorView {
567 storage: &self.storage,
568 shape: self.shape,
569 strides: self.strides,
570 }
571 }
572
573 pub fn zeros(shape: [usize; N], alloc: A) -> Tensor<T, N, A>
582 where
583 T: Clone + num_traits::Zero,
584 {
585 Self::from_shape_val(shape, T::zero(), alloc)
587 }
588
589 pub fn map<U, F>(&self, f: F) -> Tensor<U, N, A>
611 where
612 F: Fn(&T) -> U,
613 {
614 let data: Vec<U> = self.as_slice().iter().map(f).collect();
615 let storage = TensorStorage::from_vec(data, self.storage.alloc().clone());
616
617 Tensor {
618 storage,
619 shape: self.shape,
620 strides: self.strides,
621 }
622 }
623
624 pub fn cast<U>(&self) -> Tensor<U, N, CpuAllocator>
642 where
643 U: From<T>,
644 T: Clone,
645 {
646 let mut data: Vec<U> = Vec::with_capacity(self.storage.len());
647 self.as_slice().iter().for_each(|x| {
648 data.push(U::from(x.clone()));
649 });
650 let storage = TensorStorage::from_vec(data, CpuAllocator);
651 Tensor {
652 storage,
653 shape: self.shape,
654 strides: self.strides,
655 }
656 }
657
658 pub fn element_wise_op<F>(
693 &self,
694 other: &Tensor<T, N, CpuAllocator>,
695 op: F,
696 ) -> Result<Tensor<T, N, CpuAllocator>, TensorError>
697 where
698 F: Fn(&T, &T) -> T,
699 {
700 if self.shape != other.shape {
701 return Err(TensorError::DimensionMismatch(format!(
702 "Shapes {:?} and {:?} are not compatible for element-wise operations",
703 self.shape, other.shape
704 )));
705 }
706
707 let data = self
708 .as_slice()
709 .iter()
710 .zip(other.as_slice().iter())
711 .map(|(a, b)| op(a, b))
712 .collect();
713
714 let storage = TensorStorage::from_vec(data, CpuAllocator);
715
716 Ok(Tensor {
717 storage,
718 shape: self.shape,
719 strides: self.strides,
720 })
721 }
722}
723
724impl<T, const N: usize, A> Clone for Tensor<T, N, A>
725where
726 T: Clone,
727 A: TensorAllocator + Clone + 'static,
728{
729 fn clone(&self) -> Self {
730 Self {
731 storage: self.storage.clone(),
732 shape: self.shape,
733 strides: self.strides,
734 }
735 }
736}
737
738impl<T, const N: usize, A> std::fmt::Display for Tensor<T, N, A>
739where
740 T: std::fmt::Display + std::fmt::LowerExp,
741 A: TensorAllocator + 'static,
742{
743 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744 let width = self
745 .storage
746 .as_slice()
747 .iter()
748 .map(|v| format!("{v:.4}").len())
749 .max()
750 .unwrap();
751
752 let scientific = width > 8;
753
754 let should_mask: [bool; N] = self.shape.map(|s| s > 8);
755 let mut skip_until = 0;
756
757 for (i, v) in self.storage.as_slice().iter().enumerate() {
758 if i < skip_until {
759 continue;
760 }
761 let mut value = String::new();
762 let mut prefix = String::new();
763 let mut suffix = String::new();
764 let mut separator = ",".to_string();
765 let mut last_size = 1;
766 for (dim, (&size, maskable)) in self.shape.iter().zip(should_mask).enumerate().rev() {
767 let prod = size * last_size;
768 if i % prod == (3 * last_size) && maskable {
769 let pad = if dim == (N - 1) { 0 } else { dim + 1 };
770 value = format!("{}...", " ".repeat(pad));
771 skip_until = i + (size - 4) * last_size;
772 prefix = "".to_string();
773 if dim != (N - 1) {
774 separator = "\n".repeat(N - 1 - dim);
775 }
776 break;
777 } else if i % prod == 0 {
778 prefix.push('[');
779 } else if (i + 1) % prod == 0 {
780 suffix.push(']');
781 separator.push('\n');
782 if dim == 0 {
783 separator = "".to_string();
784 }
785 } else {
786 break;
787 }
788 last_size = prod;
789 }
790 if !prefix.is_empty() {
791 prefix = format!("{prefix:>N$}");
792 }
793
794 if value.is_empty() {
795 value = if scientific {
796 let num = format!("{v:.4e}");
797 let (before, after) = num.split_once('e').unwrap();
798 let after = if let Some(stripped) = after.strip_prefix('-') {
799 format!("-{:0>2}", &stripped)
800 } else {
801 format!("+{:0>2}", &after)
802 };
803 format!("{before}e{after}")
804 } else {
805 let rounded = format!("{v:.4}");
806 format!("{rounded:>width$}")
807 }
808 };
809 write!(f, "{prefix}{value}{suffix}{separator}",)?;
810 }
811 Ok(())
812 }
813}
814
815#[cfg(test)]
816mod tests {
817 use crate::allocator::CpuAllocator;
818 use crate::tensor::{Tensor, TensorError};
819
820 #[test]
821 fn constructor_1d() -> Result<(), TensorError> {
822 let data: Vec<u8> = vec![1];
823 let t = Tensor::<u8, 1, _>::from_shape_vec([1], data, CpuAllocator)?;
824 assert_eq!(t.shape, [1]);
825 assert_eq!(t.as_slice(), vec![1]);
826 assert_eq!(t.strides, [1]);
827 assert_eq!(t.numel(), 1);
828 Ok(())
829 }
830
831 #[test]
832 fn constructor_2d() -> Result<(), TensorError> {
833 let data: Vec<u8> = vec![1, 2];
834 let t = Tensor::<u8, 2, _>::from_shape_vec([1, 2], data, CpuAllocator)?;
835 assert_eq!(t.shape, [1, 2]);
836 assert_eq!(t.as_slice(), vec![1, 2]);
837 assert_eq!(t.strides, [2, 1]);
838 assert_eq!(t.numel(), 2);
839 Ok(())
840 }
841
842 #[test]
843 fn get_1d() -> Result<(), TensorError> {
844 let data: Vec<u8> = vec![1, 2, 3, 4];
845 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
846 assert_eq!(t.get([0]), Some(&1));
847 assert_eq!(t.get([1]), Some(&2));
848 assert_eq!(t.get([2]), Some(&3));
849 assert_eq!(t.get([3]), Some(&4));
850 assert!(t.get([4]).is_none());
851 Ok(())
852 }
853
854 #[test]
855 fn get_2d() -> Result<(), TensorError> {
856 let data: Vec<u8> = vec![1, 2, 3, 4];
857 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
858 assert_eq!(t.get([0, 0]), Some(&1));
859 assert_eq!(t.get([0, 1]), Some(&2));
860 assert_eq!(t.get([1, 0]), Some(&3));
861 assert_eq!(t.get([1, 1]), Some(&4));
862 assert!(t.get([2, 0]).is_none());
863 assert!(t.get([0, 2]).is_none());
864 Ok(())
865 }
866
867 #[test]
868 fn get_3d() -> Result<(), TensorError> {
869 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
870 let t = Tensor::<u8, 3, _>::from_shape_vec([2, 1, 3], data, CpuAllocator)?;
871 assert_eq!(t.get([0, 0, 0]), Some(&1));
872 assert_eq!(t.get([0, 0, 1]), Some(&2));
873 assert_eq!(t.get([0, 0, 2]), Some(&3));
874 assert_eq!(t.get([1, 0, 0]), Some(&4));
875 assert_eq!(t.get([1, 0, 1]), Some(&5));
876 assert_eq!(t.get([1, 0, 2]), Some(&6));
877 assert!(t.get([2, 0, 0]).is_none());
878 assert!(t.get([0, 1, 0]).is_none());
879 assert!(t.get([0, 0, 3]).is_none());
880 Ok(())
881 }
882
883 #[test]
884 fn get_checked_1d() -> Result<(), TensorError> {
885 let data: Vec<u8> = vec![1, 2, 3, 4];
886 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
887 assert_eq!(*t.get_unchecked([0]), 1);
888 assert_eq!(*t.get_unchecked([1]), 2);
889 assert_eq!(*t.get_unchecked([2]), 3);
890 assert_eq!(*t.get_unchecked([3]), 4);
891 Ok(())
892 }
893
894 #[test]
895 fn get_checked_2d() -> Result<(), TensorError> {
896 let data: Vec<u8> = vec![1, 2, 3, 4];
897 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
898 assert_eq!(*t.get_unchecked([0, 0]), 1);
899 assert_eq!(*t.get_unchecked([0, 1]), 2);
900 assert_eq!(*t.get_unchecked([1, 0]), 3);
901 assert_eq!(*t.get_unchecked([1, 1]), 4);
902 Ok(())
903 }
904 #[test]
905 fn reshape_1d() -> Result<(), TensorError> {
906 let data: Vec<u8> = vec![1, 2, 3, 4];
907 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
908
909 let view = t.reshape([2, 2])?;
910
911 assert_eq!(view.shape, [2, 2]);
912 assert_eq!(view.as_slice(), vec![1, 2, 3, 4]);
913 assert_eq!(view.strides, [2, 1]);
914 assert_eq!(view.numel(), 4);
915 assert_eq!(view.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
916 Ok(())
917 }
918
919 #[test]
920 fn reshape_2d() -> Result<(), TensorError> {
921 let data: Vec<u8> = vec![1, 2, 3, 4];
922 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
923 let t2 = t.reshape([4])?;
924
925 assert_eq!(t2.shape, [4]);
926 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
927 assert_eq!(t2.strides, [1]);
928 assert_eq!(t2.numel(), 4);
929 assert_eq!(t2.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
930 Ok(())
931 }
932
933 #[test]
934 fn reshape_get_1d() -> Result<(), TensorError> {
935 let data: Vec<u8> = vec![1, 2, 3, 4];
936 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
937 let view = t.reshape([2, 2])?;
938 assert_eq!(*view.get_unchecked([0, 0]), 1);
939 assert_eq!(*view.get_unchecked([0, 1]), 2);
940 assert_eq!(*view.get_unchecked([1, 0]), 3);
941 assert_eq!(*view.get_unchecked([1, 1]), 4);
942 assert_eq!(view.numel(), 4);
943 assert_eq!(view.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
944 Ok(())
945 }
946
947 #[test]
948 fn permute_axes_1d() -> Result<(), TensorError> {
949 let data: Vec<u8> = vec![1, 2, 3, 4];
950 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
951 let t2 = t.permute_axes([0]);
952 assert_eq!(t2.shape, [4]);
953 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
954 assert_eq!(t2.strides, [1]);
955 assert_eq!(t2.as_contiguous().as_slice(), vec![1, 2, 3, 4]);
956 Ok(())
957 }
958
959 #[test]
960 fn permute_axes_2d() -> Result<(), TensorError> {
961 let data: Vec<u8> = vec![1, 2, 3, 4];
962 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
963 let view = t.permute_axes([1, 0]);
964 assert_eq!(view.shape, [2, 2]);
965 assert_eq!(*view.get_unchecked([0, 0]), 1u8);
966 assert_eq!(*view.get_unchecked([1, 0]), 2u8);
967 assert_eq!(*view.get_unchecked([0, 1]), 3u8);
968 assert_eq!(*view.get_unchecked([1, 1]), 4u8);
969 assert_eq!(view.strides, [1, 2]);
970 assert_eq!(view.as_contiguous().as_slice(), vec![1, 3, 2, 4]);
971 Ok(())
972 }
973
974 #[test]
975 fn contiguous_2d() -> Result<(), TensorError> {
976 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
977 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 3], data, CpuAllocator)?;
978
979 let view = t.permute_axes([1, 0]);
980
981 let contiguous = view.as_contiguous();
982
983 assert_eq!(contiguous.shape, [3, 2]);
984 assert_eq!(contiguous.strides, [2, 1]);
985 assert_eq!(contiguous.as_slice(), vec![1, 4, 2, 5, 3, 6]);
986
987 Ok(())
988 }
989
990 #[test]
991 fn zeros_1d() -> Result<(), TensorError> {
992 let t = Tensor::<u8, 1, _>::zeros([4], CpuAllocator);
993 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
994 Ok(())
995 }
996
997 #[test]
998 fn zeros_2d() -> Result<(), TensorError> {
999 let t = Tensor::<u8, 2, _>::zeros([2, 2], CpuAllocator);
1000 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1001 Ok(())
1002 }
1003
1004 #[test]
1005 fn map_1d() -> Result<(), TensorError> {
1006 let data: Vec<u8> = vec![1, 2, 3, 4];
1007 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1008 let t2 = t.map(|x| *x + 1);
1009 assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
1010 Ok(())
1011 }
1012
1013 #[test]
1014 fn map_2d() -> Result<(), TensorError> {
1015 let data: Vec<u8> = vec![1, 2, 3, 4];
1016 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1017 let t2 = t.map(|x| *x + 1);
1018 assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
1019 Ok(())
1020 }
1021
1022 #[test]
1023 fn from_shape_val_1d() -> Result<(), TensorError> {
1024 let t = Tensor::<u8, 1, _>::from_shape_val([4], 0, CpuAllocator);
1025 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
1026 Ok(())
1027 }
1028
1029 #[test]
1030 fn from_shape_val_2d() -> Result<(), TensorError> {
1031 let t = Tensor::<u8, 2, _>::from_shape_val([2, 2], 1, CpuAllocator);
1032 assert_eq!(t.as_slice(), vec![1, 1, 1, 1]);
1033 Ok(())
1034 }
1035
1036 #[test]
1037 fn from_shape_val_3d() -> Result<(), TensorError> {
1038 let t = Tensor::<u8, 3, _>::from_shape_val([2, 1, 3], 2, CpuAllocator);
1039 assert_eq!(t.as_slice(), vec![2, 2, 2, 2, 2, 2]);
1040 Ok(())
1041 }
1042
1043 #[test]
1044 fn cast_1d() -> Result<(), TensorError> {
1045 let data: Vec<u8> = vec![1, 2, 3, 4];
1046 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, CpuAllocator)?;
1047 let t2 = t.cast::<u16>();
1048 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1049 Ok(())
1050 }
1051
1052 #[test]
1053 fn cast_2d() -> Result<(), TensorError> {
1054 let data: Vec<u8> = vec![1, 2, 3, 4];
1055 let t = Tensor::<u8, 2, _>::from_shape_vec([2, 2], data, CpuAllocator)?;
1056 let t2 = t.cast::<u16>();
1057 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
1058 Ok(())
1059 }
1060
1061 #[test]
1062 fn from_shape_fn_1d() -> Result<(), TensorError> {
1063 let alloc = CpuAllocator;
1064 let t = Tensor::from_shape_fn([3, 3], alloc, |[i, j]| ((1 + i) * (1 + j)) as u8);
1065 assert_eq!(t.as_slice(), vec![1, 2, 3, 2, 4, 6, 3, 6, 9]);
1066 Ok(())
1067 }
1068
1069 #[test]
1070 fn from_shape_fn_2d() -> Result<(), TensorError> {
1071 let alloc = CpuAllocator;
1072 let t = Tensor::from_shape_fn([3, 3], alloc, |[i, j]| ((1 + i) * (1 + j)) as f32);
1073 assert_eq!(
1074 t.as_slice(),
1075 vec![1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0]
1076 );
1077 Ok(())
1078 }
1079
1080 #[test]
1081 fn from_shape_fn_3d() -> Result<(), TensorError> {
1082 let alloc = CpuAllocator;
1083 let t = Tensor::from_shape_fn([2, 3, 3], alloc, |[x, y, c]| {
1084 ((1 + x) * (1 + y) * (1 + c)) as i16
1085 });
1086 assert_eq!(
1087 t.as_slice(),
1088 vec![1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18]
1089 );
1090 Ok(())
1091 }
1092
1093 #[test]
1094 fn view_1d() -> Result<(), TensorError> {
1095 let alloc = CpuAllocator;
1096 let data: Vec<u8> = vec![1, 2, 3, 4];
1097 let t = Tensor::<u8, 1, _>::from_shape_vec([4], data, alloc)?;
1098 let view = t.view();
1099
1100 assert_eq!(view.as_slice(), t.as_slice());
1102
1103 assert!(std::ptr::eq(view.as_ptr(), t.as_ptr()));
1105
1106 Ok(())
1107 }
1108
1109 #[test]
1110 fn from_slice() -> Result<(), TensorError> {
1111 let data: [u8; 4] = [1, 2, 3, 4];
1112 let t = Tensor::<u8, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1113
1114 assert_eq!(t.shape, [2, 2]);
1115 assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
1116
1117 Ok(())
1118 }
1119
1120 #[test]
1121 fn display_2d() -> Result<(), TensorError> {
1122 let data: [u8; 4] = [1, 2, 3, 4];
1123 let t = Tensor::<u8, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1124 let disp = t.to_string();
1125 let lines = disp.lines().collect::<Vec<_>>();
1126
1127 #[rustfmt::skip]
1128 assert_eq!(lines.as_slice(),
1129 ["[[1,2],",
1130 " [3,4]]"]);
1131 Ok(())
1132 }
1133
1134 #[test]
1135 fn display_3d() -> Result<(), TensorError> {
1136 let data: [u8; 12] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1137 let t = Tensor::<u8, 3, _>::from_shape_slice([2, 3, 2], &data, CpuAllocator)?;
1138 let disp = t.to_string();
1139 let lines = disp.lines().collect::<Vec<_>>();
1140
1141 #[rustfmt::skip]
1142 assert_eq!(lines.as_slice(),
1143 ["[[[ 1, 2],",
1144 " [ 3, 4],",
1145 " [ 5, 6]],",
1146 "",
1147 " [[ 7, 8],",
1148 " [ 9,10],",
1149 " [11,12]]]"]);
1150 Ok(())
1151 }
1152
1153 #[test]
1154 fn display_float() -> Result<(), TensorError> {
1155 let data: [f32; 4] = [1.00001, 1.00009, 0.99991, 0.99999];
1156 let t = Tensor::<f32, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1157 let disp = t.to_string();
1158 let lines = disp.lines().collect::<Vec<_>>();
1159
1160 #[rustfmt::skip]
1161 assert_eq!(lines.as_slice(),
1162 ["[[1.0000,1.0001],",
1163 " [0.9999,1.0000]]"]);
1164 Ok(())
1165 }
1166
1167 #[test]
1168 fn display_big_float() -> Result<(), TensorError> {
1169 let data: [f32; 4] = [1000.00001, 1.00009, 0.99991, 0.99999];
1170 let t = Tensor::<f32, 2, _>::from_shape_slice([2, 2], &data, CpuAllocator)?;
1171 let disp = t.to_string();
1172 let lines = disp.lines().collect::<Vec<_>>();
1173
1174 #[rustfmt::skip]
1175 assert_eq!(lines.as_slice(),
1176 ["[[1.0000e+03,1.0001e+00],",
1177 " [9.9991e-01,9.9999e-01]]"]);
1178 Ok(())
1179 }
1180
1181 #[test]
1182 fn display_big_tensor() -> Result<(), TensorError> {
1183 let data: [u8; 1000] = [0; 1000];
1184 let t = Tensor::<u8, 3, _>::from_shape_slice([10, 10, 10], &data, CpuAllocator)?;
1185 let disp = t.to_string();
1186 let lines = disp.lines().collect::<Vec<_>>();
1187
1188 #[rustfmt::skip]
1189 assert_eq!(lines.as_slice(),
1190 ["[[[0,0,0,...,0],",
1191 " [0,0,0,...,0],",
1192 " [0,0,0,...,0],",
1193 " ...",
1194 " [0,0,0,...,0]],",
1195 "",
1196 " [[0,0,0,...,0],",
1197 " [0,0,0,...,0],",
1198 " [0,0,0,...,0],",
1199 " ...",
1200 " [0,0,0,...,0]],",
1201 "",
1202 " [[0,0,0,...,0],",
1203 " [0,0,0,...,0],",
1204 " [0,0,0,...,0],",
1205 " ...",
1206 " [0,0,0,...,0]],",
1207 "",
1208 " ...",
1209 "",
1210 " [[0,0,0,...,0],",
1211 " [0,0,0,...,0],",
1212 " [0,0,0,...,0],",
1213 " ...",
1214 " [0,0,0,...,0]]]"]);
1215 Ok(())
1216 }
1217
1218 #[test]
1219 fn get_index_unchecked_1d() -> Result<(), TensorError> {
1220 let data: Vec<u8> = vec![1, 2, 3, 4];
1221 let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator)?;
1222 assert_eq!(t.get_index_unchecked(0), [0]);
1223 assert_eq!(t.get_index_unchecked(1), [1]);
1224 assert_eq!(t.get_index_unchecked(2), [2]);
1225 assert_eq!(t.get_index_unchecked(3), [3]);
1226 Ok(())
1227 }
1228
1229 #[test]
1230 fn get_index_unchecked_2d() -> Result<(), TensorError> {
1231 let data: Vec<u8> = vec![1, 2, 3, 4];
1232 let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator)?;
1233 assert_eq!(t.get_index_unchecked(0), [0, 0]);
1234 assert_eq!(t.get_index_unchecked(1), [0, 1]);
1235 assert_eq!(t.get_index_unchecked(2), [1, 0]);
1236 assert_eq!(t.get_index_unchecked(3), [1, 1]);
1237 Ok(())
1238 }
1239
1240 #[test]
1241 fn get_index_unchecked_3d() -> Result<(), TensorError> {
1242 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1243 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1244 assert_eq!(t.get_index_unchecked(0), [0, 0, 0]);
1245 assert_eq!(t.get_index_unchecked(1), [0, 0, 1]);
1246 assert_eq!(t.get_index_unchecked(2), [0, 0, 2]);
1247 assert_eq!(t.get_index_unchecked(3), [0, 1, 0]);
1248 assert_eq!(t.get_index_unchecked(4), [0, 1, 1]);
1249 assert_eq!(t.get_index_unchecked(5), [0, 1, 2]);
1250 assert_eq!(t.get_index_unchecked(6), [1, 0, 0]);
1251 assert_eq!(t.get_index_unchecked(7), [1, 0, 1]);
1252 assert_eq!(t.get_index_unchecked(8), [1, 0, 2]);
1253 assert_eq!(t.get_index_unchecked(9), [1, 1, 0]);
1254 assert_eq!(t.get_index_unchecked(10), [1, 1, 1]);
1255 assert_eq!(t.get_index_unchecked(11), [1, 1, 2]);
1256 Ok(())
1257 }
1258
1259 #[test]
1260 fn get_index_to_offset_and_back() -> Result<(), TensorError> {
1261 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1262 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1263 for offset in 0..12 {
1264 assert_eq!(
1265 t.get_iter_offset_unchecked(t.get_index_unchecked(offset)),
1266 offset
1267 );
1268 }
1269 Ok(())
1270 }
1271
1272 #[test]
1273 fn get_offset_to_index_and_back() -> Result<(), TensorError> {
1274 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1275 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1276 for ind in [
1277 [0, 0, 0],
1278 [0, 0, 1],
1279 [0, 0, 2],
1280 [0, 1, 0],
1281 [0, 1, 1],
1282 [0, 1, 2],
1283 [1, 0, 0],
1284 [1, 0, 1],
1285 [1, 0, 2],
1286 [1, 1, 0],
1287 [1, 1, 1],
1288 [1, 1, 2],
1289 ] {
1290 assert_eq!(t.get_index_unchecked(t.get_iter_offset_unchecked(ind)), ind);
1291 }
1292 Ok(())
1293 }
1294
1295 #[test]
1296 fn get_index_1d() -> Result<(), TensorError> {
1297 let data: Vec<u8> = vec![1, 2, 3, 4];
1298 let t = Tensor::<u8, 1, CpuAllocator>::from_shape_vec([4], data, CpuAllocator)?;
1299 assert_eq!(t.get_index(3), Ok([3]));
1300 assert!(t
1301 .get_index(4)
1302 .is_err_and(|x| x == TensorError::IndexOutOfBounds(4)));
1303 Ok(())
1304 }
1305
1306 #[test]
1307 fn get_index_2d() -> Result<(), TensorError> {
1308 let data: Vec<u8> = vec![1, 2, 3, 4];
1309 let t = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 2], data, CpuAllocator)?;
1310 assert_eq!(t.get_index_unchecked(3), [1, 1]);
1311 assert!(t
1312 .get_index(4)
1313 .is_err_and(|x| x == TensorError::IndexOutOfBounds(4)));
1314 Ok(())
1315 }
1316
1317 #[test]
1318 fn get_index_3d() -> Result<(), TensorError> {
1319 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
1320 let t = Tensor::<u8, 3, CpuAllocator>::from_shape_vec([2, 2, 3], data, CpuAllocator)?;
1321 assert_eq!(t.get_index_unchecked(11), [1, 1, 2]);
1322 assert!(t
1323 .get_index(12)
1324 .is_err_and(|x| x == TensorError::IndexOutOfBounds(12)));
1325 Ok(())
1326 }
1327
1328 #[test]
1329 fn from_raw_parts() -> Result<(), TensorError> {
1330 let data: Vec<u8> = vec![1, 2, 3, 4];
1331 let t = unsafe { Tensor::from_raw_parts([2, 2], data.as_ptr(), data.len(), CpuAllocator)? };
1332 std::mem::forget(data);
1333 assert_eq!(t.shape, [2, 2]);
1334 assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
1335 Ok(())
1336 }
1337}