1use core::ops::{Deref, DerefMut};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use crate::device::Device;
23use crate::dtype::Scalar;
24use crate::error::{Error, Result};
25
26#[cfg(feature = "cuda")]
27use cudarc::driver::CudaSlice;
28#[cfg(feature = "cuda")]
29use cudarc::driver::DeviceSlice;
30
31#[cfg(feature = "cuda")]
38pub struct PooledCudaSlice {
39 slice: Option<CudaSlice<f32>>,
40 pool_managed: bool,
41}
42
43#[cfg(feature = "cuda")]
44impl std::fmt::Debug for PooledCudaSlice {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("PooledCudaSlice")
47 .field("pool_managed", &self.pool_managed)
48 .field("len", &self.slice.as_ref().map(|s| s.len()))
49 .finish()
50 }
51}
52
53#[cfg(feature = "cuda")]
54impl Drop for PooledCudaSlice {
55 fn drop(&mut self) {
56 if let Some(slice) = self.slice.take() {
57 if self.pool_managed {
58 crate::backends::cuda_pool::pool_free(slice);
59 }
60 }
62 }
63}
64
65#[cfg(feature = "cuda")]
66impl PooledCudaSlice {
67 pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
69 Self {
70 slice: Some(slice),
71 pool_managed,
72 }
73 }
74
75 pub fn slice(&self) -> &CudaSlice<f32> {
77 self.slice.as_ref().expect("CudaSlice already taken")
78 }
79
80 pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
82 self.slice.as_mut().expect("CudaSlice already taken")
83 }
84}
85
86#[derive(Debug)]
88enum StorageData<T: Scalar> {
89 Cpu(Vec<T>),
91 #[cfg(feature = "cuda")]
94 Cuda(PooledCudaSlice),
95}
96
97#[derive(Debug)]
107pub struct Storage<T: Scalar> {
108 inner: Arc<RwLock<StorageInner<T>>>,
110 offset: usize,
112 len: usize,
114}
115
116#[derive(Debug)]
118struct StorageInner<T: Scalar> {
119 data: StorageData<T>,
121 device: Device,
123}
124
125impl<T: Scalar> Storage<T> {
126 #[must_use]
128 pub fn zeros(len: usize, device: Device) -> Self {
129 let data = vec![T::zeroed(); len];
130 Self::from_vec(data, device)
131 }
132
133 #[must_use]
135 pub fn from_vec(data: Vec<T>, device: Device) -> Self {
136 let len = data.len();
137 Self {
138 inner: Arc::new(RwLock::new(StorageInner {
139 data: StorageData::Cpu(data),
140 device,
141 })),
142 offset: 0,
143 len,
144 }
145 }
146
147 #[must_use]
149 pub fn from_slice(data: &[T], device: Device) -> Self {
150 Self::from_vec(data.to_vec(), device)
151 }
152
153 #[must_use]
155 pub const fn len(&self) -> usize {
156 self.len
157 }
158
159 #[must_use]
161 pub const fn is_empty(&self) -> bool {
162 self.len == 0
163 }
164
165 #[must_use]
167 pub const fn offset(&self) -> usize {
168 self.offset
169 }
170
171 #[must_use]
173 pub fn device(&self) -> Device {
174 self.inner.read().device
175 }
176
177 #[must_use]
179 pub fn is_cpu(&self) -> bool {
180 matches!(self.inner.read().data, StorageData::Cpu(_))
181 }
182
183 #[must_use]
185 pub fn is_gpu(&self) -> bool {
186 !self.is_cpu()
187 }
188
189 #[must_use]
191 pub fn size_bytes(&self) -> usize {
192 self.len * core::mem::size_of::<T>()
193 }
194
195 pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
197 if offset + len > self.len {
198 return Err(Error::IndexOutOfBounds {
199 index: offset + len,
200 size: self.len,
201 });
202 }
203
204 Ok(Self {
205 inner: Arc::clone(&self.inner),
206 offset: self.offset + offset,
207 len,
208 })
209 }
210
211 #[must_use]
213 pub fn is_unique(&self) -> bool {
214 Arc::strong_count(&self.inner) == 1
215 }
216
217 #[must_use]
222 pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
223 StorageReadGuard {
224 guard: self.inner.read(),
225 offset: self.offset,
226 len: self.len,
227 }
228 }
229
230 #[must_use]
235 pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
236 StorageWriteGuard {
237 guard: self.inner.write(),
238 offset: self.offset,
239 len: self.len,
240 }
241 }
242
243 pub fn copy_from(&self, other: &Self) -> Result<()> {
245 if self.len != other.len {
246 return Err(Error::shape_mismatch(&[self.len], &[other.len]));
247 }
248
249 let src = other.as_slice();
250 let mut dst = self.as_slice_mut();
251 dst.copy_from_slice(&src);
252 Ok(())
253 }
254
255 #[must_use]
260 pub fn deep_copy(&self) -> Self {
261 let inner = self.inner.read();
262 match &inner.data {
263 StorageData::Cpu(cpu_data) => {
264 let data = cpu_data[self.offset..self.offset + self.len].to_vec();
265 Self::from_vec(data, inner.device)
266 }
267 #[cfg(feature = "cuda")]
268 StorageData::Cuda(_) => {
269 panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
270 }
271 }
272 }
273
274 pub fn to_vec(&self) -> Vec<T> {
282 let inner = self.inner.read();
283 match &inner.data {
284 StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
285 #[cfg(feature = "cuda")]
286 StorageData::Cuda(_) => {
287 panic!(
288 "Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>."
289 );
290 }
291 }
292 }
293
294 pub fn to_device(&self, device: Device) -> Result<Self> {
299 if self.device() == device {
300 return Ok(self.clone());
301 }
302
303 if device.is_cpu() && self.device().is_cpu() {
305 return Ok(self.deep_copy());
306 }
307
308 Err(Error::DeviceNotAvailable { device })
309 }
310}
311
312#[cfg(feature = "cuda")]
317impl Storage<f32> {
318 pub fn to_device_f32(&self, device: Device) -> Result<Self> {
320 if self.device() == device {
321 return Ok(self.clone());
322 }
323
324 let inner = self.inner.read();
325
326 match (&inner.data, device) {
327 (StorageData::Cpu(_), Device::Cpu) => {
329 drop(inner);
330 Ok(self.deep_copy())
331 }
332 (StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
334 let backend = crate::backends::cuda::get_cuda_backend()
335 .ok_or(Error::DeviceNotAvailable { device })?;
336 let slice = &cpu_data[self.offset..self.offset + self.len];
337 let cuda_slice = backend
338 .htod_copy(slice)
339 .map_err(|_| Error::DeviceNotAvailable { device })?;
340 let len = self.len;
341 Ok(Self {
342 inner: Arc::new(RwLock::new(StorageInner {
343 data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
344 device,
345 })),
346 offset: 0,
347 len,
348 })
349 }
350 (StorageData::Cuda(pooled), Device::Cpu) => {
352 let backend =
353 crate::backends::cuda::get_cuda_backend().ok_or(Error::DeviceNotAvailable {
354 device: self.device(),
355 })?;
356 let full_vec = backend
357 .dtoh_copy(pooled.slice())
358 .map_err(|_| Error::DeviceNotAvailable { device })?;
359 let end = self.offset + self.len;
360 let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
361 full_vec
362 } else if end <= full_vec.len() {
363 full_vec[self.offset..end].to_vec()
364 } else {
365 eprintln!(
368 "[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
369 full_vec.len(),
370 end,
371 self.offset,
372 self.len
373 );
374 let available = if self.offset < full_vec.len() {
375 full_vec.len() - self.offset
376 } else {
377 0
378 };
379 let mut result = vec![0.0f32; self.len];
380 if available > 0 {
381 result[..available]
382 .copy_from_slice(&full_vec[self.offset..self.offset + available]);
383 }
384 result
385 };
386 Ok(Self::from_vec(sliced, Device::Cpu))
387 }
388 (StorageData::Cuda(_), Device::Cuda(_)) => {
390 drop(inner);
391 let cpu_storage = self.to_device_f32(Device::Cpu)?;
392 cpu_storage.to_device_f32(device)
393 }
394 _ => Err(Error::DeviceNotAvailable { device }),
395 }
396 }
397}
398
399#[cfg(feature = "cuda")]
404impl Storage<f32> {
405 pub fn to_vec_f32(&self) -> Vec<f32> {
407 let inner = self.inner.read();
408 match &inner.data {
409 StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
410 StorageData::Cuda(pooled) => {
411 if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
412 if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
413 if self.offset == 0 && self.len == full_vec.len() {
414 return full_vec;
415 }
416 return full_vec[self.offset..self.offset + self.len].to_vec();
417 }
418 }
419 vec![0.0f32; self.len]
420 }
421 }
422 }
423
424 pub fn deep_copy_f32(&self) -> Self {
426 let device = self.device();
427 let vec = self.to_vec_f32();
428 if device.is_gpu() {
429 if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
430 if let Ok(new_slice) = backend.htod_copy(&vec) {
431 return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
432 }
433 }
434 }
435 Self::from_vec(vec, device)
436 }
437
438 pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
444 Self {
445 inner: Arc::new(RwLock::new(StorageInner {
446 data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
447 device,
448 })),
449 offset: 0,
450 len,
451 }
452 }
453
454 pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
459 Self {
460 inner: Arc::new(RwLock::new(StorageInner {
461 data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
462 device,
463 })),
464 offset: 0,
465 len,
466 }
467 }
468
469 pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
474 CudaSliceReadGuard {
475 guard: self.inner.read(),
476 }
477 }
478
479 pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
484 CudaSliceWriteGuard {
485 guard: self.inner.write(),
486 }
487 }
488}
489
490#[cfg(feature = "cuda")]
492pub struct CudaSliceReadGuard<'a> {
493 guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
494}
495
496#[cfg(feature = "cuda")]
497impl<'a> CudaSliceReadGuard<'a> {
498 pub fn slice(&self) -> &CudaSlice<f32> {
503 match &self.guard.data {
504 StorageData::Cuda(pooled) => pooled.slice(),
505 StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
506 }
507 }
508}
509
510#[cfg(feature = "cuda")]
512pub struct CudaSliceWriteGuard<'a> {
513 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
514}
515
516#[cfg(feature = "cuda")]
517impl<'a> CudaSliceWriteGuard<'a> {
518 pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
523 match &mut self.guard.data {
524 StorageData::Cuda(pooled) => pooled.slice_mut(),
525 StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
526 }
527 }
528}
529
530impl<T: Scalar> Clone for Storage<T> {
531 fn clone(&self) -> Self {
532 Self {
533 inner: Arc::clone(&self.inner),
534 offset: self.offset,
535 len: self.len,
536 }
537 }
538}
539
540pub struct StorageReadGuard<'a, T: Scalar> {
546 guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
547 offset: usize,
548 len: usize,
549}
550
551impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
552 type Target = [T];
553
554 fn deref(&self) -> &Self::Target {
555 match &self.guard.data {
556 StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
557 #[cfg(feature = "cuda")]
558 StorageData::Cuda(_) => panic!(
559 "Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."
560 ),
561 }
562 }
563}
564
565pub struct StorageWriteGuard<'a, T: Scalar> {
567 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
568 offset: usize,
569 len: usize,
570}
571
572impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
573 type Target = [T];
574
575 fn deref(&self) -> &Self::Target {
576 match &self.guard.data {
577 StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
578 #[cfg(feature = "cuda")]
579 StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
580 }
581 }
582}
583
584impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
585 fn deref_mut(&mut self) -> &mut Self::Target {
586 match &mut self.guard.data {
587 StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
588 #[cfg(feature = "cuda")]
589 StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
590 }
591 }
592}
593
594#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn test_storage_zeros() {
604 let storage = Storage::<f32>::zeros(10, Device::Cpu);
605 assert_eq!(storage.len(), 10);
606 assert!(!storage.is_empty());
607
608 let data = storage.as_slice();
609 for &val in data.iter() {
610 assert_eq!(val, 0.0);
611 }
612 }
613
614 #[test]
615 fn test_storage_from_vec() {
616 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
617 let storage = Storage::from_vec(vec.clone(), Device::Cpu);
618
619 let data = storage.as_slice();
620 assert_eq!(&*data, &vec[..]);
621 }
622
623 #[test]
624 fn test_storage_slice() {
625 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
626 let storage = Storage::from_vec(vec, Device::Cpu);
627 let slice = storage.slice(1, 3).unwrap();
628
629 assert_eq!(slice.len(), 3);
630 let data = slice.as_slice();
631 assert_eq!(&*data, &[2.0, 3.0, 4.0]);
632 }
633
634 #[test]
635 fn test_storage_clone_shares() {
636 let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
637 let storage2 = storage1.clone();
638
639 assert!(!storage1.is_unique());
640 assert!(!storage2.is_unique());
641 }
642
643 #[test]
644 fn test_storage_deep_copy() {
645 let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
646 let storage2 = storage1.deep_copy();
647
648 assert!(storage1.is_unique());
649 assert!(storage2.is_unique());
650
651 storage2.as_slice_mut()[0] = 99.0;
653
654 assert_eq!(storage1.as_slice()[0], 1.0);
656 }
657
658 #[test]
659 fn test_storage_copy_from() {
660 let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
661 let dst = Storage::<f32>::zeros(3, Device::Cpu);
662
663 dst.copy_from(&src).unwrap();
664
665 let data = dst.as_slice();
666 assert_eq!(&*data, &[1.0, 2.0, 3.0]);
667 }
668
669 #[test]
670 fn test_storage_slice_out_of_bounds() {
671 let storage = Storage::<f32>::zeros(10, Device::Cpu);
672 let result = storage.slice(5, 10);
673 assert!(result.is_err());
674 }
675
676 #[test]
677 fn test_storage_to_vec_cpu() {
678 let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
679 assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
680 }
681
682 #[test]
683 fn test_storage_is_cpu() {
684 let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
685 assert!(storage.is_cpu());
686 assert!(!storage.is_gpu());
687 }
688}