axonml_core/
storage.rs

1//! Storage - Raw Memory Management for Tensors
2//!
3//! Provides efficient memory storage that underlies all tensor operations.
4//! Storage is reference-counted for efficient sharing between tensor views.
5//!
6//! # Key Features
7//! - Reference-counted memory for efficient views
8//! - Device-agnostic storage interface
9//! - Zero-copy slicing through offset/length
10//! - Automatic memory cleanup
11//!
12//! # Example
13//! ```rust
14//! use axonml_core::{Storage, Device};
15//!
16//! // Create storage for 100 f32 values on CPU
17//! let storage = Storage::<f32>::zeros(100, Device::Cpu);
18//! assert_eq!(storage.len(), 100);
19//! ```
20//!
21//! @version 0.1.0
22//! @author `AutomataNexus` Development Team
23
24use core::ops::{Deref, DerefMut};
25use std::sync::Arc;
26
27use parking_lot::RwLock;
28
29use crate::device::Device;
30use crate::dtype::Scalar;
31use crate::error::{Error, Result};
32
33// =============================================================================
34// Storage Struct
35// =============================================================================
36
37/// Raw memory storage for tensor data.
38///
39/// Storage is the fundamental building block for tensors. It manages a contiguous
40/// block of memory on a specific device and is reference-counted to allow
41/// efficient sharing between tensor views.
42#[derive(Debug)]
43pub struct Storage<T: Scalar> {
44    /// The underlying data buffer.
45    inner: Arc<RwLock<StorageInner<T>>>,
46    /// Offset into the storage (for views).
47    offset: usize,
48    /// Number of elements in this view.
49    len: usize,
50}
51
52/// Inner storage data that can be shared between views.
53#[derive(Debug)]
54struct StorageInner<T: Scalar> {
55    /// Raw data pointer (owned).
56    data: Vec<T>,
57    /// The device this storage resides on.
58    device: Device,
59}
60
61impl<T: Scalar> Storage<T> {
62    /// Creates new storage with the given capacity, initialized to zero.
63    ///
64    /// # Arguments
65    /// * `len` - Number of elements to allocate
66    /// * `device` - Device to allocate on
67    ///
68    /// # Returns
69    /// New storage initialized to zeros.
70    #[must_use]
71    pub fn zeros(len: usize, device: Device) -> Self {
72        let data = vec![T::zeroed(); len];
73        Self::from_vec(data, device)
74    }
75
76    /// Creates storage from an existing vector.
77    ///
78    /// # Arguments
79    /// * `data` - Vector of data
80    /// * `device` - Device the storage is on
81    ///
82    /// # Returns
83    /// New storage wrapping the data.
84    #[must_use]
85    pub fn from_vec(data: Vec<T>, device: Device) -> Self {
86        let len = data.len();
87        Self {
88            inner: Arc::new(RwLock::new(StorageInner { data, device })),
89            offset: 0,
90            len,
91        }
92    }
93
94    /// Creates storage from a slice by copying the data.
95    ///
96    /// # Arguments
97    /// * `data` - Slice of data to copy
98    /// * `device` - Device to allocate on
99    ///
100    /// # Returns
101    /// New storage containing a copy of the data.
102    #[must_use]
103    pub fn from_slice(data: &[T], device: Device) -> Self {
104        Self::from_vec(data.to_vec(), device)
105    }
106
107    /// Returns the number of elements in this storage view.
108    #[must_use]
109    pub const fn len(&self) -> usize {
110        self.len
111    }
112
113    /// Returns true if the storage is empty.
114    #[must_use]
115    pub const fn is_empty(&self) -> bool {
116        self.len == 0
117    }
118
119    /// Returns the offset into the underlying buffer.
120    #[must_use]
121    pub const fn offset(&self) -> usize {
122        self.offset
123    }
124
125    /// Returns the device this storage is on.
126    #[must_use]
127    pub fn device(&self) -> Device {
128        self.inner.read().device
129    }
130
131    /// Returns the size in bytes of this storage.
132    #[must_use]
133    pub fn size_bytes(&self) -> usize {
134        self.len * core::mem::size_of::<T>()
135    }
136
137    /// Creates a view into a portion of this storage.
138    ///
139    /// # Arguments
140    /// * `offset` - Starting offset relative to this view
141    /// * `len` - Number of elements in the new view
142    ///
143    /// # Returns
144    /// A new storage view, or error if bounds are invalid.
145    pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
146        if offset + len > self.len {
147            return Err(Error::IndexOutOfBounds {
148                index: offset + len,
149                size: self.len,
150            });
151        }
152
153        Ok(Self {
154            inner: Arc::clone(&self.inner),
155            offset: self.offset + offset,
156            len,
157        })
158    }
159
160    /// Returns true if this storage is uniquely owned (not shared).
161    #[must_use]
162    pub fn is_unique(&self) -> bool {
163        Arc::strong_count(&self.inner) == 1
164    }
165
166    /// Returns an immutable reference to the data.
167    ///
168    /// # Panics
169    /// Panics if the lock is poisoned.
170    #[must_use]
171    pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
172        StorageReadGuard {
173            guard: self.inner.read(),
174            offset: self.offset,
175            len: self.len,
176        }
177    }
178
179    /// Returns a mutable reference to the data.
180    ///
181    /// # Panics
182    /// Panics if the lock is poisoned.
183    #[must_use]
184    pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
185        StorageWriteGuard {
186            guard: self.inner.write(),
187            offset: self.offset,
188            len: self.len,
189        }
190    }
191
192    /// Copies data from another storage into this one.
193    ///
194    /// # Arguments
195    /// * `other` - Source storage to copy from
196    ///
197    /// # Returns
198    /// Ok if successful, error if lengths don't match.
199    pub fn copy_from(&self, other: &Self) -> Result<()> {
200        if self.len != other.len {
201            return Err(Error::shape_mismatch(&[self.len], &[other.len]));
202        }
203
204        let src = other.as_slice();
205        let mut dst = self.as_slice_mut();
206        dst.copy_from_slice(&src);
207        Ok(())
208    }
209
210    /// Makes a deep copy of this storage.
211    #[must_use]
212    pub fn deep_copy(&self) -> Self {
213        let data = self.as_slice().to_vec();
214        Self::from_vec(data, self.device())
215    }
216
217    /// Transfers this storage to a different device.
218    ///
219    /// # Arguments
220    /// * `device` - Target device
221    ///
222    /// # Returns
223    /// New storage on the target device.
224    pub fn to_device(&self, device: Device) -> Result<Self> {
225        if self.device() == device {
226            return Ok(self.clone());
227        }
228
229        // For now, only CPU is supported
230        if !device.is_cpu() {
231            return Err(Error::DeviceNotAvailable { device });
232        }
233
234        Ok(self.deep_copy())
235    }
236}
237
238impl<T: Scalar> Clone for Storage<T> {
239    fn clone(&self) -> Self {
240        Self {
241            inner: Arc::clone(&self.inner),
242            offset: self.offset,
243            len: self.len,
244        }
245    }
246}
247
248// =============================================================================
249// Guard Types for Safe Access
250// =============================================================================
251
252/// Read guard for storage data.
253pub struct StorageReadGuard<'a, T: Scalar> {
254    guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
255    offset: usize,
256    len: usize,
257}
258
259impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
260    type Target = [T];
261
262    fn deref(&self) -> &Self::Target {
263        &self.guard.data[self.offset..self.offset + self.len]
264    }
265}
266
267/// Write guard for storage data.
268pub struct StorageWriteGuard<'a, T: Scalar> {
269    guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
270    offset: usize,
271    len: usize,
272}
273
274impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
275    type Target = [T];
276
277    fn deref(&self) -> &Self::Target {
278        &self.guard.data[self.offset..self.offset + self.len]
279    }
280}
281
282impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
283    fn deref_mut(&mut self) -> &mut Self::Target {
284        &mut self.guard.data[self.offset..self.offset + self.len]
285    }
286}
287
288// =============================================================================
289// Tests
290// =============================================================================
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_storage_zeros() {
298        let storage = Storage::<f32>::zeros(10, Device::Cpu);
299        assert_eq!(storage.len(), 10);
300        assert!(!storage.is_empty());
301
302        let data = storage.as_slice();
303        for &val in data.iter() {
304            assert_eq!(val, 0.0);
305        }
306    }
307
308    #[test]
309    fn test_storage_from_vec() {
310        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
311        let storage = Storage::from_vec(vec.clone(), Device::Cpu);
312
313        let data = storage.as_slice();
314        assert_eq!(&*data, &vec[..]);
315    }
316
317    #[test]
318    fn test_storage_slice() {
319        let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
320        let storage = Storage::from_vec(vec, Device::Cpu);
321        let slice = storage.slice(1, 3).unwrap();
322
323        assert_eq!(slice.len(), 3);
324        let data = slice.as_slice();
325        assert_eq!(&*data, &[2.0, 3.0, 4.0]);
326    }
327
328    #[test]
329    fn test_storage_clone_shares() {
330        let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
331        let storage2 = storage1.clone();
332
333        assert!(!storage1.is_unique());
334        assert!(!storage2.is_unique());
335    }
336
337    #[test]
338    fn test_storage_deep_copy() {
339        let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
340        let storage2 = storage1.deep_copy();
341
342        assert!(storage1.is_unique());
343        assert!(storage2.is_unique());
344
345        // Modify storage2
346        storage2.as_slice_mut()[0] = 99.0;
347
348        // storage1 should be unchanged
349        assert_eq!(storage1.as_slice()[0], 1.0);
350    }
351
352    #[test]
353    fn test_storage_copy_from() {
354        let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
355        let dst = Storage::<f32>::zeros(3, Device::Cpu);
356
357        dst.copy_from(&src).unwrap();
358
359        let data = dst.as_slice();
360        assert_eq!(&*data, &[1.0, 2.0, 3.0]);
361    }
362
363    #[test]
364    fn test_storage_slice_out_of_bounds() {
365        let storage = Storage::<f32>::zeros(10, Device::Cpu);
366        let result = storage.slice(5, 10);
367        assert!(result.is_err());
368    }
369}