Skip to main content

axonml_core/
storage.rs

1//! Storage - Raw Memory Management for Tensors
2//!
3//! # File
4//! `crates/axonml-core/src/storage.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use core::ops::{Deref, DerefMut};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use crate::device::Device;
23use crate::dtype::Scalar;
24use crate::error::{Error, Result};
25
26#[cfg(feature = "cuda")]
27use cudarc::driver::CudaSlice;
28#[cfg(feature = "cuda")]
29use cudarc::driver::DeviceSlice;
30
31// =============================================================================
32// Storage Data Enum
33// =============================================================================
34
35/// Wrapper around CudaSlice that returns memory to the pool on drop
36/// instead of calling cudaFree.
37#[cfg(feature = "cuda")]
38pub struct PooledCudaSlice {
39    slice: Option<CudaSlice<f32>>,
40    pool_managed: bool,
41}
42
43#[cfg(feature = "cuda")]
44impl std::fmt::Debug for PooledCudaSlice {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("PooledCudaSlice")
47            .field("pool_managed", &self.pool_managed)
48            .field("len", &self.slice.as_ref().map(|s| s.len()))
49            .finish()
50    }
51}
52
53#[cfg(feature = "cuda")]
54impl Drop for PooledCudaSlice {
55    fn drop(&mut self) {
56        if let Some(slice) = self.slice.take() {
57            if self.pool_managed {
58                crate::backends::cuda_pool::pool_free(slice);
59            }
60            // else: normal CudaSlice drop calls cudaFree
61        }
62    }
63}
64
65#[cfg(feature = "cuda")]
66impl PooledCudaSlice {
67    /// Create a new pool-managed CUDA slice.
68    pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
69        Self {
70            slice: Some(slice),
71            pool_managed,
72        }
73    }
74
75    /// Get a reference to the underlying CudaSlice.
76    pub fn slice(&self) -> &CudaSlice<f32> {
77        self.slice.as_ref().expect("CudaSlice already taken")
78    }
79
80    /// Get a mutable reference to the underlying CudaSlice.
81    pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
82        self.slice.as_mut().expect("CudaSlice already taken")
83    }
84}
85
86/// Holds either CPU or GPU data.
87#[derive(Debug)]
88enum StorageData<T: Scalar> {
89    /// CPU data stored as a Vec.
90    Cpu(Vec<T>),
91    /// GPU data stored as a PooledCudaSlice (f32 only on GPU).
92    /// Returns to memory pool on drop instead of calling cudaFree.
93    #[cfg(feature = "cuda")]
94    Cuda(PooledCudaSlice),
95}
96
97// =============================================================================
98// Storage Struct
99// =============================================================================
100
101/// Raw memory storage for tensor data.
102///
103/// Storage is the fundamental building block for tensors. It manages a contiguous
104/// block of memory on a specific device and is reference-counted to allow
105/// efficient sharing between tensor views.
106#[derive(Debug)]
107pub struct Storage<T: Scalar> {
108    /// The underlying data buffer.
109    inner: Arc<RwLock<StorageInner<T>>>,
110    /// Offset into the storage (for views).
111    offset: usize,
112    /// Number of elements in this view.
113    len: usize,
114}
115
116/// Inner storage data that can be shared between views.
117#[derive(Debug)]
118struct StorageInner<T: Scalar> {
119    /// Raw data (CPU or GPU).
120    data: StorageData<T>,
121    /// The device this storage resides on.
122    device: Device,
123}
124
125impl<T: Scalar> Storage<T> {
126    /// Creates new storage with the given capacity, initialized to zero.
127    #[must_use]
128    pub fn zeros(len: usize, device: Device) -> Self {
129        let data = vec![T::zeroed(); len];
130        Self::from_vec(data, device)
131    }
132
133    /// Creates storage from an existing vector (always on CPU).
134    #[must_use]
135    pub fn from_vec(data: Vec<T>, device: Device) -> Self {
136        let len = data.len();
137        Self {
138            inner: Arc::new(RwLock::new(StorageInner {
139                data: StorageData::Cpu(data),
140                device,
141            })),
142            offset: 0,
143            len,
144        }
145    }
146
147    /// Creates storage from a slice by copying the data.
148    #[must_use]
149    pub fn from_slice(data: &[T], device: Device) -> Self {
150        Self::from_vec(data.to_vec(), device)
151    }
152
153    /// Returns the number of elements in this storage view.
154    #[must_use]
155    pub const fn len(&self) -> usize {
156        self.len
157    }
158
159    /// Returns true if the storage is empty.
160    #[must_use]
161    pub const fn is_empty(&self) -> bool {
162        self.len == 0
163    }
164
165    /// Returns the offset into the underlying buffer.
166    #[must_use]
167    pub const fn offset(&self) -> usize {
168        self.offset
169    }
170
171    /// Returns the device this storage is on.
172    #[must_use]
173    pub fn device(&self) -> Device {
174        self.inner.read().device
175    }
176
177    /// Returns true if data is on CPU.
178    #[must_use]
179    pub fn is_cpu(&self) -> bool {
180        matches!(self.inner.read().data, StorageData::Cpu(_))
181    }
182
183    /// Returns true if data is on GPU.
184    #[must_use]
185    pub fn is_gpu(&self) -> bool {
186        !self.is_cpu()
187    }
188
189    /// Returns the size in bytes of this storage.
190    #[must_use]
191    pub fn size_bytes(&self) -> usize {
192        self.len * core::mem::size_of::<T>()
193    }
194
195    /// Creates a view into a portion of this storage.
196    pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
197        if offset + len > self.len {
198            return Err(Error::IndexOutOfBounds {
199                index: offset + len,
200                size: self.len,
201            });
202        }
203
204        Ok(Self {
205            inner: Arc::clone(&self.inner),
206            offset: self.offset + offset,
207            len,
208        })
209    }
210
211    /// Returns true if this storage is uniquely owned (not shared).
212    #[must_use]
213    pub fn is_unique(&self) -> bool {
214        Arc::strong_count(&self.inner) == 1
215    }
216
217    /// Returns an immutable reference to the CPU data.
218    ///
219    /// # Panics
220    /// Panics if the storage is on GPU. Use `to_vec()` for device-safe access.
221    #[must_use]
222    pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
223        StorageReadGuard {
224            guard: self.inner.read(),
225            offset: self.offset,
226            len: self.len,
227        }
228    }
229
230    /// Returns a mutable reference to the CPU data.
231    ///
232    /// # Panics
233    /// Panics if the storage is on GPU.
234    #[must_use]
235    pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
236        StorageWriteGuard {
237            guard: self.inner.write(),
238            offset: self.offset,
239            len: self.len,
240        }
241    }
242
243    /// Copies data from another storage into this one.
244    pub fn copy_from(&self, other: &Self) -> Result<()> {
245        if self.len != other.len {
246            return Err(Error::shape_mismatch(&[self.len], &[other.len]));
247        }
248
249        let src = other.as_slice();
250        let mut dst = self.as_slice_mut();
251        dst.copy_from_slice(&src);
252        Ok(())
253    }
254
255    /// Makes a deep copy of this storage.
256    ///
257    /// For GPU storage, this only works for `Storage<f32>`.
258    /// Other types will panic on GPU storage.
259    #[must_use]
260    pub fn deep_copy(&self) -> Self {
261        let inner = self.inner.read();
262        match &inner.data {
263            StorageData::Cpu(cpu_data) => {
264                let data = cpu_data[self.offset..self.offset + self.len].to_vec();
265                Self::from_vec(data, inner.device)
266            }
267            #[cfg(feature = "cuda")]
268            StorageData::Cuda(_) => {
269                panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
270            }
271        }
272    }
273
274    /// Returns the data as a Vec from CPU storage.
275    ///
276    /// For GPU-resident f32 tensors, use the `Tensor::to_vec()` method which
277    /// handles device-aware copies via `Storage<f32>::to_vec_f32()`.
278    ///
279    /// # Panics
280    /// Panics if storage is on GPU (use `to_vec_f32()` on `Storage<f32>` instead).
281    pub fn to_vec(&self) -> Vec<T> {
282        let inner = self.inner.read();
283        match &inner.data {
284            StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
285            #[cfg(feature = "cuda")]
286            StorageData::Cuda(_) => {
287                panic!(
288                    "Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>."
289                );
290            }
291        }
292    }
293
294    /// Transfers this storage to a different device.
295    ///
296    /// For GPU transfers, only `Storage<f32>` is supported. Use the
297    /// `Storage<f32>::to_device()` specialization for CPU↔GPU transfers.
298    pub fn to_device(&self, device: Device) -> Result<Self> {
299        if self.device() == device {
300            return Ok(self.clone());
301        }
302
303        // Generic path: only CPU→CPU is supported
304        if device.is_cpu() && self.device().is_cpu() {
305            return Ok(self.deep_copy());
306        }
307
308        Err(Error::DeviceNotAvailable { device })
309    }
310}
311
312// =============================================================================
313// f32-specific Storage for GPU transfers
314// =============================================================================
315
316#[cfg(feature = "cuda")]
317impl Storage<f32> {
318    /// Transfers f32 storage between CPU and GPU.
319    pub fn to_device_f32(&self, device: Device) -> Result<Self> {
320        if self.device() == device {
321            return Ok(self.clone());
322        }
323
324        let inner = self.inner.read();
325
326        match (&inner.data, device) {
327            // CPU → CPU: just deep copy
328            (StorageData::Cpu(_), Device::Cpu) => {
329                drop(inner);
330                Ok(self.deep_copy())
331            }
332            // CPU → GPU: htod_copy
333            (StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
334                let backend = crate::backends::cuda::get_cuda_backend()
335                    .ok_or(Error::DeviceNotAvailable { device })?;
336                let slice = &cpu_data[self.offset..self.offset + self.len];
337                let cuda_slice = backend
338                    .htod_copy(slice)
339                    .map_err(|_| Error::DeviceNotAvailable { device })?;
340                let len = self.len;
341                Ok(Self {
342                    inner: Arc::new(RwLock::new(StorageInner {
343                        data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
344                        device,
345                    })),
346                    offset: 0,
347                    len,
348                })
349            }
350            // GPU → CPU: dtoh_copy
351            (StorageData::Cuda(pooled), Device::Cpu) => {
352                let backend =
353                    crate::backends::cuda::get_cuda_backend().ok_or(Error::DeviceNotAvailable {
354                        device: self.device(),
355                    })?;
356                let full_vec = backend
357                    .dtoh_copy(pooled.slice())
358                    .map_err(|_| Error::DeviceNotAvailable { device })?;
359                let end = self.offset + self.len;
360                let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
361                    full_vec
362                } else if end <= full_vec.len() {
363                    full_vec[self.offset..end].to_vec()
364                } else {
365                    // CudaSlice is smaller than Storage.len — this indicates a bug
366                    // but handle gracefully: copy what we have, zero-pad the rest
367                    eprintln!(
368                        "[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
369                        full_vec.len(),
370                        end,
371                        self.offset,
372                        self.len
373                    );
374                    let available = if self.offset < full_vec.len() {
375                        full_vec.len() - self.offset
376                    } else {
377                        0
378                    };
379                    let mut result = vec![0.0f32; self.len];
380                    if available > 0 {
381                        result[..available]
382                            .copy_from_slice(&full_vec[self.offset..self.offset + available]);
383                    }
384                    result
385                };
386                Ok(Self::from_vec(sliced, Device::Cpu))
387            }
388            // GPU → GPU: D2H then H2D (simple path)
389            (StorageData::Cuda(_), Device::Cuda(_)) => {
390                drop(inner);
391                let cpu_storage = self.to_device_f32(Device::Cpu)?;
392                cpu_storage.to_device_f32(device)
393            }
394            _ => Err(Error::DeviceNotAvailable { device }),
395        }
396    }
397}
398
399// =============================================================================
400// CUDA-specific Storage methods
401// =============================================================================
402
403#[cfg(feature = "cuda")]
404impl Storage<f32> {
405    /// Returns data as a Vec<f32>, performing D2H copy if on GPU.
406    pub fn to_vec_f32(&self) -> Vec<f32> {
407        let inner = self.inner.read();
408        match &inner.data {
409            StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
410            StorageData::Cuda(pooled) => {
411                if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
412                    if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
413                        if self.offset == 0 && self.len == full_vec.len() {
414                            return full_vec;
415                        }
416                        return full_vec[self.offset..self.offset + self.len].to_vec();
417                    }
418                }
419                vec![0.0f32; self.len]
420            }
421        }
422    }
423
424    /// Deep copy that works for both CPU and GPU f32 storage.
425    pub fn deep_copy_f32(&self) -> Self {
426        let device = self.device();
427        let vec = self.to_vec_f32();
428        if device.is_gpu() {
429            if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
430                if let Ok(new_slice) = backend.htod_copy(&vec) {
431                    return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
432                }
433            }
434        }
435        Self::from_vec(vec, device)
436    }
437
438    /// Creates storage from a pool-allocated CudaSlice.
439    ///
440    /// The slice will be returned to the CUDA memory pool on drop.
441    /// Only use this for CudaSlices from `pool_alloc()` — the slice must be
442    /// bucket-sized for correct pool reuse.
443    pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
444        Self {
445            inner: Arc::new(RwLock::new(StorageInner {
446                data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
447                device,
448            })),
449            offset: 0,
450            len,
451        }
452    }
453
454    /// Creates storage from a non-pool CudaSlice (e.g., from htod_copy).
455    ///
456    /// The slice will be freed via cudaFree on drop (normal CUDA deallocation),
457    /// NOT returned to the memory pool.
458    pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
459        Self {
460            inner: Arc::new(RwLock::new(StorageInner {
461                data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
462                device,
463            })),
464            offset: 0,
465            len,
466        }
467    }
468
469    /// Returns a reference to the CudaSlice if on GPU.
470    ///
471    /// # Panics
472    /// Panics if storage is not on GPU.
473    pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
474        CudaSliceReadGuard {
475            guard: self.inner.read(),
476        }
477    }
478
479    /// Returns a write guard providing mutable access to the underlying `CudaSlice<f32>`.
480    ///
481    /// # Panics
482    /// Panics if storage is not on GPU.
483    pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
484        CudaSliceWriteGuard {
485            guard: self.inner.write(),
486        }
487    }
488}
489
490/// Read guard that provides access to the CudaSlice.
491#[cfg(feature = "cuda")]
492pub struct CudaSliceReadGuard<'a> {
493    guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
494}
495
496#[cfg(feature = "cuda")]
497impl<'a> CudaSliceReadGuard<'a> {
498    /// Returns a reference to the CudaSlice.
499    ///
500    /// # Panics
501    /// Panics if storage is CPU.
502    pub fn slice(&self) -> &CudaSlice<f32> {
503        match &self.guard.data {
504            StorageData::Cuda(pooled) => pooled.slice(),
505            StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
506        }
507    }
508}
509
510/// Write guard that provides mutable access to the CudaSlice.
511#[cfg(feature = "cuda")]
512pub struct CudaSliceWriteGuard<'a> {
513    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
514}
515
516#[cfg(feature = "cuda")]
517impl<'a> CudaSliceWriteGuard<'a> {
518    /// Returns a mutable reference to the CudaSlice.
519    ///
520    /// # Panics
521    /// Panics if storage is CPU.
522    pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
523        match &mut self.guard.data {
524            StorageData::Cuda(pooled) => pooled.slice_mut(),
525            StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
526        }
527    }
528}
529
530impl<T: Scalar> Clone for Storage<T> {
531    fn clone(&self) -> Self {
532        Self {
533            inner: Arc::clone(&self.inner),
534            offset: self.offset,
535            len: self.len,
536        }
537    }
538}
539
540// =============================================================================
541// Guard Types for Safe Access
542// =============================================================================
543
544/// Read guard for storage data.
545pub struct StorageReadGuard<'a, T: Scalar> {
546    guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
547    offset: usize,
548    len: usize,
549}
550
551impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
552    type Target = [T];
553
554    fn deref(&self) -> &Self::Target {
555        match &self.guard.data {
556            StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
557            #[cfg(feature = "cuda")]
558            StorageData::Cuda(_) => panic!(
559                "Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."
560            ),
561        }
562    }
563}
564
565/// Write guard for storage data.
566pub struct StorageWriteGuard<'a, T: Scalar> {
567    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
568    offset: usize,
569    len: usize,
570}
571
572impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
573    type Target = [T];
574
575    fn deref(&self) -> &Self::Target {
576        match &self.guard.data {
577            StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
578            #[cfg(feature = "cuda")]
579            StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
580        }
581    }
582}
583
584impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
585    fn deref_mut(&mut self) -> &mut Self::Target {
586        match &mut self.guard.data {
587            StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
588            #[cfg(feature = "cuda")]
589            StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
590        }
591    }
592}
593
594// =============================================================================
595// Tests
596// =============================================================================
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[test]
603    fn test_storage_zeros() {
604        let storage = Storage::<f32>::zeros(10, Device::Cpu);
605        assert_eq!(storage.len(), 10);
606        assert!(!storage.is_empty());
607
608        let data = storage.as_slice();
609        for &val in data.iter() {
610            assert_eq!(val, 0.0);
611        }
612    }
613
614    #[test]
615    fn test_storage_from_vec() {
616        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
617        let storage = Storage::from_vec(vec.clone(), Device::Cpu);
618
619        let data = storage.as_slice();
620        assert_eq!(&*data, &vec[..]);
621    }
622
623    #[test]
624    fn test_storage_slice() {
625        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
626        let storage = Storage::from_vec(vec, Device::Cpu);
627        let slice = storage.slice(1, 3).unwrap();
628
629        assert_eq!(slice.len(), 3);
630        let data = slice.as_slice();
631        assert_eq!(&*data, &[2.0, 3.0, 4.0]);
632    }
633
634    #[test]
635    fn test_storage_clone_shares() {
636        let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
637        let storage2 = storage1.clone();
638
639        assert!(!storage1.is_unique());
640        assert!(!storage2.is_unique());
641    }
642
643    #[test]
644    fn test_storage_deep_copy() {
645        let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
646        let storage2 = storage1.deep_copy();
647
648        assert!(storage1.is_unique());
649        assert!(storage2.is_unique());
650
651        // Modify storage2
652        storage2.as_slice_mut()[0] = 99.0;
653
654        // storage1 should be unchanged
655        assert_eq!(storage1.as_slice()[0], 1.0);
656    }
657
658    #[test]
659    fn test_storage_copy_from() {
660        let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
661        let dst = Storage::<f32>::zeros(3, Device::Cpu);
662
663        dst.copy_from(&src).unwrap();
664
665        let data = dst.as_slice();
666        assert_eq!(&*data, &[1.0, 2.0, 3.0]);
667    }
668
669    #[test]
670    fn test_storage_slice_out_of_bounds() {
671        let storage = Storage::<f32>::zeros(10, Device::Cpu);
672        let result = storage.slice(5, 10);
673        assert!(result.is_err());
674    }
675
676    #[test]
677    fn test_storage_to_vec_cpu() {
678        let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
679        assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
680    }
681
682    #[test]
683    fn test_storage_is_cpu() {
684        let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
685        assert!(storage.is_cpu());
686        assert!(!storage.is_gpu());
687    }
688}