ghostflow_core/
storage.rs1use std::sync::Arc;
4use parking_lot::RwLock;
5use crate::dtype::{DType, TensorElement};
6
7#[derive(Debug)]
10pub struct Storage {
11 data: Arc<RwLock<Vec<u8>>>,
13 dtype: DType,
15 len: usize,
17}
18
19impl Storage {
20 pub fn new(dtype: DType, len: usize) -> Self {
22 let byte_len = len * dtype.size_bytes();
23 let data = vec![0u8; byte_len];
24 Storage {
25 data: Arc::new(RwLock::new(data)),
26 dtype,
27 len,
28 }
29 }
30
31 pub fn from_slice<T: TensorElement>(data: &[T]) -> Self {
33 let byte_len = std::mem::size_of_val(data);
34 let mut bytes = vec![0u8; byte_len];
35
36 unsafe {
38 std::ptr::copy_nonoverlapping(
39 data.as_ptr() as *const u8,
40 bytes.as_mut_ptr(),
41 byte_len,
42 );
43 }
44
45 Storage {
46 data: Arc::new(RwLock::new(bytes)),
47 dtype: T::DTYPE,
48 len: data.len(),
49 }
50 }
51
52 pub fn len(&self) -> usize {
54 self.len
55 }
56
57 pub fn is_empty(&self) -> bool {
59 self.len == 0
60 }
61
62 pub fn dtype(&self) -> DType {
64 self.dtype
65 }
66
67 pub fn size_bytes(&self) -> usize {
69 self.len * self.dtype.size_bytes()
70 }
71
72 pub fn as_slice<T: TensorElement>(&self) -> StorageReadGuard<'_, T> {
74 debug_assert_eq!(T::DTYPE, self.dtype);
75 StorageReadGuard {
76 guard: self.data.read(),
77 len: self.len,
78 _marker: std::marker::PhantomData,
79 }
80 }
81
82 pub fn as_slice_mut<T: TensorElement>(&self) -> StorageWriteGuard<'_, T> {
84 debug_assert_eq!(T::DTYPE, self.dtype);
85 StorageWriteGuard {
86 guard: self.data.write(),
87 len: self.len,
88 _marker: std::marker::PhantomData,
89 }
90 }
91
92 pub fn deep_clone(&self) -> Self {
94 let data = self.data.read().clone();
95 Storage {
96 data: Arc::new(RwLock::new(data)),
97 dtype: self.dtype,
98 len: self.len,
99 }
100 }
101
102 pub fn is_shared(&self) -> bool {
104 Arc::strong_count(&self.data) > 1
105 }
106}
107
108impl Clone for Storage {
109 fn clone(&self) -> Self {
111 Storage {
112 data: Arc::clone(&self.data),
113 dtype: self.dtype,
114 len: self.len,
115 }
116 }
117}
118
119pub struct StorageReadGuard<'a, T> {
121 guard: parking_lot::RwLockReadGuard<'a, Vec<u8>>,
122 len: usize,
123 _marker: std::marker::PhantomData<T>,
124}
125
126impl<'a, T: TensorElement> StorageReadGuard<'a, T> {
127 pub fn as_slice(&self) -> &[T] {
128 unsafe {
129 std::slice::from_raw_parts(self.guard.as_ptr() as *const T, self.len)
130 }
131 }
132}
133
134impl<'a, T: TensorElement> std::ops::Deref for StorageReadGuard<'a, T> {
135 type Target = [T];
136
137 fn deref(&self) -> &Self::Target {
138 self.as_slice()
139 }
140}
141
142pub struct StorageWriteGuard<'a, T> {
144 guard: parking_lot::RwLockWriteGuard<'a, Vec<u8>>,
145 len: usize,
146 _marker: std::marker::PhantomData<T>,
147}
148
149impl<'a, T: TensorElement> StorageWriteGuard<'a, T> {
150 pub fn as_slice(&self) -> &[T] {
151 unsafe {
152 std::slice::from_raw_parts(self.guard.as_ptr() as *const T, self.len)
153 }
154 }
155
156 pub fn as_slice_mut(&mut self) -> &mut [T] {
157 unsafe {
158 std::slice::from_raw_parts_mut(self.guard.as_mut_ptr() as *mut T, self.len)
159 }
160 }
161}
162
163impl<'a, T: TensorElement> std::ops::Deref for StorageWriteGuard<'a, T> {
164 type Target = [T];
165
166 fn deref(&self) -> &Self::Target {
167 self.as_slice()
168 }
169}
170
171impl<'a, T: TensorElement> std::ops::DerefMut for StorageWriteGuard<'a, T> {
172 fn deref_mut(&mut self) -> &mut Self::Target {
173 self.as_slice_mut()
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_storage_creation() {
183 let storage = Storage::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
184 assert_eq!(storage.len(), 4);
185 assert_eq!(storage.dtype(), DType::F32);
186 }
187
188 #[test]
189 fn test_storage_read() {
190 let storage = Storage::from_slice(&[1.0f32, 2.0, 3.0]);
191 let data = storage.as_slice::<f32>();
192 assert_eq!(&*data, &[1.0, 2.0, 3.0]);
193 }
194
195 #[test]
196 fn test_storage_write() {
197 let storage = Storage::from_slice(&[1.0f32, 2.0, 3.0]);
198 {
199 let mut data = storage.as_slice_mut::<f32>();
200 data[0] = 10.0;
201 }
202 let data = storage.as_slice::<f32>();
203 assert_eq!(data[0], 10.0);
204 }
205}