1use thiserror::Error;
2
3use super::{
4 allocator::{CpuAllocator, TensorAllocator, TensorAllocatorError},
5 storage::TensorStorage,
6};
7
8#[derive(Error, Debug)]
9pub enum TensorError {
10 #[error("The number of elements in the data does not match the shape of the tensor: {0}")]
11 InvalidShape(usize),
12
13 #[error("Index out of bounds. The index {0} is out of bounds.")]
14 IndexOutOfBounds(usize),
15
16 #[error("Error with the tensor storage: {0}")]
17 StorageError(#[from] TensorAllocatorError),
18}
19
20fn get_strides_from_shape<const N: usize>(shape: [usize; N]) -> [usize; N] {
30 let mut strides: [usize; N] = [0; N];
31 let mut stride = 1;
32 for i in (0..shape.len()).rev() {
33 strides[i] = stride;
34 stride *= shape[i];
35 }
36 strides
37}
38
39pub struct Tensor<T, const N: usize, A: TensorAllocator = CpuAllocator>
59where
60 T: arrow_buffer::ArrowNativeType,
61{
62 pub storage: TensorStorage<T, A>,
63 pub shape: [usize; N],
64 pub strides: [usize; N],
65}
66
67impl<T, const N: usize, A> Tensor<T, N, A>
69where
70 T: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
71 A: TensorAllocator,
72{
73 pub fn new_uninitialized(shape: [usize; N], alloc: A) -> Result<Self, TensorError> {
84 let numel = shape.iter().product::<usize>();
85 let strides = get_strides_from_shape(shape);
86 let storage = TensorStorage::new(numel, alloc)?;
87 Ok(Tensor {
88 storage,
89 shape,
90 strides,
91 })
92 }
93
94 pub fn as_slice(&self) -> &[T] {
100 let slice = self.storage.data.typed_data::<T>();
101 slice
102 }
103
104 pub fn as_slice_mut(&mut self) -> &mut [T] {
110 let slice = self.storage.data.typed_data::<T>();
112
113 unsafe { std::slice::from_raw_parts_mut(slice.as_ptr() as *mut T, slice.len()) }
115 }
116
117 pub fn from_shape_vec(shape: [usize; N], data: Vec<T>, alloc: A) -> Result<Self, TensorError> {
143 let numel = shape.iter().product::<usize>();
144 if numel != data.len() {
145 Err(TensorError::InvalidShape(numel))?;
146 }
147 let storage = TensorStorage::from_vec(data, alloc)?;
148 let strides = get_strides_from_shape(shape);
149 Ok(Tensor {
150 storage,
151 shape,
152 strides,
153 })
154 }
155
156 pub fn from_shape_val(shape: [usize; N], value: T, alloc: A) -> Result<Self, TensorError>
182 where
183 T: Copy,
184 {
185 let numel = shape.iter().product::<usize>();
186 let mut a = Self::new_uninitialized(shape, alloc)?;
187
188 for i in a.as_slice_mut().iter_mut().take(numel) {
189 *i = value;
190 }
191
192 Ok(a)
193 }
194
195 pub fn from_shape_fn<F>(shape: [usize; N], f: F, alloc: A) -> Self
220 where
221 F: Fn([usize; N]) -> T,
222 {
223 let numel = shape.iter().product::<usize>();
224 let data: Vec<T> = (0..numel)
225 .map(|i| {
226 let mut index = [0; N];
227 let mut j = i;
228 for k in (0..N).rev() {
229 index[k] = j % shape[k];
230 j /= shape[k];
231 }
232 f(index)
233 })
234 .collect();
235 let storage = TensorStorage::from_vec(data, alloc).unwrap();
236 let strides = get_strides_from_shape(shape);
237 Tensor {
238 storage,
239 shape,
240 strides,
241 }
242 }
243
244 pub fn numel(&self) -> usize {
246 self.storage.data.len()
247 }
248
249 pub fn get_iter_offset(&self, index: [usize; N]) -> usize {
251 let mut offset = 0;
252 for (i, &idx) in index.iter().enumerate() {
253 offset += idx * self.strides[i];
254 }
255 offset
256 }
257
258 pub fn get_unchecked(&self, index: [usize; N]) -> &T {
282 let offset = self.get_iter_offset(index);
283 &self.as_slice()[offset]
284 }
285
286 pub fn get(&self, index: [usize; N]) -> Result<&T, TensorError> {
317 let mut offset = 0;
318 for (i, &idx) in index.iter().enumerate() {
319 if idx >= self.shape[i] {
320 Err(TensorError::IndexOutOfBounds(idx))?;
321 }
322 offset += idx * self.strides[i];
323 }
324 Ok(&self.as_slice()[offset])
325 }
326
327 pub fn reshape<const M: usize>(
356 self,
357 shape: [usize; M],
358 ) -> Result<Tensor<T, M, A>, TensorError> {
359 let numel = shape.iter().product::<usize>();
360 if numel != self.storage.data.len() {
361 Err(TensorError::InvalidShape(numel))?;
362 }
363
364 let strides = get_strides_from_shape(shape);
365
366 Ok(Tensor {
367 storage: self.storage,
368 shape,
369 strides,
370 })
371 }
372
373 pub fn element_wise_op<F>(
408 &self,
409 other: &Tensor<T, N>,
410 op: F,
411 ) -> Result<Tensor<T, N>, TensorError>
412 where
413 F: Fn(&T, &T) -> T,
414 {
415 let data = self
416 .as_slice()
417 .iter()
418 .zip(other.as_slice().iter())
419 .map(|(a, b)| op(a, b))
420 .collect();
421
422 let storage = TensorStorage::from_vec(data, CpuAllocator)?;
423
424 Ok(Tensor {
425 storage,
426 shape: self.shape,
427 strides: self.strides,
428 })
429 }
430
431 pub fn add(&self, other: &Tensor<T, N>) -> Result<Tensor<T, N>, TensorError>
456 where
457 T: std::ops::Add<Output = T> + Copy,
458 {
459 self.element_wise_op(other, |a, b| *a + *b)
460 }
461
462 pub fn sub(&self, other: &Tensor<T, N>) -> Result<Tensor<T, N>, TensorError>
487 where
488 T: std::ops::Sub<Output = T> + Copy,
489 {
490 self.element_wise_op(other, |a, b| *a - *b)
491 }
492
493 pub fn mul(&self, other: &Tensor<T, N>) -> Result<Tensor<T, N>, TensorError>
518 where
519 T: std::ops::Mul<Output = T> + Copy,
520 {
521 self.element_wise_op(other, |a, b| *a * *b)
522 }
523
524 pub fn div(&self, other: &Tensor<T, N>) -> Result<Tensor<T, N>, TensorError>
549 where
550 T: std::ops::Div<Output = T> + Copy,
551 {
552 self.element_wise_op(other, |a, b| *a / *b)
553 }
554
555 pub fn zeros(shape: [usize; N], alloc: A) -> Tensor<T, N, A>
577 where
578 T: Default + Copy,
579 {
580 Self::from_shape_val(shape, T::default(), alloc).unwrap()
581 }
582
583 pub fn map<F>(&self, f: F) -> Result<Tensor<T, N>, TensorError>
605 where
606 F: Fn(&T) -> T,
607 {
608 let data: Vec<T> = self.as_slice().iter().map(f).collect();
609 let storage = TensorStorage::from_vec(data, CpuAllocator)?;
610 Ok(Tensor {
611 storage,
612 shape: self.shape,
613 strides: self.strides,
614 })
615 }
616
617 pub fn cast<U>(&self) -> Result<Tensor<U, N>, TensorError>
635 where
636 T: Copy + Into<U>,
637 U: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
638 {
639 let data: Vec<U> = self.as_slice().iter().map(|x| (*x).into()).collect();
640 let storage = TensorStorage::from_vec(data, CpuAllocator)?;
641 Ok(Tensor {
642 storage,
643 shape: self.shape,
644 strides: self.strides,
645 })
646 }
647}
648
649impl<T, const N: usize, A> Clone for Tensor<T, N, A>
650where
651 T: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
652 A: TensorAllocator,
653{
654 fn clone(&self) -> Self {
655 let mut cloned_tensor =
657 Self::new_uninitialized(self.shape, self.storage.alloc().clone()).unwrap();
658
659 for (a, b) in cloned_tensor
661 .as_slice_mut()
662 .iter_mut()
663 .zip(self.as_slice().iter())
664 {
665 *a = *b;
666 }
667
668 cloned_tensor
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use crate::tensor::allocator::CpuAllocator;
675 use crate::tensor::{Tensor, TensorError};
676
677 #[test]
678 fn constructor_1d() -> Result<(), TensorError> {
679 let data: Vec<u8> = vec![1];
680 let t = Tensor::<u8, 1>::from_shape_vec([1], data, CpuAllocator)?;
681 assert_eq!(t.shape, [1]);
682 assert_eq!(t.as_slice(), vec![1]);
683 assert_eq!(t.strides, [1]);
684 assert_eq!(t.numel(), 1);
685 Ok(())
686 }
687
688 #[test]
689 fn constructor_2d() -> Result<(), TensorError> {
690 let data: Vec<u8> = vec![1, 2];
691 let t = Tensor::<u8, 2>::from_shape_vec([1, 2], data, CpuAllocator)?;
692 assert_eq!(t.shape, [1, 2]);
693 assert_eq!(t.as_slice(), vec![1, 2]);
694 assert_eq!(t.strides, [2, 1]);
695 assert_eq!(t.numel(), 2);
696 Ok(())
697 }
698
699 #[test]
700 fn get_1d() -> Result<(), TensorError> {
701 let data: Vec<u8> = vec![1, 2, 3, 4];
702 let t = Tensor::<u8, 1>::from_shape_vec([4], data, CpuAllocator)?;
703 assert_eq!(*t.get([0])?, 1);
704 assert_eq!(*t.get([1])?, 2);
705 assert_eq!(*t.get([2])?, 3);
706 assert_eq!(*t.get([3])?, 4);
707 assert!(t.get([4]).is_err());
708 Ok(())
709 }
710
711 #[test]
712 fn get_2d() -> Result<(), TensorError> {
713 let data: Vec<u8> = vec![1, 2, 3, 4];
714 let t = Tensor::<u8, 2>::from_shape_vec([2, 2], data, CpuAllocator)?;
715 assert_eq!(*t.get([0, 0])?, 1);
716 assert_eq!(*t.get([0, 1])?, 2);
717 assert_eq!(*t.get([1, 0])?, 3);
718 assert_eq!(*t.get([1, 1])?, 4);
719 assert!(t.get([2, 0]).is_err());
720 Ok(())
721 }
722
723 #[test]
724 fn get_3d() -> Result<(), TensorError> {
725 let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
726 let t = Tensor::<u8, 3>::from_shape_vec([2, 1, 3], data, CpuAllocator)?;
727 assert_eq!(*t.get([0, 0, 0])?, 1);
728 assert_eq!(*t.get([0, 0, 1])?, 2);
729 assert_eq!(*t.get([0, 0, 2])?, 3);
730 assert_eq!(*t.get([1, 0, 0])?, 4);
731 assert_eq!(*t.get([1, 0, 1])?, 5);
732 assert_eq!(*t.get([1, 0, 2])?, 6);
733 assert!(t.get([2, 0, 0]).is_err());
734 Ok(())
735 }
736
737 #[test]
738 fn get_checked_1d() -> Result<(), TensorError> {
739 let data: Vec<u8> = vec![1, 2, 3, 4];
740 let t = Tensor::<u8, 1>::from_shape_vec([4], data, CpuAllocator)?;
741 assert_eq!(*t.get_unchecked([0]), 1);
742 assert_eq!(*t.get_unchecked([1]), 2);
743 assert_eq!(*t.get_unchecked([2]), 3);
744 assert_eq!(*t.get_unchecked([3]), 4);
745 Ok(())
746 }
747
748 #[test]
749 fn get_checked_2d() -> Result<(), TensorError> {
750 let data: Vec<u8> = vec![1, 2, 3, 4];
751 let t = Tensor::<u8, 2>::from_shape_vec([2, 2], data, CpuAllocator)?;
752 assert_eq!(*t.get_unchecked([0, 0]), 1);
753 assert_eq!(*t.get_unchecked([0, 1]), 2);
754 assert_eq!(*t.get_unchecked([1, 0]), 3);
755 assert_eq!(*t.get_unchecked([1, 1]), 4);
756 Ok(())
757 }
758
759 #[test]
760 fn add_1d() -> Result<(), TensorError> {
761 let data1: Vec<u8> = vec![1, 2, 3, 4];
762 let t1 = Tensor::<u8, 1>::from_shape_vec([4], data1, CpuAllocator)?;
763 let data2: Vec<u8> = vec![1, 2, 3, 4];
764 let t2 = Tensor::<u8, 1>::from_shape_vec([4], data2, CpuAllocator)?;
765 let t3 = t1.add(&t2)?;
766 assert_eq!(t3.as_slice(), vec![2, 4, 6, 8]);
767 Ok(())
768 }
769
770 #[test]
771 fn add_2d() -> Result<(), TensorError> {
772 let data1: Vec<u8> = vec![1, 2, 3, 4];
773 let t1 = Tensor::<u8, 2>::from_shape_vec([2, 2], data1, CpuAllocator)?;
774 let data2: Vec<u8> = vec![1, 2, 3, 4];
775 let t2 = Tensor::<u8, 2>::from_shape_vec([2, 2], data2, CpuAllocator)?;
776 let t3 = t1.add(&t2)?;
777 assert_eq!(t3.as_slice(), vec![2, 4, 6, 8]);
778 Ok(())
779 }
780
781 #[test]
782 fn add_3d() -> Result<(), TensorError> {
783 let data1: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
784 let t1 = Tensor::<u8, 3>::from_shape_vec([2, 1, 3], data1, CpuAllocator)?;
785 let data2: Vec<u8> = vec![1, 2, 3, 4, 5, 6];
786 let t2 = Tensor::<u8, 3>::from_shape_vec([2, 1, 3], data2, CpuAllocator)?;
787 let t3 = t1.add(&t2)?;
788 assert_eq!(t3.as_slice(), vec![2, 4, 6, 8, 10, 12]);
789 Ok(())
790 }
791
792 #[test]
793 fn sub_1d() -> Result<(), TensorError> {
794 let data1: Vec<u8> = vec![1, 2, 3, 4];
795 let t1 = Tensor::<u8, 1>::from_shape_vec([4], data1, CpuAllocator)?;
796 let data2: Vec<u8> = vec![1, 2, 3, 4];
797 let t2 = Tensor::<u8, 1>::from_shape_vec([4], data2, CpuAllocator)?;
798 let t3 = t1.sub(&t2)?;
799 assert_eq!(t3.as_slice(), vec![0, 0, 0, 0]);
800 Ok(())
801 }
802
803 #[test]
804 fn sub_2d() -> Result<(), TensorError> {
805 let data1: Vec<u8> = vec![1, 2, 3, 4];
806 let t1 = Tensor::<u8, 2>::from_shape_vec([2, 2], data1, CpuAllocator)?;
807 let data2: Vec<u8> = vec![1, 2, 3, 4];
808 let t2 = Tensor::<u8, 2>::from_shape_vec([2, 2], data2, CpuAllocator)?;
809 let t3 = t1.sub(&t2)?;
810 assert_eq!(t3.as_slice(), vec![0, 0, 0, 0]);
811 Ok(())
812 }
813
814 #[test]
815 fn div_1d() -> Result<(), TensorError> {
816 let data1: Vec<u8> = vec![1, 2, 3, 4];
817 let t1 = Tensor::<u8, 1>::from_shape_vec([4], data1, CpuAllocator)?;
818 let data2: Vec<u8> = vec![1, 2, 3, 4];
819 let t2 = Tensor::<u8, 1>::from_shape_vec([4], data2, CpuAllocator)?;
820 let t3 = t1.div(&t2)?;
821 assert_eq!(t3.as_slice(), vec![1, 1, 1, 1]);
822 Ok(())
823 }
824
825 #[test]
826 fn div_2d() -> Result<(), TensorError> {
827 let data1: Vec<u8> = vec![1, 2, 3, 4];
828 let t1 = Tensor::<u8, 2>::from_shape_vec([2, 2], data1, CpuAllocator)?;
829 let data2: Vec<u8> = vec![1, 2, 3, 4];
830 let t2 = Tensor::<u8, 2>::from_shape_vec([2, 2], data2, CpuAllocator)?;
831 let t3 = t1.div(&t2)?;
832 assert_eq!(t3.as_slice(), vec![1, 1, 1, 1]);
833 Ok(())
834 }
835
836 #[test]
837 fn mul_1d() -> Result<(), TensorError> {
838 let data1: Vec<u8> = vec![1, 2, 3, 4];
839 let t1 = Tensor::<u8, 1>::from_shape_vec([4], data1, CpuAllocator)?;
840 let data2: Vec<u8> = vec![1, 2, 3, 4];
841 let t2 = Tensor::<u8, 1>::from_shape_vec([4], data2, CpuAllocator)?;
842 let t3 = t1.mul(&t2)?;
843 assert_eq!(t3.as_slice(), vec![1, 4, 9, 16]);
844 Ok(())
845 }
846
847 #[test]
848 fn mul_2d() -> Result<(), TensorError> {
849 let data1: Vec<u8> = vec![1, 2, 3, 4];
850 let t1 = Tensor::<u8, 2>::from_shape_vec([2, 2], data1, CpuAllocator)?;
851 let data2: Vec<u8> = vec![1, 2, 3, 4];
852 let t2 = Tensor::<u8, 2>::from_shape_vec([2, 2], data2, CpuAllocator)?;
853 let t3 = t1.mul(&t2)?;
854 assert_eq!(t3.as_slice(), vec![1, 4, 9, 16]);
855 Ok(())
856 }
857
858 #[test]
859 fn reshape_1d() -> Result<(), TensorError> {
860 let data: Vec<u8> = vec![1, 2, 3, 4];
861 let t = Tensor::<u8, 1>::from_shape_vec([4], data, CpuAllocator)?;
862 let t2 = t.reshape([2, 2])?;
863 assert_eq!(t2.shape, [2, 2]);
864 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
865 assert_eq!(t2.strides, [2, 1]);
866 assert_eq!(t2.numel(), 4);
867 Ok(())
868 }
869
870 #[test]
871 fn reshape_2d() -> Result<(), TensorError> {
872 let data: Vec<u8> = vec![1, 2, 3, 4];
873 let t = Tensor::<u8, 2>::from_shape_vec([2, 2], data, CpuAllocator)?;
874 let t2 = t.reshape([4])?;
875 assert_eq!(t2.shape, [4]);
876 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
877 assert_eq!(t2.strides, [1]);
878 assert_eq!(t2.numel(), 4);
879 Ok(())
880 }
881
882 #[test]
883 fn reshape_get_1d() -> Result<(), TensorError> {
884 let data: Vec<u8> = vec![1, 2, 3, 4];
885 let t = Tensor::<u8, 1>::from_shape_vec([4], data, CpuAllocator)?;
886 let t2 = t.reshape([2, 2])?;
887 assert_eq!(*t2.get([0, 0])?, 1);
888 assert_eq!(*t2.get([0, 1])?, 2);
889 assert_eq!(*t2.get([1, 0])?, 3);
890 assert_eq!(*t2.get([1, 1])?, 4);
891 assert_eq!(t2.numel(), 4);
892 Ok(())
893 }
894
895 #[test]
896 fn zeros_1d() -> Result<(), TensorError> {
897 let t = Tensor::<u8, 1>::zeros([4], CpuAllocator);
898 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
899 Ok(())
900 }
901
902 #[test]
903 fn zeros_2d() -> Result<(), TensorError> {
904 let t = Tensor::<u8, 2>::zeros([2, 2], CpuAllocator);
905 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
906 Ok(())
907 }
908
909 #[test]
910 fn map_1d() -> Result<(), TensorError> {
911 let data: Vec<u8> = vec![1, 2, 3, 4];
912 let t = Tensor::<u8, 1>::from_shape_vec([4], data, CpuAllocator)?;
913 let t2 = t.map(|x| *x + 1)?;
914 assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
915 Ok(())
916 }
917
918 #[test]
919 fn map_2d() -> Result<(), TensorError> {
920 let data: Vec<u8> = vec![1, 2, 3, 4];
921 let t = Tensor::<u8, 2>::from_shape_vec([2, 2], data, CpuAllocator)?;
922 let t2 = t.map(|x| *x + 1)?;
923 assert_eq!(t2.as_slice(), vec![2, 3, 4, 5]);
924 Ok(())
925 }
926
927 #[test]
928 fn from_shape_val_1d() -> Result<(), TensorError> {
929 let t = Tensor::<u8, 1>::from_shape_val([4], 0, CpuAllocator)?;
930 assert_eq!(t.as_slice(), vec![0, 0, 0, 0]);
931 Ok(())
932 }
933
934 #[test]
935 fn from_shape_val_2d() -> Result<(), TensorError> {
936 let t = Tensor::<u8, 2>::from_shape_val([2, 2], 1, CpuAllocator)?;
937 assert_eq!(t.as_slice(), vec![1, 1, 1, 1]);
938 Ok(())
939 }
940
941 #[test]
942 fn from_shape_val_3d() -> Result<(), TensorError> {
943 let t = Tensor::<u8, 3>::from_shape_val([2, 1, 3], 2, CpuAllocator)?;
944 assert_eq!(t.as_slice(), vec![2, 2, 2, 2, 2, 2]);
945 Ok(())
946 }
947
948 #[test]
949 fn cast_1d() -> Result<(), TensorError> {
950 let data: Vec<u8> = vec![1, 2, 3, 4];
951 let t = Tensor::<u8, 1>::from_shape_vec([4], data, CpuAllocator)?;
952 let t2 = t.cast::<u16>()?;
953 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
954 Ok(())
955 }
956
957 #[test]
958 fn cast_2d() -> Result<(), TensorError> {
959 let data: Vec<u8> = vec![1, 2, 3, 4];
960 let t = Tensor::<u8, 2>::from_shape_vec([2, 2], data, CpuAllocator)?;
961 let t2 = t.cast::<u16>()?;
962 assert_eq!(t2.as_slice(), vec![1, 2, 3, 4]);
963 Ok(())
964 }
965
966 #[test]
967 fn from_shape_fn_1d() -> Result<(), TensorError> {
968 let t = Tensor::from_shape_fn([3, 3], |[i, j]| ((1 + i) * (1 + j)) as u8, CpuAllocator);
969 assert_eq!(t.as_slice(), vec![1, 2, 3, 2, 4, 6, 3, 6, 9]);
970 Ok(())
971 }
972
973 #[test]
974 fn from_shape_fn_2d() -> Result<(), TensorError> {
975 let t = Tensor::from_shape_fn([3, 3], |[i, j]| ((1 + i) * (1 + j)) as f32, CpuAllocator);
976 assert_eq!(
977 t.as_slice(),
978 vec![1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0]
979 );
980 Ok(())
981 }
982
983 #[test]
984 fn from_shape_fn_3d() -> Result<(), TensorError> {
985 let t = Tensor::from_shape_fn(
986 [2, 3, 3],
987 |[x, y, c]| ((1 + x) * (1 + y) * (1 + c)) as i16,
988 CpuAllocator,
989 );
990 assert_eq!(
991 t.as_slice(),
992 vec![1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18]
993 );
994 Ok(())
995 }
996}