kornia_tensor/
storage.rs

1use std::{alloc::Layout, ptr::NonNull};
2
3use crate::allocator::TensorAllocator;
4
5/// Definition of the buffer for a tensor.
6pub struct TensorStorage<T, A: TensorAllocator> {
7    /// The pointer to the tensor memory which must be non null.
8    pub(crate) ptr: NonNull<T>,
9    /// The length of the tensor memory in bytes.
10    pub(crate) len: usize,
11    /// The layout of the tensor memory.
12    pub(crate) layout: Layout,
13    /// The allocator used to allocate/deallocate the tensor memory.
14    pub(crate) alloc: A,
15}
16
17impl<T, A: TensorAllocator> TensorStorage<T, A> {
18    /// Returns the pointer to the tensor memory.
19    #[inline]
20    pub fn as_ptr(&self) -> *const T {
21        self.ptr.as_ptr()
22    }
23
24    /// Returns the pointer to the tensor memory.
25    #[inline]
26    pub fn as_mut_ptr(&mut self) -> *mut T {
27        self.ptr.as_ptr()
28    }
29
30    /// Returns the data pointer as a slice.
31    pub fn as_slice(&self) -> &[T] {
32        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.len / std::mem::size_of::<T>()) }
33    }
34
35    /// Returns the data pointer as a mutable slice.
36    pub fn as_mut_slice(&mut self) -> &mut [T] {
37        unsafe {
38            std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.len / std::mem::size_of::<T>())
39        }
40    }
41
42    /// Returns the number of bytes contained in this `TensorStorage`.
43    #[inline]
44    pub fn len(&self) -> usize {
45        self.len
46    }
47
48    /// Returns true if the `TensorStorage` has a length of 0.
49    #[inline]
50    pub fn is_empty(&self) -> bool {
51        self.len == 0
52    }
53
54    /// Returns the layout of the tensor buffer.
55    #[inline]
56    pub fn layout(&self) -> Layout {
57        self.layout
58    }
59
60    /// Returns the allocator of the tensor buffer.
61    #[inline]
62    pub fn alloc(&self) -> &A {
63        &self.alloc
64    }
65
66    // TODO: use the allocator somehow
67    /// Creates a new tensor buffer from a vector.
68    pub fn from_vec(value: Vec<T>, alloc: A) -> Self {
69        //let buf = arrow_buffer::Buffer::from_vec(value);
70        // Safety
71        // Vec::as_ptr guaranteed to not be null
72        let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as _) };
73        let len = value.len() * std::mem::size_of::<T>();
74        // Safety
75        // Vec guaranteed to have a valid layout matching that of `Layout::array`
76        // This is based on `RawVec::current_memory`
77        let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
78        std::mem::forget(value);
79
80        Self {
81            ptr,
82            len,
83            layout,
84            alloc,
85        }
86    }
87
88    /// Creates a new tensor buffer from a raw pointer.
89    ///
90    /// # Safety
91    ///
92    /// The pointer must be non-null and the length must be valid.
93    pub unsafe fn from_raw_parts(data: *const T, len: usize, alloc: A) -> Self {
94        let ptr = NonNull::new_unchecked(data as _);
95        let layout = Layout::from_size_align_unchecked(len, std::mem::size_of::<T>());
96        Self {
97            ptr,
98            len,
99            layout,
100            alloc,
101        }
102    }
103
104    /// Converts the `TensorStorage` into a `Vec<T>`.
105    ///
106    /// Returns `Err(self)` if the buffer does not have the same layout as the destination Vec.
107    pub fn into_vec(self) -> Vec<T> {
108        // TODO: check if the buffer is a cpu buffer or comes from a custom allocator
109        let _layout = &self.layout;
110
111        let vec_capacity = self.layout.size() / std::mem::size_of::<T>();
112        //match Layout::array::<T>(vec_capacity) {
113        //    Ok(expected) if layout == &expected => {}
114        //    e => return Err(TensorAllocatorError::LayoutError(e.unwrap_err())),
115        //}
116
117        let length = self.len;
118        let ptr = self.ptr;
119        let vec_len = length / std::mem::size_of::<T>();
120
121        // Safety
122        std::mem::forget(self);
123        unsafe { Vec::from_raw_parts(ptr.as_ptr(), vec_len, vec_capacity) }
124    }
125}
126
127// Safety:
128// TensorStorage is thread safe if the allocator is thread safe.
129unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
130unsafe impl<T, A: TensorAllocator> Sync for TensorStorage<T, A> {}
131
132impl<T, A: TensorAllocator> Drop for TensorStorage<T, A> {
133    fn drop(&mut self) {
134        self.alloc
135            .dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
136    }
137}
138/// A new `TensorStorage` instance with cloned data if successful, otherwise an error.
139impl<T, A> Clone for TensorStorage<T, A>
140where
141    T: Clone,
142    A: TensorAllocator,
143{
144    fn clone(&self) -> Self {
145        Self::from_vec(self.as_slice().to_vec(), self.alloc.clone())
146    }
147}
148
149#[cfg(test)]
150mod tests {
151
152    use super::TensorStorage;
153    use crate::allocator::{CpuAllocator, TensorAllocatorError};
154    use crate::TensorAllocator;
155    use std::alloc::Layout;
156    use std::cell::RefCell;
157    use std::ptr::NonNull;
158    use std::rc::Rc;
159
160    #[test]
161    fn test_tensor_buffer_create_raw() -> Result<(), TensorAllocatorError> {
162        let size = 8;
163        let allocator = CpuAllocator;
164        let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
165        let ptr =
166            NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
167        let ptr_raw = ptr.as_ptr();
168
169        let buffer = TensorStorage {
170            alloc: allocator,
171            len: size * std::mem::size_of::<u8>(),
172            layout,
173            ptr,
174        };
175
176        assert_eq!(buffer.ptr.as_ptr(), ptr_raw);
177        assert!(!ptr_raw.is_null());
178        assert_eq!(buffer.layout, layout);
179        assert_eq!(buffer.len(), size);
180        assert!(!buffer.is_empty());
181        assert_eq!(buffer.len(), size * std::mem::size_of::<u8>());
182
183        Ok(())
184    }
185
186    #[test]
187    fn test_tensor_buffer_ptr() -> Result<(), TensorAllocatorError> {
188        let size = 8;
189        let allocator = CpuAllocator;
190        let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
191        let ptr =
192            NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
193
194        // check alignment
195        let ptr_raw = ptr.as_ptr() as usize;
196        let alignment = std::mem::align_of::<u8>();
197        assert_eq!(ptr_raw % alignment, 0);
198
199        Ok(())
200    }
201
202    #[test]
203    fn test_tensor_buffer_create_f32() -> Result<(), TensorAllocatorError> {
204        let size = 8;
205        let allocator = CpuAllocator;
206        let layout = Layout::array::<f32>(size).map_err(TensorAllocatorError::LayoutError)?;
207        let ptr =
208            NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
209
210        let buffer = TensorStorage {
211            alloc: allocator,
212            len: size,
213            layout,
214            ptr: ptr.cast::<f32>(),
215        };
216
217        assert_eq!(buffer.as_ptr(), ptr.as_ptr() as *const f32);
218        assert_eq!(buffer.layout, layout);
219        assert_eq!(buffer.len(), size);
220
221        Ok(())
222    }
223
224    #[test]
225    fn test_tensor_buffer_lifecycle() -> Result<(), TensorAllocatorError> {
226        /// A simple allocator that counts the number of bytes allocated and deallocated.
227        #[derive(Clone)]
228        struct TestAllocator {
229            bytes_allocated: Rc<RefCell<i32>>,
230        }
231
232        impl TensorAllocator for TestAllocator {
233            fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
234                *self.bytes_allocated.borrow_mut() += layout.size() as i32;
235                CpuAllocator.alloc(layout)
236            }
237            fn dealloc(&self, ptr: *mut u8, layout: Layout) {
238                *self.bytes_allocated.borrow_mut() -= layout.size() as i32;
239                CpuAllocator.dealloc(ptr, layout)
240            }
241        }
242
243        let allocator = TestAllocator {
244            bytes_allocated: Rc::new(RefCell::new(0)),
245        };
246        assert_eq!(*allocator.bytes_allocated.borrow(), 0);
247
248        let size = 1024;
249
250        // TensorStorage::from_vec() -> TensorStorage::into_vec()
251        // TensorStorage::from_vec() currently does not use the custom allocator, so the
252        // bytes_allocated value should not change.
253        {
254            let vec = Vec::<u8>::with_capacity(size);
255            let vec_ptr = vec.as_ptr();
256            let vec_capacity = vec.capacity();
257
258            let buffer = TensorStorage::from_vec(vec, allocator.clone());
259            assert_eq!(*allocator.bytes_allocated.borrow(), 0);
260
261            let result_vec = buffer.into_vec();
262            assert_eq!(*allocator.bytes_allocated.borrow(), 0);
263
264            assert_eq!(result_vec.capacity(), vec_capacity);
265            assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
266        }
267        assert_eq!(*allocator.bytes_allocated.borrow(), 0);
268
269        Ok(())
270    }
271
272    #[test]
273    fn test_tensor_buffer_from_vec() -> Result<(), TensorAllocatorError> {
274        let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
275        let vec_ptr = vec.as_ptr();
276        let vec_len = vec.len();
277
278        let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
279
280        // check NO copy
281        let buffer_ptr = buffer.as_ptr();
282        assert!(std::ptr::eq(buffer_ptr, vec_ptr));
283
284        // check alignment
285        let buffer_ptr = buffer.as_ptr() as usize;
286        let alignment = std::mem::align_of::<i32>();
287        assert_eq!(buffer_ptr % alignment, 0);
288
289        // check accessors
290        let data = buffer.as_slice();
291        assert_eq!(data.len(), vec_len);
292        assert_eq!(data[0], 1);
293        assert_eq!(data[1], 2);
294        assert_eq!(data[2], 3);
295        assert_eq!(data[3], 4);
296        assert_eq!(data[4], 5);
297
298        assert_eq!(data.first(), Some(&1));
299        assert_eq!(data.get(1), Some(&2));
300        assert_eq!(data.get(2), Some(&3));
301        assert_eq!(data.get(3), Some(&4));
302        assert_eq!(data.get(4), Some(&5));
303        assert_eq!(data.get(5), None);
304
305        unsafe {
306            assert_eq!(data.get_unchecked(0), &1);
307            assert_eq!(data.get_unchecked(1), &2);
308            assert_eq!(data.get_unchecked(2), &3);
309            assert_eq!(data.get_unchecked(3), &4);
310            assert_eq!(data.get_unchecked(4), &5);
311        }
312
313        Ok(())
314    }
315
316    #[test]
317    fn test_tensor_buffer_into_vec() -> Result<(), TensorAllocatorError> {
318        let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
319        let vec_ptr = vec.as_ptr();
320        let vec_cap = vec.capacity();
321
322        let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
323
324        // convert back to vec
325        let result_vec = buffer.into_vec();
326
327        // check NO copy
328        assert_eq!(result_vec.capacity(), vec_cap);
329        assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
330
331        Ok(())
332    }
333
334    #[test]
335    fn test_tensor_mutability() -> Result<(), TensorAllocatorError> {
336        let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
337        let mut buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
338        let ptr_mut = buffer.as_mut_ptr();
339        unsafe {
340            *ptr_mut.add(0) = 10;
341        }
342        assert_eq!(buffer.into_vec(), vec![10, 2, 3, 4, 5]);
343        Ok(())
344    }
345}