Skip to main content

axonml_core/
storage.rs

1//! Storage - Raw Memory Management for Tensors
2//!
3//! Provides efficient memory storage that underlies all tensor operations.
4//! Storage is reference-counted for efficient sharing between tensor views.
5//!
6//! # Key Features
7//! - Reference-counted memory for efficient views
8//! - Device-agnostic storage interface (CPU and GPU)
9//! - Zero-copy slicing through offset/length
10//! - Automatic memory cleanup
11//!
12//! @version 0.2.0
13//! @author `AutomataNexus` Development Team
14
15use core::ops::{Deref, DerefMut};
16use std::sync::Arc;
17
18use parking_lot::RwLock;
19
20use crate::device::Device;
21use crate::dtype::Scalar;
22use crate::error::{Error, Result};
23
24#[cfg(feature = "cuda")]
25use cudarc::driver::CudaSlice;
26#[cfg(feature = "cuda")]
27use cudarc::driver::safe::DeviceSlice;
28
29// =============================================================================
30// Storage Data Enum
31// =============================================================================
32
33/// Wrapper around CudaSlice that returns memory to the pool on drop
34/// instead of calling cudaFree.
35#[cfg(feature = "cuda")]
36pub struct PooledCudaSlice {
37    slice: Option<CudaSlice<f32>>,
38    pool_managed: bool,
39}
40
41#[cfg(feature = "cuda")]
42impl std::fmt::Debug for PooledCudaSlice {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("PooledCudaSlice")
45            .field("pool_managed", &self.pool_managed)
46            .field("len", &self.slice.as_ref().map(|s| s.len()))
47            .finish()
48    }
49}
50
51#[cfg(feature = "cuda")]
52impl Drop for PooledCudaSlice {
53    fn drop(&mut self) {
54        if let Some(slice) = self.slice.take() {
55            if self.pool_managed {
56                crate::backends::cuda_pool::pool_free(slice);
57            }
58            // else: normal CudaSlice drop calls cudaFree
59        }
60    }
61}
62
63#[cfg(feature = "cuda")]
64impl PooledCudaSlice {
65    /// Create a new pool-managed CUDA slice.
66    pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
67        Self { slice: Some(slice), pool_managed }
68    }
69
70    /// Get a reference to the underlying CudaSlice.
71    pub fn slice(&self) -> &CudaSlice<f32> {
72        self.slice.as_ref().expect("CudaSlice already taken")
73    }
74
75    /// Get a mutable reference to the underlying CudaSlice.
76    pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
77        self.slice.as_mut().expect("CudaSlice already taken")
78    }
79}
80
81/// Holds either CPU or GPU data.
82#[derive(Debug)]
83enum StorageData<T: Scalar> {
84    /// CPU data stored as a Vec.
85    Cpu(Vec<T>),
86    /// GPU data stored as a PooledCudaSlice (f32 only on GPU).
87    /// Returns to memory pool on drop instead of calling cudaFree.
88    #[cfg(feature = "cuda")]
89    Cuda(PooledCudaSlice),
90}
91
92// =============================================================================
93// Storage Struct
94// =============================================================================
95
96/// Raw memory storage for tensor data.
97///
98/// Storage is the fundamental building block for tensors. It manages a contiguous
99/// block of memory on a specific device and is reference-counted to allow
100/// efficient sharing between tensor views.
101#[derive(Debug)]
102pub struct Storage<T: Scalar> {
103    /// The underlying data buffer.
104    inner: Arc<RwLock<StorageInner<T>>>,
105    /// Offset into the storage (for views).
106    offset: usize,
107    /// Number of elements in this view.
108    len: usize,
109}
110
111/// Inner storage data that can be shared between views.
112#[derive(Debug)]
113struct StorageInner<T: Scalar> {
114    /// Raw data (CPU or GPU).
115    data: StorageData<T>,
116    /// The device this storage resides on.
117    device: Device,
118}
119
120impl<T: Scalar> Storage<T> {
121    /// Creates new storage with the given capacity, initialized to zero.
122    #[must_use]
123    pub fn zeros(len: usize, device: Device) -> Self {
124        let data = vec![T::zeroed(); len];
125        Self::from_vec(data, device)
126    }
127
128    /// Creates storage from an existing vector (always on CPU).
129    #[must_use]
130    pub fn from_vec(data: Vec<T>, device: Device) -> Self {
131        let len = data.len();
132        Self {
133            inner: Arc::new(RwLock::new(StorageInner {
134                data: StorageData::Cpu(data),
135                device,
136            })),
137            offset: 0,
138            len,
139        }
140    }
141
142    /// Creates storage from a slice by copying the data.
143    #[must_use]
144    pub fn from_slice(data: &[T], device: Device) -> Self {
145        Self::from_vec(data.to_vec(), device)
146    }
147
148    /// Returns the number of elements in this storage view.
149    #[must_use]
150    pub const fn len(&self) -> usize {
151        self.len
152    }
153
154    /// Returns true if the storage is empty.
155    #[must_use]
156    pub const fn is_empty(&self) -> bool {
157        self.len == 0
158    }
159
160    /// Returns the offset into the underlying buffer.
161    #[must_use]
162    pub const fn offset(&self) -> usize {
163        self.offset
164    }
165
166    /// Returns the device this storage is on.
167    #[must_use]
168    pub fn device(&self) -> Device {
169        self.inner.read().device
170    }
171
172    /// Returns true if data is on CPU.
173    #[must_use]
174    pub fn is_cpu(&self) -> bool {
175        matches!(self.inner.read().data, StorageData::Cpu(_))
176    }
177
178    /// Returns true if data is on GPU.
179    #[must_use]
180    pub fn is_gpu(&self) -> bool {
181        !self.is_cpu()
182    }
183
184    /// Returns the size in bytes of this storage.
185    #[must_use]
186    pub fn size_bytes(&self) -> usize {
187        self.len * core::mem::size_of::<T>()
188    }
189
190    /// Creates a view into a portion of this storage.
191    pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
192        if offset + len > self.len {
193            return Err(Error::IndexOutOfBounds {
194                index: offset + len,
195                size: self.len,
196            });
197        }
198
199        Ok(Self {
200            inner: Arc::clone(&self.inner),
201            offset: self.offset + offset,
202            len,
203        })
204    }
205
206    /// Returns true if this storage is uniquely owned (not shared).
207    #[must_use]
208    pub fn is_unique(&self) -> bool {
209        Arc::strong_count(&self.inner) == 1
210    }
211
212    /// Returns an immutable reference to the CPU data.
213    ///
214    /// # Panics
215    /// Panics if the storage is on GPU. Use `to_vec()` for device-safe access.
216    #[must_use]
217    pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
218        StorageReadGuard {
219            guard: self.inner.read(),
220            offset: self.offset,
221            len: self.len,
222        }
223    }
224
225    /// Returns a mutable reference to the CPU data.
226    ///
227    /// # Panics
228    /// Panics if the storage is on GPU.
229    #[must_use]
230    pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
231        StorageWriteGuard {
232            guard: self.inner.write(),
233            offset: self.offset,
234            len: self.len,
235        }
236    }
237
238    /// Copies data from another storage into this one.
239    pub fn copy_from(&self, other: &Self) -> Result<()> {
240        if self.len != other.len {
241            return Err(Error::shape_mismatch(&[self.len], &[other.len]));
242        }
243
244        let src = other.as_slice();
245        let mut dst = self.as_slice_mut();
246        dst.copy_from_slice(&src);
247        Ok(())
248    }
249
250    /// Makes a deep copy of this storage.
251    ///
252    /// For GPU storage, this only works for `Storage<f32>`.
253    /// Other types will panic on GPU storage.
254    #[must_use]
255    pub fn deep_copy(&self) -> Self {
256        let inner = self.inner.read();
257        match &inner.data {
258            StorageData::Cpu(cpu_data) => {
259                let data = cpu_data[self.offset..self.offset + self.len].to_vec();
260                Self::from_vec(data, inner.device)
261            }
262            #[cfg(feature = "cuda")]
263            StorageData::Cuda(_) => {
264                panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
265            }
266        }
267    }
268
269    /// Returns the data as a Vec from CPU storage.
270    ///
271    /// For GPU-resident f32 tensors, use the `Tensor::to_vec()` method which
272    /// handles device-aware copies via `Storage<f32>::to_vec_f32()`.
273    ///
274    /// # Panics
275    /// Panics if storage is on GPU (use `to_vec_f32()` on `Storage<f32>` instead).
276    pub fn to_vec(&self) -> Vec<T> {
277        let inner = self.inner.read();
278        match &inner.data {
279            StorageData::Cpu(cpu_data) => {
280                cpu_data[self.offset..self.offset + self.len].to_vec()
281            }
282            #[cfg(feature = "cuda")]
283            StorageData::Cuda(_) => {
284                panic!("Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>.");
285            }
286        }
287    }
288
289    /// Transfers this storage to a different device.
290    ///
291    /// For GPU transfers, only `Storage<f32>` is supported. Use the
292    /// `Storage<f32>::to_device()` specialization for CPU↔GPU transfers.
293    pub fn to_device(&self, device: Device) -> Result<Self> {
294        if self.device() == device {
295            return Ok(self.clone());
296        }
297
298        // Generic path: only CPU→CPU is supported
299        if device.is_cpu() && self.device().is_cpu() {
300            return Ok(self.deep_copy());
301        }
302
303        Err(Error::DeviceNotAvailable { device })
304    }
305}
306
307// =============================================================================
308// f32-specific Storage for GPU transfers
309// =============================================================================
310
311#[cfg(feature = "cuda")]
312impl Storage<f32> {
313    /// Transfers f32 storage between CPU and GPU.
314    pub fn to_device_f32(&self, device: Device) -> Result<Self> {
315        if self.device() == device {
316            return Ok(self.clone());
317        }
318
319        let inner = self.inner.read();
320
321        match (&inner.data, device) {
322            // CPU → CPU: just deep copy
323            (StorageData::Cpu(_), Device::Cpu) => {
324                drop(inner);
325                Ok(self.deep_copy())
326            }
327            // CPU → GPU: htod_copy
328            (StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
329                let backend = crate::backends::cuda::get_cuda_backend()
330                    .ok_or(Error::DeviceNotAvailable { device })?;
331                let slice = &cpu_data[self.offset..self.offset + self.len];
332                let cuda_slice = backend.htod_copy(slice)
333                    .map_err(|_| Error::DeviceNotAvailable { device })?;
334                let len = self.len;
335                Ok(Self {
336                    inner: Arc::new(RwLock::new(StorageInner {
337                        data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
338                        device,
339                    })),
340                    offset: 0,
341                    len,
342                })
343            }
344            // GPU → CPU: dtoh_copy
345            (StorageData::Cuda(pooled), Device::Cpu) => {
346                let backend = crate::backends::cuda::get_cuda_backend()
347                    .ok_or(Error::DeviceNotAvailable { device: self.device() })?;
348                let full_vec = backend.dtoh_copy(pooled.slice())
349                    .map_err(|_| Error::DeviceNotAvailable { device })?;
350                let end = self.offset + self.len;
351                let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
352                    full_vec
353                } else if end <= full_vec.len() {
354                    full_vec[self.offset..end].to_vec()
355                } else {
356                    // CudaSlice is smaller than Storage.len — this indicates a bug
357                    // but handle gracefully: copy what we have, zero-pad the rest
358                    eprintln!(
359                        "[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
360                        full_vec.len(), end, self.offset, self.len
361                    );
362                    let available = if self.offset < full_vec.len() {
363                        full_vec.len() - self.offset
364                    } else {
365                        0
366                    };
367                    let mut result = vec![0.0f32; self.len];
368                    if available > 0 {
369                        result[..available].copy_from_slice(&full_vec[self.offset..self.offset + available]);
370                    }
371                    result
372                };
373                Ok(Self::from_vec(sliced, Device::Cpu))
374            }
375            // GPU → GPU: D2H then H2D (simple path)
376            (StorageData::Cuda(_), Device::Cuda(_)) => {
377                drop(inner);
378                let cpu_storage = self.to_device_f32(Device::Cpu)?;
379                cpu_storage.to_device_f32(device)
380            }
381            _ => Err(Error::DeviceNotAvailable { device }),
382        }
383    }
384}
385
386// =============================================================================
387// CUDA-specific Storage methods
388// =============================================================================
389
390#[cfg(feature = "cuda")]
391impl Storage<f32> {
392    /// Returns data as a Vec<f32>, performing D2H copy if on GPU.
393    pub fn to_vec_f32(&self) -> Vec<f32> {
394        let inner = self.inner.read();
395        match &inner.data {
396            StorageData::Cpu(cpu_data) => {
397                cpu_data[self.offset..self.offset + self.len].to_vec()
398            }
399            StorageData::Cuda(pooled) => {
400                if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
401                    if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
402                        if self.offset == 0 && self.len == full_vec.len() {
403                            return full_vec;
404                        }
405                        return full_vec[self.offset..self.offset + self.len].to_vec();
406                    }
407                }
408                vec![0.0f32; self.len]
409            }
410        }
411    }
412
413    /// Deep copy that works for both CPU and GPU f32 storage.
414    pub fn deep_copy_f32(&self) -> Self {
415        let device = self.device();
416        let vec = self.to_vec_f32();
417        if device.is_gpu() {
418            if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
419                if let Ok(new_slice) = backend.htod_copy(&vec) {
420                    return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
421                }
422            }
423        }
424        Self::from_vec(vec, device)
425    }
426
427    /// Creates storage from a pool-allocated CudaSlice.
428    ///
429    /// The slice will be returned to the CUDA memory pool on drop.
430    /// Only use this for CudaSlices from `pool_alloc()` — the slice must be
431    /// bucket-sized for correct pool reuse.
432    pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
433        Self {
434            inner: Arc::new(RwLock::new(StorageInner {
435                data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
436                device,
437            })),
438            offset: 0,
439            len,
440        }
441    }
442
443    /// Creates storage from a non-pool CudaSlice (e.g., from htod_copy).
444    ///
445    /// The slice will be freed via cudaFree on drop (normal CUDA deallocation),
446    /// NOT returned to the memory pool.
447    pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
448        Self {
449            inner: Arc::new(RwLock::new(StorageInner {
450                data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
451                device,
452            })),
453            offset: 0,
454            len,
455        }
456    }
457
458
459    /// Returns a reference to the CudaSlice if on GPU.
460    ///
461    /// # Panics
462    /// Panics if storage is not on GPU.
463    pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
464        CudaSliceReadGuard {
465            guard: self.inner.read(),
466        }
467    }
468
469    /// Returns a write guard providing mutable access to the underlying `CudaSlice<f32>`.
470    ///
471    /// # Panics
472    /// Panics if storage is not on GPU.
473    pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
474        CudaSliceWriteGuard {
475            guard: self.inner.write(),
476        }
477    }
478}
479
480/// Read guard that provides access to the CudaSlice.
481#[cfg(feature = "cuda")]
482pub struct CudaSliceReadGuard<'a> {
483    guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
484}
485
486#[cfg(feature = "cuda")]
487impl<'a> CudaSliceReadGuard<'a> {
488    /// Returns a reference to the CudaSlice.
489    ///
490    /// # Panics
491    /// Panics if storage is CPU.
492    pub fn slice(&self) -> &CudaSlice<f32> {
493        match &self.guard.data {
494            StorageData::Cuda(pooled) => pooled.slice(),
495            StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
496        }
497    }
498}
499
500/// Write guard that provides mutable access to the CudaSlice.
501#[cfg(feature = "cuda")]
502pub struct CudaSliceWriteGuard<'a> {
503    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
504}
505
506#[cfg(feature = "cuda")]
507impl<'a> CudaSliceWriteGuard<'a> {
508    /// Returns a mutable reference to the CudaSlice.
509    ///
510    /// # Panics
511    /// Panics if storage is CPU.
512    pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
513        match &mut self.guard.data {
514            StorageData::Cuda(pooled) => pooled.slice_mut(),
515            StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
516        }
517    }
518}
519
520impl<T: Scalar> Clone for Storage<T> {
521    fn clone(&self) -> Self {
522        Self {
523            inner: Arc::clone(&self.inner),
524            offset: self.offset,
525            len: self.len,
526        }
527    }
528}
529
530// =============================================================================
531// Guard Types for Safe Access
532// =============================================================================
533
534/// Read guard for storage data.
535pub struct StorageReadGuard<'a, T: Scalar> {
536    guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
537    offset: usize,
538    len: usize,
539}
540
541impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
542    type Target = [T];
543
544    fn deref(&self) -> &Self::Target {
545        match &self.guard.data {
546            StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
547            #[cfg(feature = "cuda")]
548            StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."),
549        }
550    }
551}
552
553/// Write guard for storage data.
554pub struct StorageWriteGuard<'a, T: Scalar> {
555    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
556    offset: usize,
557    len: usize,
558}
559
560impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
561    type Target = [T];
562
563    fn deref(&self) -> &Self::Target {
564        match &self.guard.data {
565            StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
566            #[cfg(feature = "cuda")]
567            StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
568        }
569    }
570}
571
572impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
573    fn deref_mut(&mut self) -> &mut Self::Target {
574        match &mut self.guard.data {
575            StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
576            #[cfg(feature = "cuda")]
577            StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
578        }
579    }
580}
581
582// =============================================================================
583// Tests
584// =============================================================================
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_storage_zeros() {
592        let storage = Storage::<f32>::zeros(10, Device::Cpu);
593        assert_eq!(storage.len(), 10);
594        assert!(!storage.is_empty());
595
596        let data = storage.as_slice();
597        for &val in data.iter() {
598            assert_eq!(val, 0.0);
599        }
600    }
601
602    #[test]
603    fn test_storage_from_vec() {
604        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
605        let storage = Storage::from_vec(vec.clone(), Device::Cpu);
606
607        let data = storage.as_slice();
608        assert_eq!(&*data, &vec[..]);
609    }
610
611    #[test]
612    fn test_storage_slice() {
613        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
614        let storage = Storage::from_vec(vec, Device::Cpu);
615        let slice = storage.slice(1, 3).unwrap();
616
617        assert_eq!(slice.len(), 3);
618        let data = slice.as_slice();
619        assert_eq!(&*data, &[2.0, 3.0, 4.0]);
620    }
621
622    #[test]
623    fn test_storage_clone_shares() {
624        let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
625        let storage2 = storage1.clone();
626
627        assert!(!storage1.is_unique());
628        assert!(!storage2.is_unique());
629    }
630
631    #[test]
632    fn test_storage_deep_copy() {
633        let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
634        let storage2 = storage1.deep_copy();
635
636        assert!(storage1.is_unique());
637        assert!(storage2.is_unique());
638
639        // Modify storage2
640        storage2.as_slice_mut()[0] = 99.0;
641
642        // storage1 should be unchanged
643        assert_eq!(storage1.as_slice()[0], 1.0);
644    }
645
646    #[test]
647    fn test_storage_copy_from() {
648        let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
649        let dst = Storage::<f32>::zeros(3, Device::Cpu);
650
651        dst.copy_from(&src).unwrap();
652
653        let data = dst.as_slice();
654        assert_eq!(&*data, &[1.0, 2.0, 3.0]);
655    }
656
657    #[test]
658    fn test_storage_slice_out_of_bounds() {
659        let storage = Storage::<f32>::zeros(10, Device::Cpu);
660        let result = storage.slice(5, 10);
661        assert!(result.is_err());
662    }
663
664    #[test]
665    fn test_storage_to_vec_cpu() {
666        let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
667        assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
668    }
669
670    #[test]
671    fn test_storage_is_cpu() {
672        let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
673        assert!(storage.is_cpu());
674        assert!(!storage.is_gpu());
675    }
676}