1use core::ops::{Deref, DerefMut};
16use std::sync::Arc;
17
18use parking_lot::RwLock;
19
20use crate::device::Device;
21use crate::dtype::Scalar;
22use crate::error::{Error, Result};
23
24#[cfg(feature = "cuda")]
25use cudarc::driver::CudaSlice;
26#[cfg(feature = "cuda")]
27use cudarc::driver::safe::DeviceSlice;
28
29#[cfg(feature = "cuda")]
36pub struct PooledCudaSlice {
37 slice: Option<CudaSlice<f32>>,
38 pool_managed: bool,
39}
40
41#[cfg(feature = "cuda")]
42impl std::fmt::Debug for PooledCudaSlice {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("PooledCudaSlice")
45 .field("pool_managed", &self.pool_managed)
46 .field("len", &self.slice.as_ref().map(|s| s.len()))
47 .finish()
48 }
49}
50
51#[cfg(feature = "cuda")]
52impl Drop for PooledCudaSlice {
53 fn drop(&mut self) {
54 if let Some(slice) = self.slice.take() {
55 if self.pool_managed {
56 crate::backends::cuda_pool::pool_free(slice);
57 }
58 }
60 }
61}
62
63#[cfg(feature = "cuda")]
64impl PooledCudaSlice {
65 pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
67 Self { slice: Some(slice), pool_managed }
68 }
69
70 pub fn slice(&self) -> &CudaSlice<f32> {
72 self.slice.as_ref().expect("CudaSlice already taken")
73 }
74
75 pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
77 self.slice.as_mut().expect("CudaSlice already taken")
78 }
79}
80
81#[derive(Debug)]
83enum StorageData<T: Scalar> {
84 Cpu(Vec<T>),
86 #[cfg(feature = "cuda")]
89 Cuda(PooledCudaSlice),
90}
91
92#[derive(Debug)]
102pub struct Storage<T: Scalar> {
103 inner: Arc<RwLock<StorageInner<T>>>,
105 offset: usize,
107 len: usize,
109}
110
111#[derive(Debug)]
113struct StorageInner<T: Scalar> {
114 data: StorageData<T>,
116 device: Device,
118}
119
120impl<T: Scalar> Storage<T> {
121 #[must_use]
123 pub fn zeros(len: usize, device: Device) -> Self {
124 let data = vec![T::zeroed(); len];
125 Self::from_vec(data, device)
126 }
127
128 #[must_use]
130 pub fn from_vec(data: Vec<T>, device: Device) -> Self {
131 let len = data.len();
132 Self {
133 inner: Arc::new(RwLock::new(StorageInner {
134 data: StorageData::Cpu(data),
135 device,
136 })),
137 offset: 0,
138 len,
139 }
140 }
141
142 #[must_use]
144 pub fn from_slice(data: &[T], device: Device) -> Self {
145 Self::from_vec(data.to_vec(), device)
146 }
147
148 #[must_use]
150 pub const fn len(&self) -> usize {
151 self.len
152 }
153
154 #[must_use]
156 pub const fn is_empty(&self) -> bool {
157 self.len == 0
158 }
159
160 #[must_use]
162 pub const fn offset(&self) -> usize {
163 self.offset
164 }
165
166 #[must_use]
168 pub fn device(&self) -> Device {
169 self.inner.read().device
170 }
171
172 #[must_use]
174 pub fn is_cpu(&self) -> bool {
175 matches!(self.inner.read().data, StorageData::Cpu(_))
176 }
177
178 #[must_use]
180 pub fn is_gpu(&self) -> bool {
181 !self.is_cpu()
182 }
183
184 #[must_use]
186 pub fn size_bytes(&self) -> usize {
187 self.len * core::mem::size_of::<T>()
188 }
189
190 pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
192 if offset + len > self.len {
193 return Err(Error::IndexOutOfBounds {
194 index: offset + len,
195 size: self.len,
196 });
197 }
198
199 Ok(Self {
200 inner: Arc::clone(&self.inner),
201 offset: self.offset + offset,
202 len,
203 })
204 }
205
206 #[must_use]
208 pub fn is_unique(&self) -> bool {
209 Arc::strong_count(&self.inner) == 1
210 }
211
212 #[must_use]
217 pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
218 StorageReadGuard {
219 guard: self.inner.read(),
220 offset: self.offset,
221 len: self.len,
222 }
223 }
224
225 #[must_use]
230 pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
231 StorageWriteGuard {
232 guard: self.inner.write(),
233 offset: self.offset,
234 len: self.len,
235 }
236 }
237
238 pub fn copy_from(&self, other: &Self) -> Result<()> {
240 if self.len != other.len {
241 return Err(Error::shape_mismatch(&[self.len], &[other.len]));
242 }
243
244 let src = other.as_slice();
245 let mut dst = self.as_slice_mut();
246 dst.copy_from_slice(&src);
247 Ok(())
248 }
249
250 #[must_use]
255 pub fn deep_copy(&self) -> Self {
256 let inner = self.inner.read();
257 match &inner.data {
258 StorageData::Cpu(cpu_data) => {
259 let data = cpu_data[self.offset..self.offset + self.len].to_vec();
260 Self::from_vec(data, inner.device)
261 }
262 #[cfg(feature = "cuda")]
263 StorageData::Cuda(_) => {
264 panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
265 }
266 }
267 }
268
269 pub fn to_vec(&self) -> Vec<T> {
277 let inner = self.inner.read();
278 match &inner.data {
279 StorageData::Cpu(cpu_data) => {
280 cpu_data[self.offset..self.offset + self.len].to_vec()
281 }
282 #[cfg(feature = "cuda")]
283 StorageData::Cuda(_) => {
284 panic!("Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>.");
285 }
286 }
287 }
288
289 pub fn to_device(&self, device: Device) -> Result<Self> {
294 if self.device() == device {
295 return Ok(self.clone());
296 }
297
298 if device.is_cpu() && self.device().is_cpu() {
300 return Ok(self.deep_copy());
301 }
302
303 Err(Error::DeviceNotAvailable { device })
304 }
305}
306
307#[cfg(feature = "cuda")]
312impl Storage<f32> {
313 pub fn to_device_f32(&self, device: Device) -> Result<Self> {
315 if self.device() == device {
316 return Ok(self.clone());
317 }
318
319 let inner = self.inner.read();
320
321 match (&inner.data, device) {
322 (StorageData::Cpu(_), Device::Cpu) => {
324 drop(inner);
325 Ok(self.deep_copy())
326 }
327 (StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
329 let backend = crate::backends::cuda::get_cuda_backend()
330 .ok_or(Error::DeviceNotAvailable { device })?;
331 let slice = &cpu_data[self.offset..self.offset + self.len];
332 let cuda_slice = backend.htod_copy(slice)
333 .map_err(|_| Error::DeviceNotAvailable { device })?;
334 let len = self.len;
335 Ok(Self {
336 inner: Arc::new(RwLock::new(StorageInner {
337 data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
338 device,
339 })),
340 offset: 0,
341 len,
342 })
343 }
344 (StorageData::Cuda(pooled), Device::Cpu) => {
346 let backend = crate::backends::cuda::get_cuda_backend()
347 .ok_or(Error::DeviceNotAvailable { device: self.device() })?;
348 let full_vec = backend.dtoh_copy(pooled.slice())
349 .map_err(|_| Error::DeviceNotAvailable { device })?;
350 let end = self.offset + self.len;
351 let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
352 full_vec
353 } else if end <= full_vec.len() {
354 full_vec[self.offset..end].to_vec()
355 } else {
356 eprintln!(
359 "[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
360 full_vec.len(), end, self.offset, self.len
361 );
362 let available = if self.offset < full_vec.len() {
363 full_vec.len() - self.offset
364 } else {
365 0
366 };
367 let mut result = vec![0.0f32; self.len];
368 if available > 0 {
369 result[..available].copy_from_slice(&full_vec[self.offset..self.offset + available]);
370 }
371 result
372 };
373 Ok(Self::from_vec(sliced, Device::Cpu))
374 }
375 (StorageData::Cuda(_), Device::Cuda(_)) => {
377 drop(inner);
378 let cpu_storage = self.to_device_f32(Device::Cpu)?;
379 cpu_storage.to_device_f32(device)
380 }
381 _ => Err(Error::DeviceNotAvailable { device }),
382 }
383 }
384}
385
386#[cfg(feature = "cuda")]
391impl Storage<f32> {
392 pub fn to_vec_f32(&self) -> Vec<f32> {
394 let inner = self.inner.read();
395 match &inner.data {
396 StorageData::Cpu(cpu_data) => {
397 cpu_data[self.offset..self.offset + self.len].to_vec()
398 }
399 StorageData::Cuda(pooled) => {
400 if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
401 if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
402 if self.offset == 0 && self.len == full_vec.len() {
403 return full_vec;
404 }
405 return full_vec[self.offset..self.offset + self.len].to_vec();
406 }
407 }
408 vec![0.0f32; self.len]
409 }
410 }
411 }
412
413 pub fn deep_copy_f32(&self) -> Self {
415 let device = self.device();
416 let vec = self.to_vec_f32();
417 if device.is_gpu() {
418 if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
419 if let Ok(new_slice) = backend.htod_copy(&vec) {
420 return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
421 }
422 }
423 }
424 Self::from_vec(vec, device)
425 }
426
427 pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
433 Self {
434 inner: Arc::new(RwLock::new(StorageInner {
435 data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
436 device,
437 })),
438 offset: 0,
439 len,
440 }
441 }
442
443 pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
448 Self {
449 inner: Arc::new(RwLock::new(StorageInner {
450 data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
451 device,
452 })),
453 offset: 0,
454 len,
455 }
456 }
457
458
459 pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
464 CudaSliceReadGuard {
465 guard: self.inner.read(),
466 }
467 }
468
469 pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
474 CudaSliceWriteGuard {
475 guard: self.inner.write(),
476 }
477 }
478}
479
480#[cfg(feature = "cuda")]
482pub struct CudaSliceReadGuard<'a> {
483 guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
484}
485
486#[cfg(feature = "cuda")]
487impl<'a> CudaSliceReadGuard<'a> {
488 pub fn slice(&self) -> &CudaSlice<f32> {
493 match &self.guard.data {
494 StorageData::Cuda(pooled) => pooled.slice(),
495 StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
496 }
497 }
498}
499
500#[cfg(feature = "cuda")]
502pub struct CudaSliceWriteGuard<'a> {
503 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
504}
505
506#[cfg(feature = "cuda")]
507impl<'a> CudaSliceWriteGuard<'a> {
508 pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
513 match &mut self.guard.data {
514 StorageData::Cuda(pooled) => pooled.slice_mut(),
515 StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
516 }
517 }
518}
519
520impl<T: Scalar> Clone for Storage<T> {
521 fn clone(&self) -> Self {
522 Self {
523 inner: Arc::clone(&self.inner),
524 offset: self.offset,
525 len: self.len,
526 }
527 }
528}
529
530pub struct StorageReadGuard<'a, T: Scalar> {
536 guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
537 offset: usize,
538 len: usize,
539}
540
541impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
542 type Target = [T];
543
544 fn deref(&self) -> &Self::Target {
545 match &self.guard.data {
546 StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
547 #[cfg(feature = "cuda")]
548 StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."),
549 }
550 }
551}
552
553pub struct StorageWriteGuard<'a, T: Scalar> {
555 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
556 offset: usize,
557 len: usize,
558}
559
560impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
561 type Target = [T];
562
563 fn deref(&self) -> &Self::Target {
564 match &self.guard.data {
565 StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
566 #[cfg(feature = "cuda")]
567 StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
568 }
569 }
570}
571
572impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
573 fn deref_mut(&mut self) -> &mut Self::Target {
574 match &mut self.guard.data {
575 StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
576 #[cfg(feature = "cuda")]
577 StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
578 }
579 }
580}
581
582#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_storage_zeros() {
592 let storage = Storage::<f32>::zeros(10, Device::Cpu);
593 assert_eq!(storage.len(), 10);
594 assert!(!storage.is_empty());
595
596 let data = storage.as_slice();
597 for &val in data.iter() {
598 assert_eq!(val, 0.0);
599 }
600 }
601
602 #[test]
603 fn test_storage_from_vec() {
604 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
605 let storage = Storage::from_vec(vec.clone(), Device::Cpu);
606
607 let data = storage.as_slice();
608 assert_eq!(&*data, &vec[..]);
609 }
610
611 #[test]
612 fn test_storage_slice() {
613 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
614 let storage = Storage::from_vec(vec, Device::Cpu);
615 let slice = storage.slice(1, 3).unwrap();
616
617 assert_eq!(slice.len(), 3);
618 let data = slice.as_slice();
619 assert_eq!(&*data, &[2.0, 3.0, 4.0]);
620 }
621
622 #[test]
623 fn test_storage_clone_shares() {
624 let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
625 let storage2 = storage1.clone();
626
627 assert!(!storage1.is_unique());
628 assert!(!storage2.is_unique());
629 }
630
631 #[test]
632 fn test_storage_deep_copy() {
633 let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
634 let storage2 = storage1.deep_copy();
635
636 assert!(storage1.is_unique());
637 assert!(storage2.is_unique());
638
639 storage2.as_slice_mut()[0] = 99.0;
641
642 assert_eq!(storage1.as_slice()[0], 1.0);
644 }
645
646 #[test]
647 fn test_storage_copy_from() {
648 let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
649 let dst = Storage::<f32>::zeros(3, Device::Cpu);
650
651 dst.copy_from(&src).unwrap();
652
653 let data = dst.as_slice();
654 assert_eq!(&*data, &[1.0, 2.0, 3.0]);
655 }
656
657 #[test]
658 fn test_storage_slice_out_of_bounds() {
659 let storage = Storage::<f32>::zeros(10, Device::Cpu);
660 let result = storage.slice(5, 10);
661 assert!(result.is_err());
662 }
663
664 #[test]
665 fn test_storage_to_vec_cpu() {
666 let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
667 assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
668 }
669
670 #[test]
671 fn test_storage_is_cpu() {
672 let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
673 assert!(storage.is_cpu());
674 assert!(!storage.is_gpu());
675 }
676}