1use std::{alloc::Layout, ptr::NonNull};
2
3use crate::allocator::TensorAllocator;
4
5pub struct TensorStorage<T, A: TensorAllocator> {
7 pub(crate) ptr: NonNull<T>,
9 pub(crate) len: usize,
11 pub(crate) layout: Layout,
13 pub(crate) alloc: A,
15}
16
17impl<T, A: TensorAllocator> TensorStorage<T, A> {
18 #[inline]
20 pub fn as_ptr(&self) -> *const T {
21 self.ptr.as_ptr()
22 }
23
24 #[inline]
26 pub fn as_mut_ptr(&mut self) -> *mut T {
27 self.ptr.as_ptr()
28 }
29
30 pub fn as_slice(&self) -> &[T] {
32 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.len / std::mem::size_of::<T>()) }
33 }
34
35 pub fn as_mut_slice(&mut self) -> &mut [T] {
37 unsafe {
38 std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.len / std::mem::size_of::<T>())
39 }
40 }
41
42 #[inline]
44 pub fn len(&self) -> usize {
45 self.len
46 }
47
48 #[inline]
50 pub fn is_empty(&self) -> bool {
51 self.len == 0
52 }
53
54 #[inline]
56 pub fn layout(&self) -> Layout {
57 self.layout
58 }
59
60 #[inline]
62 pub fn alloc(&self) -> &A {
63 &self.alloc
64 }
65
66 pub fn from_vec(value: Vec<T>, alloc: A) -> Self {
69 let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as _) };
73 let len = value.len() * std::mem::size_of::<T>();
74 let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
78 std::mem::forget(value);
79
80 Self {
81 ptr,
82 len,
83 layout,
84 alloc,
85 }
86 }
87
88 pub unsafe fn from_raw_parts(data: *const T, len: usize, alloc: A) -> Self {
94 let ptr = NonNull::new_unchecked(data as _);
95 let layout = Layout::from_size_align_unchecked(len, std::mem::size_of::<T>());
96 Self {
97 ptr,
98 len,
99 layout,
100 alloc,
101 }
102 }
103
104 pub fn into_vec(self) -> Vec<T> {
108 let _layout = &self.layout;
110
111 let vec_capacity = self.layout.size() / std::mem::size_of::<T>();
112 let length = self.len;
118 let ptr = self.ptr;
119 let vec_len = length / std::mem::size_of::<T>();
120
121 std::mem::forget(self);
123 unsafe { Vec::from_raw_parts(ptr.as_ptr(), vec_len, vec_capacity) }
124 }
125}
126
127unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
130unsafe impl<T, A: TensorAllocator> Sync for TensorStorage<T, A> {}
131
132impl<T, A: TensorAllocator> Drop for TensorStorage<T, A> {
133 fn drop(&mut self) {
134 self.alloc
135 .dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
136 }
137}
138impl<T, A> Clone for TensorStorage<T, A>
140where
141 T: Clone,
142 A: TensorAllocator,
143{
144 fn clone(&self) -> Self {
145 Self::from_vec(self.as_slice().to_vec(), self.alloc.clone())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151
152 use super::TensorStorage;
153 use crate::allocator::{CpuAllocator, TensorAllocatorError};
154 use crate::TensorAllocator;
155 use std::alloc::Layout;
156 use std::cell::RefCell;
157 use std::ptr::NonNull;
158 use std::rc::Rc;
159
160 #[test]
161 fn test_tensor_buffer_create_raw() -> Result<(), TensorAllocatorError> {
162 let size = 8;
163 let allocator = CpuAllocator;
164 let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
165 let ptr =
166 NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
167 let ptr_raw = ptr.as_ptr();
168
169 let buffer = TensorStorage {
170 alloc: allocator,
171 len: size * std::mem::size_of::<u8>(),
172 layout,
173 ptr,
174 };
175
176 assert_eq!(buffer.ptr.as_ptr(), ptr_raw);
177 assert!(!ptr_raw.is_null());
178 assert_eq!(buffer.layout, layout);
179 assert_eq!(buffer.len(), size);
180 assert!(!buffer.is_empty());
181 assert_eq!(buffer.len(), size * std::mem::size_of::<u8>());
182
183 Ok(())
184 }
185
186 #[test]
187 fn test_tensor_buffer_ptr() -> Result<(), TensorAllocatorError> {
188 let size = 8;
189 let allocator = CpuAllocator;
190 let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
191 let ptr =
192 NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
193
194 let ptr_raw = ptr.as_ptr() as usize;
196 let alignment = std::mem::align_of::<u8>();
197 assert_eq!(ptr_raw % alignment, 0);
198
199 Ok(())
200 }
201
202 #[test]
203 fn test_tensor_buffer_create_f32() -> Result<(), TensorAllocatorError> {
204 let size = 8;
205 let allocator = CpuAllocator;
206 let layout = Layout::array::<f32>(size).map_err(TensorAllocatorError::LayoutError)?;
207 let ptr =
208 NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
209
210 let buffer = TensorStorage {
211 alloc: allocator,
212 len: size,
213 layout,
214 ptr: ptr.cast::<f32>(),
215 };
216
217 assert_eq!(buffer.as_ptr(), ptr.as_ptr() as *const f32);
218 assert_eq!(buffer.layout, layout);
219 assert_eq!(buffer.len(), size);
220
221 Ok(())
222 }
223
224 #[test]
225 fn test_tensor_buffer_lifecycle() -> Result<(), TensorAllocatorError> {
226 #[derive(Clone)]
228 struct TestAllocator {
229 bytes_allocated: Rc<RefCell<i32>>,
230 }
231
232 impl TensorAllocator for TestAllocator {
233 fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
234 *self.bytes_allocated.borrow_mut() += layout.size() as i32;
235 CpuAllocator.alloc(layout)
236 }
237 fn dealloc(&self, ptr: *mut u8, layout: Layout) {
238 *self.bytes_allocated.borrow_mut() -= layout.size() as i32;
239 CpuAllocator.dealloc(ptr, layout)
240 }
241 }
242
243 let allocator = TestAllocator {
244 bytes_allocated: Rc::new(RefCell::new(0)),
245 };
246 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
247
248 let size = 1024;
249
250 {
254 let vec = Vec::<u8>::with_capacity(size);
255 let vec_ptr = vec.as_ptr();
256 let vec_capacity = vec.capacity();
257
258 let buffer = TensorStorage::from_vec(vec, allocator.clone());
259 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
260
261 let result_vec = buffer.into_vec();
262 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
263
264 assert_eq!(result_vec.capacity(), vec_capacity);
265 assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
266 }
267 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
268
269 Ok(())
270 }
271
272 #[test]
273 fn test_tensor_buffer_from_vec() -> Result<(), TensorAllocatorError> {
274 let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
275 let vec_ptr = vec.as_ptr();
276 let vec_len = vec.len();
277
278 let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
279
280 let buffer_ptr = buffer.as_ptr();
282 assert!(std::ptr::eq(buffer_ptr, vec_ptr));
283
284 let buffer_ptr = buffer.as_ptr() as usize;
286 let alignment = std::mem::align_of::<i32>();
287 assert_eq!(buffer_ptr % alignment, 0);
288
289 let data = buffer.as_slice();
291 assert_eq!(data.len(), vec_len);
292 assert_eq!(data[0], 1);
293 assert_eq!(data[1], 2);
294 assert_eq!(data[2], 3);
295 assert_eq!(data[3], 4);
296 assert_eq!(data[4], 5);
297
298 assert_eq!(data.first(), Some(&1));
299 assert_eq!(data.get(1), Some(&2));
300 assert_eq!(data.get(2), Some(&3));
301 assert_eq!(data.get(3), Some(&4));
302 assert_eq!(data.get(4), Some(&5));
303 assert_eq!(data.get(5), None);
304
305 unsafe {
306 assert_eq!(data.get_unchecked(0), &1);
307 assert_eq!(data.get_unchecked(1), &2);
308 assert_eq!(data.get_unchecked(2), &3);
309 assert_eq!(data.get_unchecked(3), &4);
310 assert_eq!(data.get_unchecked(4), &5);
311 }
312
313 Ok(())
314 }
315
316 #[test]
317 fn test_tensor_buffer_into_vec() -> Result<(), TensorAllocatorError> {
318 let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
319 let vec_ptr = vec.as_ptr();
320 let vec_cap = vec.capacity();
321
322 let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
323
324 let result_vec = buffer.into_vec();
326
327 assert_eq!(result_vec.capacity(), vec_cap);
329 assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
330
331 Ok(())
332 }
333
334 #[test]
335 fn test_tensor_mutability() -> Result<(), TensorAllocatorError> {
336 let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
337 let mut buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
338 let ptr_mut = buffer.as_mut_ptr();
339 unsafe {
340 *ptr_mut.add(0) = 10;
341 }
342 assert_eq!(buffer.into_vec(), vec![10, 2, 3, 4, 5]);
343 Ok(())
344 }
345}