kornia_tensor/
view.rs

1use crate::{
2    get_strides_from_shape, storage::TensorStorage, CpuAllocator, Tensor, TensorAllocator,
3};
4
5/// A view into a tensor.
6pub struct TensorView<'a, T, const N: usize, A: TensorAllocator> {
7    /// Reference to the storage held by the another tensor.
8    pub storage: &'a TensorStorage<T, A>,
9
10    /// The shape of the tensor.
11    pub shape: [usize; N],
12
13    /// The strides of the tensor.
14    pub strides: [usize; N],
15}
16
17impl<T, const N: usize, A: TensorAllocator + 'static> TensorView<'_, T, N, A> {
18    /// Returns the data slice of the tensor.
19    #[inline]
20    pub fn as_slice(&self) -> &[T] {
21        self.storage.as_slice()
22    }
23
24    /// Returns the data pointer of the tensor.
25    #[inline]
26    pub fn as_ptr(&self) -> *const T {
27        self.storage.as_ptr()
28    }
29
30    /// Returns the length of the tensor.
31    #[inline]
32    pub fn numel(&self) -> usize {
33        self.storage.len() / std::mem::size_of::<T>()
34    }
35
36    /// Get the element at the given index.
37    ///
38    /// # Returns
39    ///
40    /// A reference to the element at the given index.
41    ///
42    /// # Safety
43    ///
44    /// The caller must ensure that the index is within the bounds of the tensor.
45    pub fn get_unchecked(&self, index: [usize; N]) -> &T {
46        let offset = index
47            .iter()
48            .zip(self.strides.iter())
49            .fold(0, |acc, (i, s)| acc + i * s);
50        unsafe { self.storage.as_slice().get_unchecked(offset) }
51    }
52
53    /// Convert the view an owned tensor with contiguous memory.
54    ///
55    /// # Returns
56    ///
57    /// A new `Tensor` instance with contiguous memory.
58    pub fn as_contiguous(&self) -> Tensor<T, N, CpuAllocator>
59    where
60        T: Clone,
61    {
62        let mut data = Vec::<T>::with_capacity(self.numel());
63        let mut index = [0; N];
64
65        loop {
66            let val = self.get_unchecked(index);
67            data.push(val.clone());
68
69            // Increment index
70            let mut i = N - 1;
71            while i > 0 && index[i] == self.shape[i] - 1 {
72                index[i] = 0;
73                i -= 1;
74            }
75            if i == 0 && index[0] == self.shape[0] - 1 {
76                break;
77            }
78            index[i] += 1;
79        }
80
81        let strides = get_strides_from_shape(self.shape);
82
83        Tensor {
84            storage: TensorStorage::from_vec(data, CpuAllocator),
85            shape: self.shape,
86            strides,
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::allocator::{CpuAllocator, TensorAllocatorError};
95
96    #[test]
97    fn test_tensor_view_from_vec() -> Result<(), TensorAllocatorError> {
98        let vec = vec![1, 2, 3, 4, 5, 6, 7, 8];
99        let storage = TensorStorage::from_vec(vec, CpuAllocator);
100
101        let view = TensorView::<u8, 1, _> {
102            storage: &storage,
103            shape: [8],
104            strides: [1],
105        };
106
107        assert_eq!(view.numel(), 8);
108        assert!(!view.as_ptr().is_null());
109
110        // check slice
111        let data = view.as_slice();
112        assert_eq!(data.len(), 8);
113        assert_eq!(data[0], 1);
114        assert_eq!(data[1], 2);
115        assert_eq!(data[2], 3);
116        assert_eq!(data[3], 4);
117        assert_eq!(data[4], 5);
118        assert_eq!(data[5], 6);
119        assert_eq!(data[6], 7);
120        assert_eq!(data[7], 8);
121
122        // check get_unchecked
123        assert_eq!(view.get_unchecked([0]), &1);
124        assert_eq!(view.get_unchecked([1]), &2);
125        assert_eq!(view.get_unchecked([2]), &3);
126        assert_eq!(view.get_unchecked([3]), &4);
127        assert_eq!(view.get_unchecked([4]), &5);
128        assert_eq!(view.get_unchecked([5]), &6);
129        assert_eq!(view.get_unchecked([6]), &7);
130        assert_eq!(view.get_unchecked([7]), &8);
131
132        Ok(())
133    }
134}