1use crate::{
2 get_strides_from_shape, storage::TensorStorage, CpuAllocator, Tensor, TensorAllocator,
3};
4
5pub struct TensorView<'a, T, const N: usize, A: TensorAllocator> {
7 pub storage: &'a TensorStorage<T, A>,
9
10 pub shape: [usize; N],
12
13 pub strides: [usize; N],
15}
16
17impl<T, const N: usize, A: TensorAllocator + 'static> TensorView<'_, T, N, A> {
18 #[inline]
20 pub fn as_slice(&self) -> &[T] {
21 self.storage.as_slice()
22 }
23
24 #[inline]
26 pub fn as_ptr(&self) -> *const T {
27 self.storage.as_ptr()
28 }
29
30 #[inline]
32 pub fn numel(&self) -> usize {
33 self.storage.len() / std::mem::size_of::<T>()
34 }
35
36 pub fn get_unchecked(&self, index: [usize; N]) -> &T {
46 let offset = index
47 .iter()
48 .zip(self.strides.iter())
49 .fold(0, |acc, (i, s)| acc + i * s);
50 unsafe { self.storage.as_slice().get_unchecked(offset) }
51 }
52
53 pub fn as_contiguous(&self) -> Tensor<T, N, CpuAllocator>
59 where
60 T: Clone,
61 {
62 let mut data = Vec::<T>::with_capacity(self.numel());
63 let mut index = [0; N];
64
65 loop {
66 let val = self.get_unchecked(index);
67 data.push(val.clone());
68
69 let mut i = N - 1;
71 while i > 0 && index[i] == self.shape[i] - 1 {
72 index[i] = 0;
73 i -= 1;
74 }
75 if i == 0 && index[0] == self.shape[0] - 1 {
76 break;
77 }
78 index[i] += 1;
79 }
80
81 let strides = get_strides_from_shape(self.shape);
82
83 Tensor {
84 storage: TensorStorage::from_vec(data, CpuAllocator),
85 shape: self.shape,
86 strides,
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::allocator::{CpuAllocator, TensorAllocatorError};
95
96 #[test]
97 fn test_tensor_view_from_vec() -> Result<(), TensorAllocatorError> {
98 let vec = vec![1, 2, 3, 4, 5, 6, 7, 8];
99 let storage = TensorStorage::from_vec(vec, CpuAllocator);
100
101 let view = TensorView::<u8, 1, _> {
102 storage: &storage,
103 shape: [8],
104 strides: [1],
105 };
106
107 assert_eq!(view.numel(), 8);
108 assert!(!view.as_ptr().is_null());
109
110 let data = view.as_slice();
112 assert_eq!(data.len(), 8);
113 assert_eq!(data[0], 1);
114 assert_eq!(data[1], 2);
115 assert_eq!(data[2], 3);
116 assert_eq!(data[3], 4);
117 assert_eq!(data[4], 5);
118 assert_eq!(data[5], 6);
119 assert_eq!(data[6], 7);
120 assert_eq!(data[7], 8);
121
122 assert_eq!(view.get_unchecked([0]), &1);
124 assert_eq!(view.get_unchecked([1]), &2);
125 assert_eq!(view.get_unchecked([2]), &3);
126 assert_eq!(view.get_unchecked([3]), &4);
127 assert_eq!(view.get_unchecked([4]), &5);
128 assert_eq!(view.get_unchecked([5]), &6);
129 assert_eq!(view.get_unchecked([6]), &7);
130 assert_eq!(view.get_unchecked([7]), &8);
131
132 Ok(())
133 }
134}