1use core::ops::{Deref, DerefMut};
27use std::sync::Arc;
28
29use parking_lot::RwLock;
30
31use crate::device::Device;
32use crate::dtype::Scalar;
33use crate::error::{Error, Result};
34
35#[cfg(feature = "cuda")]
36use cudarc::driver::CudaSlice;
37
38#[cfg(feature = "cuda")]
45pub struct PooledCudaSlice {
46 slice: Option<CudaSlice<f32>>,
47 pool_managed: bool,
48}
49
50#[cfg(feature = "cuda")]
51impl std::fmt::Debug for PooledCudaSlice {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("PooledCudaSlice")
54 .field("pool_managed", &self.pool_managed)
55 .field("len", &self.slice.as_ref().map(|s| s.len()))
56 .finish()
57 }
58}
59
60#[cfg(feature = "cuda")]
61impl Drop for PooledCudaSlice {
62 fn drop(&mut self) {
63 if let Some(slice) = self.slice.take() {
64 if self.pool_managed {
65 crate::backends::cuda_pool::pool_free(slice);
66 }
67 }
69 }
70}
71
72#[cfg(feature = "cuda")]
73impl PooledCudaSlice {
74 pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
76 Self {
77 slice: Some(slice),
78 pool_managed,
79 }
80 }
81
82 pub fn slice(&self) -> &CudaSlice<f32> {
84 self.slice.as_ref().expect("CudaSlice already taken")
85 }
86
87 pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
89 self.slice.as_mut().expect("CudaSlice already taken")
90 }
91}
92
93#[derive(Debug)]
95enum StorageData<T: Scalar> {
96 Cpu(Vec<T>),
98 #[cfg(feature = "cuda")]
101 Cuda(PooledCudaSlice),
102}
103
104#[derive(Debug)]
114pub struct Storage<T: Scalar> {
115 inner: Arc<RwLock<StorageInner<T>>>,
117 offset: usize,
119 len: usize,
121}
122
123#[derive(Debug)]
125struct StorageInner<T: Scalar> {
126 data: StorageData<T>,
128 device: Device,
130}
131
132impl<T: Scalar> Storage<T> {
133 #[must_use]
135 pub fn zeros(len: usize, device: Device) -> Self {
136 let data = vec![T::zeroed(); len];
137 Self::from_vec(data, device)
138 }
139
140 #[must_use]
142 pub fn from_vec(data: Vec<T>, device: Device) -> Self {
143 let len = data.len();
144 Self {
145 inner: Arc::new(RwLock::new(StorageInner {
146 data: StorageData::Cpu(data),
147 device,
148 })),
149 offset: 0,
150 len,
151 }
152 }
153
154 #[must_use]
156 pub fn from_slice(data: &[T], device: Device) -> Self {
157 Self::from_vec(data.to_vec(), device)
158 }
159
160 #[must_use]
162 pub const fn len(&self) -> usize {
163 self.len
164 }
165
166 #[must_use]
168 pub const fn is_empty(&self) -> bool {
169 self.len == 0
170 }
171
172 #[must_use]
174 pub const fn offset(&self) -> usize {
175 self.offset
176 }
177
178 #[must_use]
180 pub fn device(&self) -> Device {
181 self.inner.read().device
182 }
183
184 #[must_use]
186 pub fn is_cpu(&self) -> bool {
187 matches!(self.inner.read().data, StorageData::Cpu(_))
188 }
189
190 #[must_use]
192 pub fn is_gpu(&self) -> bool {
193 !self.is_cpu()
194 }
195
196 #[must_use]
198 pub fn size_bytes(&self) -> usize {
199 self.len * core::mem::size_of::<T>()
200 }
201
202 pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
204 if offset + len > self.len {
205 return Err(Error::IndexOutOfBounds {
206 index: offset + len,
207 size: self.len,
208 });
209 }
210
211 Ok(Self {
212 inner: Arc::clone(&self.inner),
213 offset: self.offset + offset,
214 len,
215 })
216 }
217
218 #[must_use]
220 pub fn is_unique(&self) -> bool {
221 Arc::strong_count(&self.inner) == 1
222 }
223
224 #[must_use]
229 pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
230 StorageReadGuard {
231 guard: self.inner.read(),
232 offset: self.offset,
233 len: self.len,
234 }
235 }
236
237 #[must_use]
242 pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
243 StorageWriteGuard {
244 guard: self.inner.write(),
245 offset: self.offset,
246 len: self.len,
247 }
248 }
249
250 pub fn copy_from(&self, other: &Self) -> Result<()> {
252 if self.len != other.len {
253 return Err(Error::shape_mismatch(&[self.len], &[other.len]));
254 }
255
256 let src = other.as_slice();
257 let mut dst = self.as_slice_mut();
258 dst.copy_from_slice(&src);
259 Ok(())
260 }
261
262 #[must_use]
267 pub fn deep_copy(&self) -> Self {
268 let inner = self.inner.read();
269 match &inner.data {
270 StorageData::Cpu(cpu_data) => {
271 let data = cpu_data[self.offset..self.offset + self.len].to_vec();
272 Self::from_vec(data, inner.device)
273 }
274 #[cfg(feature = "cuda")]
275 StorageData::Cuda(_) => {
276 panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
277 }
278 }
279 }
280
281 pub fn to_vec(&self) -> Vec<T> {
289 let inner = self.inner.read();
290 match &inner.data {
291 StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
292 #[cfg(feature = "cuda")]
293 StorageData::Cuda(_) => {
294 panic!(
295 "Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>."
296 );
297 }
298 }
299 }
300
301 pub fn to_device(&self, device: Device) -> Result<Self> {
306 if self.device() == device {
307 return Ok(self.clone());
308 }
309
310 if device.is_cpu() && self.device().is_cpu() {
312 return Ok(self.deep_copy());
313 }
314
315 Err(Error::DeviceNotAvailable { device })
316 }
317}
318
319#[cfg(feature = "cuda")]
324impl Storage<f32> {
325 pub fn to_device_f32(&self, device: Device) -> Result<Self> {
327 if self.device() == device {
328 return Ok(self.clone());
329 }
330
331 let inner = self.inner.read();
332
333 match (&inner.data, device) {
334 (StorageData::Cpu(_), Device::Cpu) => {
336 drop(inner);
337 Ok(self.deep_copy())
338 }
339 (StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
341 let backend = crate::backends::cuda::get_cuda_backend()
342 .ok_or(Error::DeviceNotAvailable { device })?;
343 let slice = &cpu_data[self.offset..self.offset + self.len];
344 let cuda_slice = backend
345 .htod_copy(slice)
346 .map_err(|_| Error::DeviceNotAvailable { device })?;
347 let len = self.len;
348 Ok(Self {
349 inner: Arc::new(RwLock::new(StorageInner {
350 data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
351 device,
352 })),
353 offset: 0,
354 len,
355 })
356 }
357 (StorageData::Cuda(pooled), Device::Cpu) => {
359 let backend =
360 crate::backends::cuda::get_cuda_backend().ok_or(Error::DeviceNotAvailable {
361 device: self.device(),
362 })?;
363 let full_vec = backend
364 .dtoh_copy(pooled.slice())
365 .map_err(|_| Error::DeviceNotAvailable { device })?;
366 let end = self.offset + self.len;
367 let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
368 full_vec
369 } else if end <= full_vec.len() {
370 full_vec[self.offset..end].to_vec()
371 } else {
372 eprintln!(
375 "[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
376 full_vec.len(),
377 end,
378 self.offset,
379 self.len
380 );
381 let available = if self.offset < full_vec.len() {
382 full_vec.len() - self.offset
383 } else {
384 0
385 };
386 let mut result = vec![0.0f32; self.len];
387 if available > 0 {
388 result[..available]
389 .copy_from_slice(&full_vec[self.offset..self.offset + available]);
390 }
391 result
392 };
393 Ok(Self::from_vec(sliced, Device::Cpu))
394 }
395 (StorageData::Cuda(_), Device::Cuda(_)) => {
397 drop(inner);
398 let cpu_storage = self.to_device_f32(Device::Cpu)?;
399 cpu_storage.to_device_f32(device)
400 }
401 }
402 }
403}
404
405#[cfg(feature = "cuda")]
410impl Storage<f32> {
411 pub fn to_vec_f32(&self) -> Vec<f32> {
413 let inner = self.inner.read();
414 match &inner.data {
415 StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
416 StorageData::Cuda(pooled) => {
417 if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
418 if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
419 if self.offset == 0 && self.len == full_vec.len() {
420 return full_vec;
421 }
422 return full_vec[self.offset..self.offset + self.len].to_vec();
423 }
424 }
425 vec![0.0f32; self.len]
426 }
427 }
428 }
429
430 pub fn deep_copy_f32(&self) -> Self {
432 let device = self.device();
433 let vec = self.to_vec_f32();
434 if device.is_gpu() {
435 if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
436 if let Ok(new_slice) = backend.htod_copy(&vec) {
437 return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
438 }
439 }
440 }
441 Self::from_vec(vec, device)
442 }
443
444 pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
450 Self {
451 inner: Arc::new(RwLock::new(StorageInner {
452 data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
453 device,
454 })),
455 offset: 0,
456 len,
457 }
458 }
459
460 pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
465 Self {
466 inner: Arc::new(RwLock::new(StorageInner {
467 data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
468 device,
469 })),
470 offset: 0,
471 len,
472 }
473 }
474
475 pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
480 CudaSliceReadGuard {
481 guard: self.inner.read(),
482 }
483 }
484
485 pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
490 CudaSliceWriteGuard {
491 guard: self.inner.write(),
492 }
493 }
494}
495
496#[cfg(feature = "cuda")]
498pub struct CudaSliceReadGuard<'a> {
499 guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
500}
501
502#[cfg(feature = "cuda")]
503impl<'a> CudaSliceReadGuard<'a> {
504 pub fn slice(&self) -> &CudaSlice<f32> {
509 match &self.guard.data {
510 StorageData::Cuda(pooled) => pooled.slice(),
511 StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
512 }
513 }
514}
515
516#[cfg(feature = "cuda")]
518pub struct CudaSliceWriteGuard<'a> {
519 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
520}
521
522#[cfg(feature = "cuda")]
523impl<'a> CudaSliceWriteGuard<'a> {
524 pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
529 match &mut self.guard.data {
530 StorageData::Cuda(pooled) => pooled.slice_mut(),
531 StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
532 }
533 }
534}
535
536impl<T: Scalar> Clone for Storage<T> {
537 fn clone(&self) -> Self {
538 Self {
539 inner: Arc::clone(&self.inner),
540 offset: self.offset,
541 len: self.len,
542 }
543 }
544}
545
546pub struct StorageReadGuard<'a, T: Scalar> {
552 guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
553 offset: usize,
554 len: usize,
555}
556
557impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
558 type Target = [T];
559
560 fn deref(&self) -> &Self::Target {
561 match &self.guard.data {
562 StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
563 #[cfg(feature = "cuda")]
564 StorageData::Cuda(_) => panic!(
565 "Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."
566 ),
567 }
568 }
569}
570
571pub struct StorageWriteGuard<'a, T: Scalar> {
573 guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
574 offset: usize,
575 len: usize,
576}
577
578impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
579 type Target = [T];
580
581 fn deref(&self) -> &Self::Target {
582 match &self.guard.data {
583 StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
584 #[cfg(feature = "cuda")]
585 StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
586 }
587 }
588}
589
590impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
591 fn deref_mut(&mut self) -> &mut Self::Target {
592 match &mut self.guard.data {
593 StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
594 #[cfg(feature = "cuda")]
595 StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
596 }
597 }
598}
599
600#[cfg(test)]
605mod tests {
606 use super::*;
607
608 #[test]
609 fn test_storage_zeros() {
610 let storage = Storage::<f32>::zeros(10, Device::Cpu);
611 assert_eq!(storage.len(), 10);
612 assert!(!storage.is_empty());
613
614 let data = storage.as_slice();
615 for &val in data.iter() {
616 assert_eq!(val, 0.0);
617 }
618 }
619
620 #[test]
621 fn test_storage_from_vec() {
622 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
623 let storage = Storage::from_vec(vec.clone(), Device::Cpu);
624
625 let data = storage.as_slice();
626 assert_eq!(&*data, &vec[..]);
627 }
628
629 #[test]
630 fn test_storage_slice() {
631 let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
632 let storage = Storage::from_vec(vec, Device::Cpu);
633 let slice = storage.slice(1, 3).unwrap();
634
635 assert_eq!(slice.len(), 3);
636 let data = slice.as_slice();
637 assert_eq!(&*data, &[2.0, 3.0, 4.0]);
638 }
639
640 #[test]
641 fn test_storage_clone_shares() {
642 let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
643 let storage2 = storage1.clone();
644
645 assert!(!storage1.is_unique());
646 assert!(!storage2.is_unique());
647 }
648
649 #[test]
650 fn test_storage_deep_copy() {
651 let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
652 let storage2 = storage1.deep_copy();
653
654 assert!(storage1.is_unique());
655 assert!(storage2.is_unique());
656
657 storage2.as_slice_mut()[0] = 99.0;
659
660 assert_eq!(storage1.as_slice()[0], 1.0);
662 }
663
664 #[test]
665 fn test_storage_copy_from() {
666 let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
667 let dst = Storage::<f32>::zeros(3, Device::Cpu);
668
669 dst.copy_from(&src).unwrap();
670
671 let data = dst.as_slice();
672 assert_eq!(&*data, &[1.0, 2.0, 3.0]);
673 }
674
675 #[test]
676 fn test_storage_slice_out_of_bounds() {
677 let storage = Storage::<f32>::zeros(10, Device::Cpu);
678 let result = storage.slice(5, 10);
679 assert!(result.is_err());
680 }
681
682 #[test]
683 fn test_storage_to_vec_cpu() {
684 let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
685 assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
686 }
687
688 #[test]
689 fn test_storage_is_cpu() {
690 let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
691 assert!(storage.is_cpu());
692 assert!(!storage.is_gpu());
693 }
694}