kornia_core/
storage.rs

1use std::{alloc::Layout, ptr::NonNull};
2
3use crate::allocator::TensorAllocator;
4
5/// Definition of the buffer for a tensor.
6pub struct TensorStorage<T, A: TensorAllocator> {
7    /// The pointer to the tensor memory which must be non null.
8    pub(crate) ptr: NonNull<T>,
9    /// The length of the tensor memory in bytes.
10    pub(crate) len: usize,
11    /// The layout of the tensor memory.
12    pub(crate) layout: Layout,
13    /// The allocator used to allocate/deallocate the tensor memory.
14    pub(crate) alloc: A,
15}
16
17impl<T, A: TensorAllocator> TensorStorage<T, A> {
18    /// Returns the pointer to the tensor memory.
19    #[inline]
20    pub fn as_ptr(&self) -> *const T {
21        self.ptr.as_ptr()
22    }
23
24    /// Returns the pointer to the tensor memory.
25    #[inline]
26    pub fn as_mut_ptr(&mut self) -> *mut T {
27        self.ptr.as_ptr()
28    }
29
30    /// Returns the data pointer as a slice.
31    pub fn as_slice(&self) -> &[T] {
32        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.len / std::mem::size_of::<T>()) }
33    }
34
35    /// Returns the data pointer as a mutable slice.
36    pub fn as_mut_slice(&mut self) -> &mut [T] {
37        unsafe {
38            std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.len / std::mem::size_of::<T>())
39        }
40    }
41
42    /// Returns the number of bytes contained in this `TensorStorage`.
43    #[inline]
44    pub fn len(&self) -> usize {
45        self.len
46    }
47
48    /// Returns true if the `TensorStorage` has a length of 0.
49    #[inline]
50    pub fn is_empty(&self) -> bool {
51        self.len == 0
52    }
53
54    /// Returns the layout of the tensor buffer.
55    #[inline]
56    pub fn layout(&self) -> Layout {
57        self.layout
58    }
59
60    /// Returns the allocator of the tensor buffer.
61    #[inline]
62    pub fn alloc(&self) -> &A {
63        &self.alloc
64    }
65
66    // TODO: use the allocator somehow
67    /// Creates a new tensor buffer from a vector.
68    pub fn from_vec(value: Vec<T>, alloc: A) -> Self {
69        //let buf = arrow_buffer::Buffer::from_vec(value);
70        // Safety
71        // Vec::as_ptr guaranteed to not be null
72        let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as _) };
73        let len = value.len() * std::mem::size_of::<T>();
74        // Safety
75        // Vec guaranteed to have a valid layout matching that of `Layout::array`
76        // This is based on `RawVec::current_memory`
77        let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
78        std::mem::forget(value);
79
80        Self {
81            ptr,
82            len,
83            layout,
84            alloc,
85        }
86    }
87
88    /// Converts the `TensorStorage` into a `Vec<T>`.
89    ///
90    /// Returns `Err(self)` if the buffer does not have the same layout as the destination Vec.
91    pub fn into_vec(self) -> Vec<T> {
92        // TODO: check if the buffer is a cpu buffer or comes from a custom allocator
93        let _layout = &self.layout;
94
95        let vec_capacity = self.layout.size() / std::mem::size_of::<T>();
96        //match Layout::array::<T>(vec_capacity) {
97        //    Ok(expected) if layout == &expected => {}
98        //    e => return Err(TensorAllocatorError::LayoutError(e.unwrap_err())),
99        //}
100
101        let length = self.len;
102        let ptr = self.ptr;
103        let vec_len = length / std::mem::size_of::<T>();
104
105        // Safety
106        std::mem::forget(self);
107        unsafe { Vec::from_raw_parts(ptr.as_ptr(), vec_len, vec_capacity) }
108    }
109}
110
111// TODO: pass the allocator to constructor
112impl<T, A: TensorAllocator> From<Vec<T>> for TensorStorage<T, A>
113where
114    A: Default,
115{
116    /// Creates a new tensor buffer from a vector.
117    fn from(value: Vec<T>) -> Self {
118        // Safety
119        // Vec::as_ptr guaranteed to not be null
120        let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as *mut T) };
121        let len = value.len() * std::mem::size_of::<T>();
122        // Safety
123        // Vec guaranteed to have a valid layout matching that of `Layout::array`
124        // This is based on `RawVec::current_memory`
125        let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
126        std::mem::forget(value);
127
128        Self {
129            ptr,
130            len,
131            layout,
132            alloc: A::default(),
133        }
134    }
135}
136// Safety:
137// TensorStorage is thread safe if the allocator is thread safe.
138unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
139unsafe impl<T, A: TensorAllocator> Sync for TensorStorage<T, A> {}
140
141impl<T, A: TensorAllocator> Drop for TensorStorage<T, A> {
142    fn drop(&mut self) {
143        self.alloc
144            .dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
145    }
146}
147/// A new `TensorStorage` instance with cloned data if successful, otherwise an error.
148impl<T, A> Clone for TensorStorage<T, A>
149where
150    T: Clone,
151    A: TensorAllocator + 'static,
152{
153    fn clone(&self) -> Self {
154        let mut new_vec = Vec::<T>::with_capacity(self.len());
155
156        for i in self.as_slice() {
157            new_vec.push(i.clone());
158        }
159
160        Self::from_vec(new_vec, self.alloc.clone())
161    }
162}
163
164#[cfg(test)]
165mod tests {
166
167    use super::TensorStorage;
168    use crate::allocator::{CpuAllocator, TensorAllocatorError};
169    use crate::TensorAllocator;
170    use std::alloc::Layout;
171    use std::cell::RefCell;
172    use std::ptr::NonNull;
173    use std::rc::Rc;
174
175    #[test]
176    fn test_tensor_buffer_create_raw() -> Result<(), TensorAllocatorError> {
177        let size = 8;
178        let allocator = CpuAllocator;
179        let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
180        let ptr =
181            NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
182        let ptr_raw = ptr.as_ptr();
183
184        let buffer = TensorStorage {
185            alloc: allocator,
186            len: size * std::mem::size_of::<u8>(),
187            layout,
188            ptr,
189        };
190
191        assert_eq!(buffer.ptr.as_ptr(), ptr_raw);
192        assert!(!ptr_raw.is_null());
193        assert_eq!(buffer.layout, layout);
194        assert_eq!(buffer.len(), size);
195        assert!(!buffer.is_empty());
196        assert_eq!(buffer.len(), size * std::mem::size_of::<u8>());
197
198        Ok(())
199    }
200
201    #[test]
202    fn test_tensor_buffer_ptr() -> Result<(), TensorAllocatorError> {
203        let size = 8;
204        let allocator = CpuAllocator;
205        let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
206        let ptr =
207            NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
208
209        // check alignment
210        let ptr_raw = ptr.as_ptr() as usize;
211        let alignment = std::mem::align_of::<u8>();
212        assert_eq!(ptr_raw % alignment, 0);
213
214        Ok(())
215    }
216
217    #[test]
218    fn test_tensor_buffer_create_f32() -> Result<(), TensorAllocatorError> {
219        let size = 8;
220        let allocator = CpuAllocator;
221        let layout = Layout::array::<f32>(size).map_err(TensorAllocatorError::LayoutError)?;
222        let ptr =
223            NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
224
225        let buffer = TensorStorage {
226            alloc: allocator,
227            len: size,
228            layout,
229            ptr: ptr.cast::<f32>(),
230        };
231
232        assert_eq!(buffer.as_ptr(), ptr.as_ptr() as *const f32);
233        assert_eq!(buffer.layout, layout);
234        assert_eq!(buffer.len(), size);
235
236        Ok(())
237    }
238
239    #[test]
240    fn test_tensor_buffer_lifecycle() -> Result<(), TensorAllocatorError> {
241        /// A simple allocator that counts the number of bytes allocated and deallocated.
242        #[derive(Clone)]
243        struct TestAllocator {
244            bytes_allocated: Rc<RefCell<i32>>,
245        }
246
247        impl TensorAllocator for TestAllocator {
248            fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
249                *self.bytes_allocated.borrow_mut() += layout.size() as i32;
250                CpuAllocator.alloc(layout)
251            }
252            fn dealloc(&self, ptr: *mut u8, layout: Layout) {
253                *self.bytes_allocated.borrow_mut() -= layout.size() as i32;
254                CpuAllocator.dealloc(ptr, layout)
255            }
256        }
257
258        let allocator = TestAllocator {
259            bytes_allocated: Rc::new(RefCell::new(0)),
260        };
261        assert_eq!(*allocator.bytes_allocated.borrow(), 0);
262
263        let size = 1024;
264
265        // TensorStorage::from_vec() -> TensorStorage::into_vec()
266        // TensorStorage::from_vec() currently does not use the custom allocator, so the
267        // bytes_allocated value should not change.
268        {
269            let vec = Vec::<u8>::with_capacity(size);
270            let vec_ptr = vec.as_ptr();
271            let vec_capacity = vec.capacity();
272
273            let buffer = TensorStorage::from_vec(vec, allocator.clone());
274            assert_eq!(*allocator.bytes_allocated.borrow(), 0);
275
276            let result_vec = buffer.into_vec();
277            assert_eq!(*allocator.bytes_allocated.borrow(), 0);
278
279            assert_eq!(result_vec.capacity(), vec_capacity);
280            assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
281        }
282        assert_eq!(*allocator.bytes_allocated.borrow(), 0);
283
284        Ok(())
285    }
286
287    #[test]
288    fn test_tensor_buffer_from_vec() -> Result<(), TensorAllocatorError> {
289        let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
290        let vec_ptr = vec.as_ptr();
291        let vec_len = vec.len();
292
293        let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
294
295        // check NO copy
296        let buffer_ptr = buffer.as_ptr();
297        assert!(std::ptr::eq(buffer_ptr, vec_ptr));
298
299        // check alignment
300        let buffer_ptr = buffer.as_ptr() as usize;
301        let alignment = std::mem::align_of::<i32>();
302        assert_eq!(buffer_ptr % alignment, 0);
303
304        // check accessors
305        let data = buffer.as_slice();
306        assert_eq!(data.len(), vec_len);
307        assert_eq!(data[0], 1);
308        assert_eq!(data[1], 2);
309        assert_eq!(data[2], 3);
310        assert_eq!(data[3], 4);
311        assert_eq!(data[4], 5);
312
313        assert_eq!(data.first(), Some(&1));
314        assert_eq!(data.get(1), Some(&2));
315        assert_eq!(data.get(2), Some(&3));
316        assert_eq!(data.get(3), Some(&4));
317        assert_eq!(data.get(4), Some(&5));
318        assert_eq!(data.get(5), None);
319
320        unsafe {
321            assert_eq!(data.get_unchecked(0), &1);
322            assert_eq!(data.get_unchecked(1), &2);
323            assert_eq!(data.get_unchecked(2), &3);
324            assert_eq!(data.get_unchecked(3), &4);
325            assert_eq!(data.get_unchecked(4), &5);
326        }
327
328        Ok(())
329    }
330
331    #[test]
332    fn test_tensor_buffer_into_vec() -> Result<(), TensorAllocatorError> {
333        let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
334        let vec_ptr = vec.as_ptr();
335        let vec_cap = vec.capacity();
336
337        let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
338
339        // convert back to vec
340        let result_vec = buffer.into_vec();
341
342        // check NO copy
343        assert_eq!(result_vec.capacity(), vec_cap);
344        assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
345
346        Ok(())
347    }
348
349    #[test]
350    fn test_tensor_mutability() -> Result<(), TensorAllocatorError> {
351        let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
352        let mut buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
353        let ptr_mut = buffer.as_mut_ptr();
354        unsafe {
355            *ptr_mut.add(0) = 10;
356        }
357        assert_eq!(buffer.into_vec(), vec![10, 2, 3, 4, 5]);
358        Ok(())
359    }
360}