1use core::ops::{Deref, DerefMut};
25use std::sync::Arc;
26
27use parking_lot::RwLock;
28
29use crate::device::Device;
30use crate::dtype::Scalar;
31use crate::error::{Error, Result};
32
33#[derive(Debug)]
43pub struct Storage<T: Scalar> {
44 inner: Arc<RwLock<StorageInner<T>>>,
46 offset: usize,
48 len: usize,
50}
51
52#[derive(Debug)]
54struct StorageInner<T: Scalar> {
55 data: Vec<T>,
57 device: Device,
59}
60
61impl<T: Scalar> Storage<T> {
62 #[must_use]
71 pub fn zeros(len: usize, device: Device) -> Self {
72 let data = vec![T::zeroed(); len];
73 Self::from_vec(data, device)
74 }
75
76 #[must_use]
85 pub fn from_vec(data: Vec<T>, device: Device) -> Self {
86 let len = data.len();
87 Self {
88 inner: Arc::new(RwLock::new(StorageInner { data, device })),
89 offset: 0,
90 len,
91 }
92 }
93
94 #[must_use]
103 pub fn from_slice(data: &[T], device: Device) -> Self {
104 Self::from_vec(data.to_vec(), device)
105 }
106
107 #[must_use]
109 pub const fn len(&self) -> usize {
110 self.len
111 }
112
113 #[must_use]
115 pub const fn is_empty(&self) -> bool {
116 self.len == 0
117 }
118
119 #[must_use]
121 pub const fn offset(&self) -> usize {
122 self.offset
123 }
124
125 #[must_use]
127 pub fn device(&self) -> Device {
128 self.inner.read().device
129 }
130
131 #[must_use]
133 pub fn size_bytes(&self) -> usize {
134 self.len * core::mem::size_of::<T>()
135 }
136
137 pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
146 if offset + len > self.len {
147 return Err(Error::IndexOutOfBounds {
148 index: offset + len,
149 size: self.len,
150 });
151 }
152
153 Ok(Self {
154 inner: Arc::clone(&self.inner),
155 offset: self.offset + offset,
156 len,
157 })
158 }
159
160 #[must_use]
162 pub fn is_unique(&self) -> bool {
163 Arc::strong_count(&self.inner) == 1
164 }
165
166 #[must_use]
171 pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
172 StorageReadGuard {
173 guard: self.inner.read(),
174 offset: self.offset,
175 len: self.len,
176 }
177 }
178
179 #[must_use]
184 pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
185 StorageWriteGuard {
186 guard: self.inner.write(),
187 offset: self.offset,
188 len: self.len,
189 }
190 }
191
192 pub fn copy_from(&self, other: &Self) -> Result<()> {
200 if self.len != other.len {
201 return Err(Error::shape_mismatch(&[self.len], &[other.len]));
202 }
203
204 let src = other.as_slice();
205 let mut dst = self.as_slice_mut();
206 dst.copy_from_slice(&src);
207 Ok(())
208 }
209
210 #[must_use]
212 pub fn deep_copy(&self) -> Self {
213 let data = self.as_slice().to_vec();
214 Self::from_vec(data, self.device())
215 }
216
217 pub fn to_device(&self, device: Device) -> Result<Self> {
225 if self.device() == device {
226 return Ok(self.clone());
227 }
228
229 if !device.is_cpu() {
231 return Err(Error::DeviceNotAvailable { device });
232 }
233
234 Ok(self.deep_copy())
235 }
236}
237
238impl<T: Scalar> Clone for Storage<T> {
239 fn clone(&self) -> Self {
240 Self {
241 inner: Arc::clone(&self.inner),
242 offset: self.offset,
243 len: self.len,
244 }
245 }
246}
247
248pub struct StorageReadGuard<'a, T: Scalar> {
254 guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
255 offset: usize,
256 len: usize,
257}
258
259impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
260 type Target = [T];
261
262 fn deref(&self) -> &Self::Target {
263 &self.guard.data[self.offset..self.offset + self.len]
264 }
265}
266
267pub struct StorageWriteGuard<'a, T: Scalar> {
269 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
270 offset: usize,
271 len: usize,
272}
273
274impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
275 type Target = [T];
276
277 fn deref(&self) -> &Self::Target {
278 &self.guard.data[self.offset..self.offset + self.len]
279 }
280}
281
282impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
283 fn deref_mut(&mut self) -> &mut Self::Target {
284 &mut self.guard.data[self.offset..self.offset + self.len]
285 }
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_storage_zeros() {
298 let storage = Storage::<f32>::zeros(10, Device::Cpu);
299 assert_eq!(storage.len(), 10);
300 assert!(!storage.is_empty());
301
302 let data = storage.as_slice();
303 for &val in data.iter() {
304 assert_eq!(val, 0.0);
305 }
306 }
307
308 #[test]
309 fn test_storage_from_vec() {
310 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
311 let storage = Storage::from_vec(vec.clone(), Device::Cpu);
312
313 let data = storage.as_slice();
314 assert_eq!(&*data, &vec[..]);
315 }
316
317 #[test]
318 fn test_storage_slice() {
319 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
320 let storage = Storage::from_vec(vec, Device::Cpu);
321 let slice = storage.slice(1, 3).unwrap();
322
323 assert_eq!(slice.len(), 3);
324 let data = slice.as_slice();
325 assert_eq!(&*data, &[2.0, 3.0, 4.0]);
326 }
327
328 #[test]
329 fn test_storage_clone_shares() {
330 let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
331 let storage2 = storage1.clone();
332
333 assert!(!storage1.is_unique());
334 assert!(!storage2.is_unique());
335 }
336
337 #[test]
338 fn test_storage_deep_copy() {
339 let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
340 let storage2 = storage1.deep_copy();
341
342 assert!(storage1.is_unique());
343 assert!(storage2.is_unique());
344
345 storage2.as_slice_mut()[0] = 99.0;
347
348 assert_eq!(storage1.as_slice()[0], 1.0);
350 }
351
352 #[test]
353 fn test_storage_copy_from() {
354 let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
355 let dst = Storage::<f32>::zeros(3, Device::Cpu);
356
357 dst.copy_from(&src).unwrap();
358
359 let data = dst.as_slice();
360 assert_eq!(&*data, &[1.0, 2.0, 3.0]);
361 }
362
363 #[test]
364 fn test_storage_slice_out_of_bounds() {
365 let storage = Storage::<f32>::zeros(10, Device::Cpu);
366 let result = storage.slice(5, 10);
367 assert!(result.is_err());
368 }
369}