1use std::{alloc::Layout, ptr::NonNull};
2
3use crate::allocator::TensorAllocator;
4
5pub struct TensorStorage<T, A: TensorAllocator> {
7 pub(crate) ptr: NonNull<T>,
9 pub(crate) len: usize,
11 pub(crate) layout: Layout,
13 pub(crate) alloc: A,
15}
16
17impl<T, A: TensorAllocator> TensorStorage<T, A> {
18 #[inline]
20 pub fn as_ptr(&self) -> *const T {
21 self.ptr.as_ptr()
22 }
23
24 #[inline]
26 pub fn as_mut_ptr(&mut self) -> *mut T {
27 self.ptr.as_ptr()
28 }
29
30 pub fn as_slice(&self) -> &[T] {
32 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.len / std::mem::size_of::<T>()) }
33 }
34
35 pub fn as_mut_slice(&mut self) -> &mut [T] {
37 unsafe {
38 std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.len / std::mem::size_of::<T>())
39 }
40 }
41
42 #[inline]
44 pub fn len(&self) -> usize {
45 self.len
46 }
47
48 #[inline]
50 pub fn is_empty(&self) -> bool {
51 self.len == 0
52 }
53
54 #[inline]
56 pub fn layout(&self) -> Layout {
57 self.layout
58 }
59
60 #[inline]
62 pub fn alloc(&self) -> &A {
63 &self.alloc
64 }
65
66 pub fn from_vec(value: Vec<T>, alloc: A) -> Self {
69 let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as _) };
73 let len = value.len() * std::mem::size_of::<T>();
74 let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
78 std::mem::forget(value);
79
80 Self {
81 ptr,
82 len,
83 layout,
84 alloc,
85 }
86 }
87
88 pub fn into_vec(self) -> Vec<T> {
92 let _layout = &self.layout;
94
95 let vec_capacity = self.layout.size() / std::mem::size_of::<T>();
96 let length = self.len;
102 let ptr = self.ptr;
103 let vec_len = length / std::mem::size_of::<T>();
104
105 std::mem::forget(self);
107 unsafe { Vec::from_raw_parts(ptr.as_ptr(), vec_len, vec_capacity) }
108 }
109}
110
111impl<T, A: TensorAllocator> From<Vec<T>> for TensorStorage<T, A>
113where
114 A: Default,
115{
116 fn from(value: Vec<T>) -> Self {
118 let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as *mut T) };
121 let len = value.len() * std::mem::size_of::<T>();
122 let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
126 std::mem::forget(value);
127
128 Self {
129 ptr,
130 len,
131 layout,
132 alloc: A::default(),
133 }
134 }
135}
136unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
139unsafe impl<T, A: TensorAllocator> Sync for TensorStorage<T, A> {}
140
141impl<T, A: TensorAllocator> Drop for TensorStorage<T, A> {
142 fn drop(&mut self) {
143 self.alloc
144 .dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
145 }
146}
147impl<T, A> Clone for TensorStorage<T, A>
149where
150 T: Clone,
151 A: TensorAllocator + 'static,
152{
153 fn clone(&self) -> Self {
154 let mut new_vec = Vec::<T>::with_capacity(self.len());
155
156 for i in self.as_slice() {
157 new_vec.push(i.clone());
158 }
159
160 Self::from_vec(new_vec, self.alloc.clone())
161 }
162}
163
164#[cfg(test)]
165mod tests {
166
167 use super::TensorStorage;
168 use crate::allocator::{CpuAllocator, TensorAllocatorError};
169 use crate::TensorAllocator;
170 use std::alloc::Layout;
171 use std::cell::RefCell;
172 use std::ptr::NonNull;
173 use std::rc::Rc;
174
175 #[test]
176 fn test_tensor_buffer_create_raw() -> Result<(), TensorAllocatorError> {
177 let size = 8;
178 let allocator = CpuAllocator;
179 let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
180 let ptr =
181 NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
182 let ptr_raw = ptr.as_ptr();
183
184 let buffer = TensorStorage {
185 alloc: allocator,
186 len: size * std::mem::size_of::<u8>(),
187 layout,
188 ptr,
189 };
190
191 assert_eq!(buffer.ptr.as_ptr(), ptr_raw);
192 assert!(!ptr_raw.is_null());
193 assert_eq!(buffer.layout, layout);
194 assert_eq!(buffer.len(), size);
195 assert!(!buffer.is_empty());
196 assert_eq!(buffer.len(), size * std::mem::size_of::<u8>());
197
198 Ok(())
199 }
200
201 #[test]
202 fn test_tensor_buffer_ptr() -> Result<(), TensorAllocatorError> {
203 let size = 8;
204 let allocator = CpuAllocator;
205 let layout = Layout::array::<u8>(size).map_err(TensorAllocatorError::LayoutError)?;
206 let ptr =
207 NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
208
209 let ptr_raw = ptr.as_ptr() as usize;
211 let alignment = std::mem::align_of::<u8>();
212 assert_eq!(ptr_raw % alignment, 0);
213
214 Ok(())
215 }
216
217 #[test]
218 fn test_tensor_buffer_create_f32() -> Result<(), TensorAllocatorError> {
219 let size = 8;
220 let allocator = CpuAllocator;
221 let layout = Layout::array::<f32>(size).map_err(TensorAllocatorError::LayoutError)?;
222 let ptr =
223 NonNull::new(allocator.alloc(layout)?).ok_or(TensorAllocatorError::NullPointer)?;
224
225 let buffer = TensorStorage {
226 alloc: allocator,
227 len: size,
228 layout,
229 ptr: ptr.cast::<f32>(),
230 };
231
232 assert_eq!(buffer.as_ptr(), ptr.as_ptr() as *const f32);
233 assert_eq!(buffer.layout, layout);
234 assert_eq!(buffer.len(), size);
235
236 Ok(())
237 }
238
239 #[test]
240 fn test_tensor_buffer_lifecycle() -> Result<(), TensorAllocatorError> {
241 #[derive(Clone)]
243 struct TestAllocator {
244 bytes_allocated: Rc<RefCell<i32>>,
245 }
246
247 impl TensorAllocator for TestAllocator {
248 fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
249 *self.bytes_allocated.borrow_mut() += layout.size() as i32;
250 CpuAllocator.alloc(layout)
251 }
252 fn dealloc(&self, ptr: *mut u8, layout: Layout) {
253 *self.bytes_allocated.borrow_mut() -= layout.size() as i32;
254 CpuAllocator.dealloc(ptr, layout)
255 }
256 }
257
258 let allocator = TestAllocator {
259 bytes_allocated: Rc::new(RefCell::new(0)),
260 };
261 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
262
263 let size = 1024;
264
265 {
269 let vec = Vec::<u8>::with_capacity(size);
270 let vec_ptr = vec.as_ptr();
271 let vec_capacity = vec.capacity();
272
273 let buffer = TensorStorage::from_vec(vec, allocator.clone());
274 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
275
276 let result_vec = buffer.into_vec();
277 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
278
279 assert_eq!(result_vec.capacity(), vec_capacity);
280 assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
281 }
282 assert_eq!(*allocator.bytes_allocated.borrow(), 0);
283
284 Ok(())
285 }
286
287 #[test]
288 fn test_tensor_buffer_from_vec() -> Result<(), TensorAllocatorError> {
289 let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
290 let vec_ptr = vec.as_ptr();
291 let vec_len = vec.len();
292
293 let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
294
295 let buffer_ptr = buffer.as_ptr();
297 assert!(std::ptr::eq(buffer_ptr, vec_ptr));
298
299 let buffer_ptr = buffer.as_ptr() as usize;
301 let alignment = std::mem::align_of::<i32>();
302 assert_eq!(buffer_ptr % alignment, 0);
303
304 let data = buffer.as_slice();
306 assert_eq!(data.len(), vec_len);
307 assert_eq!(data[0], 1);
308 assert_eq!(data[1], 2);
309 assert_eq!(data[2], 3);
310 assert_eq!(data[3], 4);
311 assert_eq!(data[4], 5);
312
313 assert_eq!(data.first(), Some(&1));
314 assert_eq!(data.get(1), Some(&2));
315 assert_eq!(data.get(2), Some(&3));
316 assert_eq!(data.get(3), Some(&4));
317 assert_eq!(data.get(4), Some(&5));
318 assert_eq!(data.get(5), None);
319
320 unsafe {
321 assert_eq!(data.get_unchecked(0), &1);
322 assert_eq!(data.get_unchecked(1), &2);
323 assert_eq!(data.get_unchecked(2), &3);
324 assert_eq!(data.get_unchecked(3), &4);
325 assert_eq!(data.get_unchecked(4), &5);
326 }
327
328 Ok(())
329 }
330
331 #[test]
332 fn test_tensor_buffer_into_vec() -> Result<(), TensorAllocatorError> {
333 let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
334 let vec_ptr = vec.as_ptr();
335 let vec_cap = vec.capacity();
336
337 let buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
338
339 let result_vec = buffer.into_vec();
341
342 assert_eq!(result_vec.capacity(), vec_cap);
344 assert!(std::ptr::eq(result_vec.as_ptr(), vec_ptr));
345
346 Ok(())
347 }
348
349 #[test]
350 fn test_tensor_mutability() -> Result<(), TensorAllocatorError> {
351 let vec: Vec<i32> = vec![1, 2, 3, 4, 5];
352 let mut buffer = TensorStorage::<_, CpuAllocator>::from_vec(vec, CpuAllocator);
353 let ptr_mut = buffer.as_mut_ptr();
354 unsafe {
355 *ptr_mut.add(0) = 10;
356 }
357 assert_eq!(buffer.into_vec(), vec![10, 2, 3, 4, 5]);
358 Ok(())
359 }
360}