ghostflow_core/
storage.rs

1//! Storage backend for tensor data
2
3use std::sync::Arc;
4use parking_lot::RwLock;
5use crate::dtype::{DType, TensorElement};
6
7/// Raw storage for tensor data
8/// Separate from Tensor to enable zero-copy views
9#[derive(Debug)]
10pub struct Storage {
11    /// Raw bytes
12    data: Arc<RwLock<Vec<u8>>>,
13    /// Data type
14    dtype: DType,
15    /// Number of elements
16    len: usize,
17}
18
19impl Storage {
20    /// Create new storage with given capacity
21    pub fn new(dtype: DType, len: usize) -> Self {
22        let byte_len = len * dtype.size_bytes();
23        let data = vec![0u8; byte_len];
24        Storage {
25            data: Arc::new(RwLock::new(data)),
26            dtype,
27            len,
28        }
29    }
30
31    /// Create storage from typed data
32    pub fn from_slice<T: TensorElement>(data: &[T]) -> Self {
33        let byte_len = std::mem::size_of_val(data);
34        let mut bytes = vec![0u8; byte_len];
35        
36        // Safe copy from typed slice to bytes
37        unsafe {
38            std::ptr::copy_nonoverlapping(
39                data.as_ptr() as *const u8,
40                bytes.as_mut_ptr(),
41                byte_len,
42            );
43        }
44        
45        Storage {
46            data: Arc::new(RwLock::new(bytes)),
47            dtype: T::DTYPE,
48            len: data.len(),
49        }
50    }
51
52    /// Number of elements
53    pub fn len(&self) -> usize {
54        self.len
55    }
56
57    /// Check if empty
58    pub fn is_empty(&self) -> bool {
59        self.len == 0
60    }
61
62    /// Data type
63    pub fn dtype(&self) -> DType {
64        self.dtype
65    }
66
67    /// Size in bytes
68    pub fn size_bytes(&self) -> usize {
69        self.len * self.dtype.size_bytes()
70    }
71
72    /// Get read access to data as typed slice
73    pub fn as_slice<T: TensorElement>(&self) -> StorageReadGuard<'_, T> {
74        debug_assert_eq!(T::DTYPE, self.dtype);
75        StorageReadGuard {
76            guard: self.data.read(),
77            len: self.len,
78            _marker: std::marker::PhantomData,
79        }
80    }
81
82    /// Get write access to data as typed slice
83    pub fn as_slice_mut<T: TensorElement>(&self) -> StorageWriteGuard<'_, T> {
84        debug_assert_eq!(T::DTYPE, self.dtype);
85        StorageWriteGuard {
86            guard: self.data.write(),
87            len: self.len,
88            _marker: std::marker::PhantomData,
89        }
90    }
91
92    /// Clone the storage (deep copy)
93    pub fn deep_clone(&self) -> Self {
94        let data = self.data.read().clone();
95        Storage {
96            data: Arc::new(RwLock::new(data)),
97            dtype: self.dtype,
98            len: self.len,
99        }
100    }
101
102    /// Check if this storage is shared (has multiple references)
103    pub fn is_shared(&self) -> bool {
104        Arc::strong_count(&self.data) > 1
105    }
106}
107
108impl Clone for Storage {
109    /// Shallow clone - shares underlying data
110    fn clone(&self) -> Self {
111        Storage {
112            data: Arc::clone(&self.data),
113            dtype: self.dtype,
114            len: self.len,
115        }
116    }
117}
118
119/// Read guard for typed access to storage
120pub struct StorageReadGuard<'a, T> {
121    guard: parking_lot::RwLockReadGuard<'a, Vec<u8>>,
122    len: usize,
123    _marker: std::marker::PhantomData<T>,
124}
125
126impl<'a, T: TensorElement> StorageReadGuard<'a, T> {
127    pub fn as_slice(&self) -> &[T] {
128        unsafe {
129            std::slice::from_raw_parts(self.guard.as_ptr() as *const T, self.len)
130        }
131    }
132}
133
134impl<'a, T: TensorElement> std::ops::Deref for StorageReadGuard<'a, T> {
135    type Target = [T];
136    
137    fn deref(&self) -> &Self::Target {
138        self.as_slice()
139    }
140}
141
142/// Write guard for typed access to storage
143pub struct StorageWriteGuard<'a, T> {
144    guard: parking_lot::RwLockWriteGuard<'a, Vec<u8>>,
145    len: usize,
146    _marker: std::marker::PhantomData<T>,
147}
148
149impl<'a, T: TensorElement> StorageWriteGuard<'a, T> {
150    pub fn as_slice(&self) -> &[T] {
151        unsafe {
152            std::slice::from_raw_parts(self.guard.as_ptr() as *const T, self.len)
153        }
154    }
155    
156    pub fn as_slice_mut(&mut self) -> &mut [T] {
157        unsafe {
158            std::slice::from_raw_parts_mut(self.guard.as_mut_ptr() as *mut T, self.len)
159        }
160    }
161}
162
163impl<'a, T: TensorElement> std::ops::Deref for StorageWriteGuard<'a, T> {
164    type Target = [T];
165    
166    fn deref(&self) -> &Self::Target {
167        self.as_slice()
168    }
169}
170
171impl<'a, T: TensorElement> std::ops::DerefMut for StorageWriteGuard<'a, T> {
172    fn deref_mut(&mut self) -> &mut Self::Target {
173        self.as_slice_mut()
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_storage_creation() {
183        let storage = Storage::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
184        assert_eq!(storage.len(), 4);
185        assert_eq!(storage.dtype(), DType::F32);
186    }
187
188    #[test]
189    fn test_storage_read() {
190        let storage = Storage::from_slice(&[1.0f32, 2.0, 3.0]);
191        let data = storage.as_slice::<f32>();
192        assert_eq!(&*data, &[1.0, 2.0, 3.0]);
193    }
194
195    #[test]
196    fn test_storage_write() {
197        let storage = Storage::from_slice(&[1.0f32, 2.0, 3.0]);
198        {
199            let mut data = storage.as_slice_mut::<f32>();
200            data[0] = 10.0;
201        }
202        let data = storage.as_slice::<f32>();
203        assert_eq!(data[0], 10.0);
204    }
205}