Skip to main content

axonml_core/
storage.rs

1//! Reference-counted raw memory management for tensors.
2//!
3//! `Storage<T>` wraps either a CPU `Vec<T>` or a GPU `PooledCudaSlice` behind
4//! an `Arc<RwLock<StorageInner<T>>>`, enabling zero-copy views via offset+len
5//! slicing. GPU storage uses the CUDA memory pool (`cuda_pool.rs`) so freed
6//! allocations are returned to a size-bucketed free list instead of calling
7//! cudaFree. Supports `to_device()` for CPU<->GPU transfer, deep copy,
8//! `as_slice()` / `as_slice_mut()` with RAII guards, and `as_cuda_slice()`
9//! for direct GPU kernel access.
10//!
11//! # File
12//! `crates/axonml-core/src/storage.rs`
13//!
14//! # Author
15//! Andrew Jewell Sr. — AutomataNexus LLC
16//! ORCID: 0009-0005-2158-7060
17//!
18//! # Updated
19//! April 14, 2026 11:15 PM EST
20//!
21//! # Disclaimer
22//! Use at own risk. This software is provided "as is", without warranty of any
23//! kind, express or implied. The author and AutomataNexus shall not be held
24//! liable for any damages arising from the use of this software.
25
26use core::ops::{Deref, DerefMut};
27use std::sync::Arc;
28
29use parking_lot::RwLock;
30
31use crate::device::Device;
32use crate::dtype::Scalar;
33use crate::error::{Error, Result};
34
35#[cfg(feature = "cuda")]
36use cudarc::driver::CudaSlice;
37
38// =============================================================================
39// Storage Data Enum
40// =============================================================================
41
42/// Wrapper around CudaSlice that returns memory to the pool on drop
43/// instead of calling cudaFree.
44#[cfg(feature = "cuda")]
45pub struct PooledCudaSlice {
46    slice: Option<CudaSlice<f32>>,
47    pool_managed: bool,
48}
49
50#[cfg(feature = "cuda")]
51impl std::fmt::Debug for PooledCudaSlice {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("PooledCudaSlice")
54            .field("pool_managed", &self.pool_managed)
55            .field("len", &self.slice.as_ref().map(|s| s.len()))
56            .finish()
57    }
58}
59
60#[cfg(feature = "cuda")]
61impl Drop for PooledCudaSlice {
62    fn drop(&mut self) {
63        if let Some(slice) = self.slice.take() {
64            if self.pool_managed {
65                crate::backends::cuda_pool::pool_free(slice);
66            }
67            // else: normal CudaSlice drop calls cudaFree
68        }
69    }
70}
71
72#[cfg(feature = "cuda")]
73impl PooledCudaSlice {
74    /// Create a new pool-managed CUDA slice.
75    pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
76        Self {
77            slice: Some(slice),
78            pool_managed,
79        }
80    }
81
82    /// Get a reference to the underlying CudaSlice.
83    pub fn slice(&self) -> &CudaSlice<f32> {
84        self.slice.as_ref().expect("CudaSlice already taken")
85    }
86
87    /// Get a mutable reference to the underlying CudaSlice.
88    pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
89        self.slice.as_mut().expect("CudaSlice already taken")
90    }
91}
92
93/// Holds either CPU or GPU data.
94#[derive(Debug)]
95enum StorageData<T: Scalar> {
96    /// CPU data stored as a Vec.
97    Cpu(Vec<T>),
98    /// GPU data stored as a PooledCudaSlice (f32 only on GPU).
99    /// Returns to memory pool on drop instead of calling cudaFree.
100    #[cfg(feature = "cuda")]
101    Cuda(PooledCudaSlice),
102}
103
104// =============================================================================
105// Storage Struct
106// =============================================================================
107
108/// Raw memory storage for tensor data.
109///
110/// Storage is the fundamental building block for tensors. It manages a contiguous
111/// block of memory on a specific device and is reference-counted to allow
112/// efficient sharing between tensor views.
113#[derive(Debug)]
114pub struct Storage<T: Scalar> {
115    /// The underlying data buffer.
116    inner: Arc<RwLock<StorageInner<T>>>,
117    /// Offset into the storage (for views).
118    offset: usize,
119    /// Number of elements in this view.
120    len: usize,
121}
122
123/// Inner storage data that can be shared between views.
124#[derive(Debug)]
125struct StorageInner<T: Scalar> {
126    /// Raw data (CPU or GPU).
127    data: StorageData<T>,
128    /// The device this storage resides on.
129    device: Device,
130}
131
132impl<T: Scalar> Storage<T> {
133    /// Creates new storage with the given capacity, initialized to zero.
134    #[must_use]
135    pub fn zeros(len: usize, device: Device) -> Self {
136        let data = vec![T::zeroed(); len];
137        Self::from_vec(data, device)
138    }
139
140    /// Creates storage from an existing vector (always on CPU).
141    #[must_use]
142    pub fn from_vec(data: Vec<T>, device: Device) -> Self {
143        let len = data.len();
144        Self {
145            inner: Arc::new(RwLock::new(StorageInner {
146                data: StorageData::Cpu(data),
147                device,
148            })),
149            offset: 0,
150            len,
151        }
152    }
153
154    /// Creates storage from a slice by copying the data.
155    #[must_use]
156    pub fn from_slice(data: &[T], device: Device) -> Self {
157        Self::from_vec(data.to_vec(), device)
158    }
159
160    /// Returns the number of elements in this storage view.
161    #[must_use]
162    pub const fn len(&self) -> usize {
163        self.len
164    }
165
166    /// Returns true if the storage is empty.
167    #[must_use]
168    pub const fn is_empty(&self) -> bool {
169        self.len == 0
170    }
171
172    /// Returns the offset into the underlying buffer.
173    #[must_use]
174    pub const fn offset(&self) -> usize {
175        self.offset
176    }
177
178    /// Returns the device this storage is on.
179    #[must_use]
180    pub fn device(&self) -> Device {
181        self.inner.read().device
182    }
183
184    /// Returns true if data is on CPU.
185    #[must_use]
186    pub fn is_cpu(&self) -> bool {
187        matches!(self.inner.read().data, StorageData::Cpu(_))
188    }
189
190    /// Returns true if data is on GPU.
191    #[must_use]
192    pub fn is_gpu(&self) -> bool {
193        !self.is_cpu()
194    }
195
196    /// Returns the size in bytes of this storage.
197    #[must_use]
198    pub fn size_bytes(&self) -> usize {
199        self.len * core::mem::size_of::<T>()
200    }
201
202    /// Creates a view into a portion of this storage.
203    pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
204        if offset + len > self.len {
205            return Err(Error::IndexOutOfBounds {
206                index: offset + len,
207                size: self.len,
208            });
209        }
210
211        Ok(Self {
212            inner: Arc::clone(&self.inner),
213            offset: self.offset + offset,
214            len,
215        })
216    }
217
218    /// Returns true if this storage is uniquely owned (not shared).
219    #[must_use]
220    pub fn is_unique(&self) -> bool {
221        Arc::strong_count(&self.inner) == 1
222    }
223
224    /// Returns an immutable reference to the CPU data.
225    ///
226    /// # Panics
227    /// Panics if the storage is on GPU. Use `to_vec()` for device-safe access.
228    #[must_use]
229    pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
230        StorageReadGuard {
231            guard: self.inner.read(),
232            offset: self.offset,
233            len: self.len,
234        }
235    }
236
237    /// Returns a mutable reference to the CPU data.
238    ///
239    /// # Panics
240    /// Panics if the storage is on GPU.
241    #[must_use]
242    pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
243        StorageWriteGuard {
244            guard: self.inner.write(),
245            offset: self.offset,
246            len: self.len,
247        }
248    }
249
250    /// Copies data from another storage into this one.
251    pub fn copy_from(&self, other: &Self) -> Result<()> {
252        if self.len != other.len {
253            return Err(Error::shape_mismatch(&[self.len], &[other.len]));
254        }
255
256        let src = other.as_slice();
257        let mut dst = self.as_slice_mut();
258        dst.copy_from_slice(&src);
259        Ok(())
260    }
261
262    /// Makes a deep copy of this storage.
263    ///
264    /// For GPU storage, this only works for `Storage<f32>`.
265    /// Other types will panic on GPU storage.
266    #[must_use]
267    pub fn deep_copy(&self) -> Self {
268        let inner = self.inner.read();
269        match &inner.data {
270            StorageData::Cpu(cpu_data) => {
271                let data = cpu_data[self.offset..self.offset + self.len].to_vec();
272                Self::from_vec(data, inner.device)
273            }
274            #[cfg(feature = "cuda")]
275            StorageData::Cuda(_) => {
276                panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
277            }
278        }
279    }
280
281    /// Returns the data as a Vec from CPU storage.
282    ///
283    /// For GPU-resident f32 tensors, use the `Tensor::to_vec()` method which
284    /// handles device-aware copies via `Storage<f32>::to_vec_f32()`.
285    ///
286    /// # Panics
287    /// Panics if storage is on GPU (use `to_vec_f32()` on `Storage<f32>` instead).
288    pub fn to_vec(&self) -> Vec<T> {
289        let inner = self.inner.read();
290        match &inner.data {
291            StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
292            #[cfg(feature = "cuda")]
293            StorageData::Cuda(_) => {
294                panic!(
295                    "Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>."
296                );
297            }
298        }
299    }
300
301    /// Transfers this storage to a different device.
302    ///
303    /// For GPU transfers, only `Storage<f32>` is supported. Use the
304    /// `Storage<f32>::to_device()` specialization for CPU↔GPU transfers.
305    pub fn to_device(&self, device: Device) -> Result<Self> {
306        if self.device() == device {
307            return Ok(self.clone());
308        }
309
310        // Generic path: only CPU→CPU is supported
311        if device.is_cpu() && self.device().is_cpu() {
312            return Ok(self.deep_copy());
313        }
314
315        Err(Error::DeviceNotAvailable { device })
316    }
317}
318
319// =============================================================================
320// f32-specific Storage for GPU transfers
321// =============================================================================
322
323#[cfg(feature = "cuda")]
324impl Storage<f32> {
325    /// Transfers f32 storage between CPU and GPU.
326    pub fn to_device_f32(&self, device: Device) -> Result<Self> {
327        if self.device() == device {
328            return Ok(self.clone());
329        }
330
331        let inner = self.inner.read();
332
333        match (&inner.data, device) {
334            // CPU → CPU: just deep copy
335            (StorageData::Cpu(_), Device::Cpu) => {
336                drop(inner);
337                Ok(self.deep_copy())
338            }
339            // CPU → GPU: htod_copy
340            (StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
341                let backend = crate::backends::cuda::get_cuda_backend()
342                    .ok_or(Error::DeviceNotAvailable { device })?;
343                let slice = &cpu_data[self.offset..self.offset + self.len];
344                let cuda_slice = backend
345                    .htod_copy(slice)
346                    .map_err(|_| Error::DeviceNotAvailable { device })?;
347                let len = self.len;
348                Ok(Self {
349                    inner: Arc::new(RwLock::new(StorageInner {
350                        data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
351                        device,
352                    })),
353                    offset: 0,
354                    len,
355                })
356            }
357            // GPU → CPU: dtoh_copy
358            (StorageData::Cuda(pooled), Device::Cpu) => {
359                let backend =
360                    crate::backends::cuda::get_cuda_backend().ok_or(Error::DeviceNotAvailable {
361                        device: self.device(),
362                    })?;
363                let full_vec = backend
364                    .dtoh_copy(pooled.slice())
365                    .map_err(|_| Error::DeviceNotAvailable { device })?;
366                let end = self.offset + self.len;
367                let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
368                    full_vec
369                } else if end <= full_vec.len() {
370                    full_vec[self.offset..end].to_vec()
371                } else {
372                    // CudaSlice is smaller than Storage.len — this indicates a bug
373                    // but handle gracefully: copy what we have, zero-pad the rest
374                    eprintln!(
375                        "[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
376                        full_vec.len(),
377                        end,
378                        self.offset,
379                        self.len
380                    );
381                    let available = if self.offset < full_vec.len() {
382                        full_vec.len() - self.offset
383                    } else {
384                        0
385                    };
386                    let mut result = vec![0.0f32; self.len];
387                    if available > 0 {
388                        result[..available]
389                            .copy_from_slice(&full_vec[self.offset..self.offset + available]);
390                    }
391                    result
392                };
393                Ok(Self::from_vec(sliced, Device::Cpu))
394            }
395            // GPU → GPU: D2H then H2D (simple path)
396            (StorageData::Cuda(_), Device::Cuda(_)) => {
397                drop(inner);
398                let cpu_storage = self.to_device_f32(Device::Cpu)?;
399                cpu_storage.to_device_f32(device)
400            }
401        }
402    }
403}
404
405// =============================================================================
406// CUDA-specific Storage methods
407// =============================================================================
408
409#[cfg(feature = "cuda")]
410impl Storage<f32> {
411    /// Returns data as a Vec<f32>, performing D2H copy if on GPU.
412    pub fn to_vec_f32(&self) -> Vec<f32> {
413        let inner = self.inner.read();
414        match &inner.data {
415            StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
416            StorageData::Cuda(pooled) => {
417                if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
418                    if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
419                        if self.offset == 0 && self.len == full_vec.len() {
420                            return full_vec;
421                        }
422                        return full_vec[self.offset..self.offset + self.len].to_vec();
423                    }
424                }
425                vec![0.0f32; self.len]
426            }
427        }
428    }
429
430    /// Deep copy that works for both CPU and GPU f32 storage.
431    pub fn deep_copy_f32(&self) -> Self {
432        let device = self.device();
433        let vec = self.to_vec_f32();
434        if device.is_gpu() {
435            if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
436                if let Ok(new_slice) = backend.htod_copy(&vec) {
437                    return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
438                }
439            }
440        }
441        Self::from_vec(vec, device)
442    }
443
444    /// Creates storage from a pool-allocated CudaSlice.
445    ///
446    /// The slice will be returned to the CUDA memory pool on drop.
447    /// Only use this for CudaSlices from `pool_alloc()` — the slice must be
448    /// bucket-sized for correct pool reuse.
449    pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
450        Self {
451            inner: Arc::new(RwLock::new(StorageInner {
452                data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
453                device,
454            })),
455            offset: 0,
456            len,
457        }
458    }
459
460    /// Creates storage from a non-pool CudaSlice (e.g., from htod_copy).
461    ///
462    /// The slice will be freed via cudaFree on drop (normal CUDA deallocation),
463    /// NOT returned to the memory pool.
464    pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
465        Self {
466            inner: Arc::new(RwLock::new(StorageInner {
467                data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
468                device,
469            })),
470            offset: 0,
471            len,
472        }
473    }
474
475    /// Returns a reference to the CudaSlice if on GPU.
476    ///
477    /// # Panics
478    /// Panics if storage is not on GPU.
479    pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
480        CudaSliceReadGuard {
481            guard: self.inner.read(),
482        }
483    }
484
485    /// Returns a write guard providing mutable access to the underlying `CudaSlice<f32>`.
486    ///
487    /// # Panics
488    /// Panics if storage is not on GPU.
489    pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
490        CudaSliceWriteGuard {
491            guard: self.inner.write(),
492        }
493    }
494}
495
496/// Read guard that provides access to the CudaSlice.
497#[cfg(feature = "cuda")]
498pub struct CudaSliceReadGuard<'a> {
499    guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
500}
501
502#[cfg(feature = "cuda")]
503impl<'a> CudaSliceReadGuard<'a> {
504    /// Returns a reference to the CudaSlice.
505    ///
506    /// # Panics
507    /// Panics if storage is CPU.
508    pub fn slice(&self) -> &CudaSlice<f32> {
509        match &self.guard.data {
510            StorageData::Cuda(pooled) => pooled.slice(),
511            StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
512        }
513    }
514}
515
516/// Write guard that provides mutable access to the CudaSlice.
517#[cfg(feature = "cuda")]
518pub struct CudaSliceWriteGuard<'a> {
519    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
520}
521
522#[cfg(feature = "cuda")]
523impl<'a> CudaSliceWriteGuard<'a> {
524    /// Returns a mutable reference to the CudaSlice.
525    ///
526    /// # Panics
527    /// Panics if storage is CPU.
528    pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
529        match &mut self.guard.data {
530            StorageData::Cuda(pooled) => pooled.slice_mut(),
531            StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
532        }
533    }
534}
535
536impl<T: Scalar> Clone for Storage<T> {
537    fn clone(&self) -> Self {
538        Self {
539            inner: Arc::clone(&self.inner),
540            offset: self.offset,
541            len: self.len,
542        }
543    }
544}
545
546// =============================================================================
547// Guard Types for Safe Access
548// =============================================================================
549
550/// Read guard for storage data.
551pub struct StorageReadGuard<'a, T: Scalar> {
552    guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
553    offset: usize,
554    len: usize,
555}
556
557impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
558    type Target = [T];
559
560    fn deref(&self) -> &Self::Target {
561        match &self.guard.data {
562            StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
563            #[cfg(feature = "cuda")]
564            StorageData::Cuda(_) => panic!(
565                "Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."
566            ),
567        }
568    }
569}
570
571/// Write guard for storage data.
572pub struct StorageWriteGuard<'a, T: Scalar> {
573    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
574    offset: usize,
575    len: usize,
576}
577
578impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
579    type Target = [T];
580
581    fn deref(&self) -> &Self::Target {
582        match &self.guard.data {
583            StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
584            #[cfg(feature = "cuda")]
585            StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
586        }
587    }
588}
589
590impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
591    fn deref_mut(&mut self) -> &mut Self::Target {
592        match &mut self.guard.data {
593            StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
594            #[cfg(feature = "cuda")]
595            StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
596        }
597    }
598}
599
600// =============================================================================
601// Tests
602// =============================================================================
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[test]
609    fn test_storage_zeros() {
610        let storage = Storage::<f32>::zeros(10, Device::Cpu);
611        assert_eq!(storage.len(), 10);
612        assert!(!storage.is_empty());
613
614        let data = storage.as_slice();
615        for &val in data.iter() {
616            assert_eq!(val, 0.0);
617        }
618    }
619
620    #[test]
621    fn test_storage_from_vec() {
622        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
623        let storage = Storage::from_vec(vec.clone(), Device::Cpu);
624
625        let data = storage.as_slice();
626        assert_eq!(&*data, &vec[..]);
627    }
628
629    #[test]
630    fn test_storage_slice() {
631        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
632        let storage = Storage::from_vec(vec, Device::Cpu);
633        let slice = storage.slice(1, 3).unwrap();
634
635        assert_eq!(slice.len(), 3);
636        let data = slice.as_slice();
637        assert_eq!(&*data, &[2.0, 3.0, 4.0]);
638    }
639
640    #[test]
641    fn test_storage_clone_shares() {
642        let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
643        let storage2 = storage1.clone();
644
645        assert!(!storage1.is_unique());
646        assert!(!storage2.is_unique());
647    }
648
649    #[test]
650    fn test_storage_deep_copy() {
651        let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
652        let storage2 = storage1.deep_copy();
653
654        assert!(storage1.is_unique());
655        assert!(storage2.is_unique());
656
657        // Modify storage2
658        storage2.as_slice_mut()[0] = 99.0;
659
660        // storage1 should be unchanged
661        assert_eq!(storage1.as_slice()[0], 1.0);
662    }
663
664    #[test]
665    fn test_storage_copy_from() {
666        let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
667        let dst = Storage::<f32>::zeros(3, Device::Cpu);
668
669        dst.copy_from(&src).unwrap();
670
671        let data = dst.as_slice();
672        assert_eq!(&*data, &[1.0, 2.0, 3.0]);
673    }
674
675    #[test]
676    fn test_storage_slice_out_of_bounds() {
677        let storage = Storage::<f32>::zeros(10, Device::Cpu);
678        let result = storage.slice(5, 10);
679        assert!(result.is_err());
680    }
681
682    #[test]
683    fn test_storage_to_vec_cpu() {
684        let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
685        assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
686    }
687
688    #[test]
689    fn test_storage_is_cpu() {
690        let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
691        assert!(storage.is_cpu());
692        assert!(!storage.is_gpu());
693    }
694}