use core::ops::{Deref, DerefMut};
use std::sync::Arc;
use parking_lot::RwLock;
use crate::device::Device;
use crate::dtype::Scalar;
use crate::error::{Error, Result};
#[cfg(feature = "cuda")]
use cudarc::driver::CudaSlice;
#[cfg(feature = "cuda")]
pub struct PooledCudaSlice {
slice: Option<CudaSlice<f32>>,
pool_managed: bool,
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for PooledCudaSlice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledCudaSlice")
.field("pool_managed", &self.pool_managed)
.field("len", &self.slice.as_ref().map(|s| s.len()))
.finish()
}
}
#[cfg(feature = "cuda")]
impl Drop for PooledCudaSlice {
fn drop(&mut self) {
if let Some(slice) = self.slice.take() {
if self.pool_managed {
crate::backends::cuda_pool::pool_free(slice);
}
}
}
}
#[cfg(feature = "cuda")]
impl PooledCudaSlice {
pub fn new(slice: CudaSlice<f32>, pool_managed: bool) -> Self {
Self {
slice: Some(slice),
pool_managed,
}
}
pub fn slice(&self) -> &CudaSlice<f32> {
self.slice.as_ref().expect("CudaSlice already taken")
}
pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
self.slice.as_mut().expect("CudaSlice already taken")
}
}
#[derive(Debug)]
enum StorageData<T: Scalar> {
Cpu(Vec<T>),
#[cfg(feature = "cuda")]
Cuda(PooledCudaSlice),
}
#[derive(Debug)]
pub struct Storage<T: Scalar> {
inner: Arc<RwLock<StorageInner<T>>>,
offset: usize,
len: usize,
}
#[derive(Debug)]
struct StorageInner<T: Scalar> {
data: StorageData<T>,
device: Device,
}
impl<T: Scalar> Storage<T> {
#[must_use]
pub fn zeros(len: usize, device: Device) -> Self {
let data = vec![T::zeroed(); len];
Self::from_vec(data, device)
}
#[must_use]
pub fn from_vec(data: Vec<T>, device: Device) -> Self {
let len = data.len();
Self {
inner: Arc::new(RwLock::new(StorageInner {
data: StorageData::Cpu(data),
device,
})),
offset: 0,
len,
}
}
#[must_use]
pub fn from_slice(data: &[T], device: Device) -> Self {
Self::from_vec(data.to_vec(), device)
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub const fn offset(&self) -> usize {
self.offset
}
#[must_use]
pub fn device(&self) -> Device {
self.inner.read().device
}
#[must_use]
pub fn is_cpu(&self) -> bool {
matches!(self.inner.read().data, StorageData::Cpu(_))
}
#[must_use]
pub fn is_gpu(&self) -> bool {
!self.is_cpu()
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len * core::mem::size_of::<T>()
}
pub fn slice(&self, offset: usize, len: usize) -> Result<Self> {
if offset + len > self.len {
return Err(Error::IndexOutOfBounds {
index: offset + len,
size: self.len,
});
}
Ok(Self {
inner: Arc::clone(&self.inner),
offset: self.offset + offset,
len,
})
}
#[must_use]
pub fn is_unique(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
#[must_use]
pub fn as_slice(&self) -> StorageReadGuard<'_, T> {
StorageReadGuard {
guard: self.inner.read(),
offset: self.offset,
len: self.len,
}
}
#[must_use]
pub fn as_slice_mut(&self) -> StorageWriteGuard<'_, T> {
StorageWriteGuard {
guard: self.inner.write(),
offset: self.offset,
len: self.len,
}
}
pub fn copy_from(&self, other: &Self) -> Result<()> {
if self.len != other.len {
return Err(Error::shape_mismatch(&[self.len], &[other.len]));
}
let src = other.as_slice();
let mut dst = self.as_slice_mut();
dst.copy_from_slice(&src);
Ok(())
}
#[must_use]
pub fn deep_copy(&self) -> Self {
let inner = self.inner.read();
match &inner.data {
StorageData::Cpu(cpu_data) => {
let data = cpu_data[self.offset..self.offset + self.len].to_vec();
Self::from_vec(data, inner.device)
}
#[cfg(feature = "cuda")]
StorageData::Cuda(_) => {
panic!("deep_copy() on GPU storage requires Storage<f32>. Use deep_copy_f32().");
}
}
}
pub fn to_vec(&self) -> Vec<T> {
let inner = self.inner.read();
match &inner.data {
StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
#[cfg(feature = "cuda")]
StorageData::Cuda(_) => {
panic!(
"Cannot call to_vec() on GPU storage for generic T. Use to_vec_f32() on Storage<f32>."
);
}
}
}
pub fn to_device(&self, device: Device) -> Result<Self> {
if self.device() == device {
return Ok(self.clone());
}
if device.is_cpu() && self.device().is_cpu() {
return Ok(self.deep_copy());
}
Err(Error::DeviceNotAvailable { device })
}
}
#[cfg(feature = "cuda")]
impl Storage<f32> {
pub fn to_device_f32(&self, device: Device) -> Result<Self> {
if self.device() == device {
return Ok(self.clone());
}
let inner = self.inner.read();
match (&inner.data, device) {
(StorageData::Cpu(_), Device::Cpu) => {
drop(inner);
Ok(self.deep_copy())
}
(StorageData::Cpu(cpu_data), Device::Cuda(_idx)) => {
let backend = crate::backends::cuda::get_cuda_backend()
.ok_or(Error::DeviceNotAvailable { device })?;
let slice = &cpu_data[self.offset..self.offset + self.len];
let cuda_slice = backend
.htod_copy(slice)
.map_err(|_| Error::DeviceNotAvailable { device })?;
let len = self.len;
Ok(Self {
inner: Arc::new(RwLock::new(StorageInner {
data: StorageData::Cuda(PooledCudaSlice::new(cuda_slice, false)),
device,
})),
offset: 0,
len,
})
}
(StorageData::Cuda(pooled), Device::Cpu) => {
let backend =
crate::backends::cuda::get_cuda_backend().ok_or(Error::DeviceNotAvailable {
device: self.device(),
})?;
let full_vec = backend
.dtoh_copy(pooled.slice())
.map_err(|_| Error::DeviceNotAvailable { device })?;
let end = self.offset + self.len;
let sliced: Vec<f32> = if self.offset == 0 && self.len == full_vec.len() {
full_vec
} else if end <= full_vec.len() {
full_vec[self.offset..end].to_vec()
} else {
eprintln!(
"[storage] WARNING: CudaSlice len={} < Storage offset+len={} (offset={}, len={})",
full_vec.len(),
end,
self.offset,
self.len
);
let available = if self.offset < full_vec.len() {
full_vec.len() - self.offset
} else {
0
};
let mut result = vec![0.0f32; self.len];
if available > 0 {
result[..available]
.copy_from_slice(&full_vec[self.offset..self.offset + available]);
}
result
};
Ok(Self::from_vec(sliced, Device::Cpu))
}
(StorageData::Cuda(_), Device::Cuda(_)) => {
drop(inner);
let cpu_storage = self.to_device_f32(Device::Cpu)?;
cpu_storage.to_device_f32(device)
}
}
}
}
#[cfg(feature = "cuda")]
impl Storage<f32> {
pub fn to_vec_f32(&self) -> Vec<f32> {
let inner = self.inner.read();
match &inner.data {
StorageData::Cpu(cpu_data) => cpu_data[self.offset..self.offset + self.len].to_vec(),
StorageData::Cuda(pooled) => {
if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
if let Ok(full_vec) = backend.dtoh_copy(pooled.slice()) {
if self.offset == 0 && self.len == full_vec.len() {
return full_vec;
}
return full_vec[self.offset..self.offset + self.len].to_vec();
}
}
vec![0.0f32; self.len]
}
}
}
pub fn deep_copy_f32(&self) -> Self {
let device = self.device();
let vec = self.to_vec_f32();
if device.is_gpu() {
if let Some(backend) = crate::backends::cuda::get_cuda_backend() {
if let Ok(new_slice) = backend.htod_copy(&vec) {
return Self::from_cuda_slice_unmanaged(new_slice, self.len, device);
}
}
}
Self::from_vec(vec, device)
}
pub fn from_cuda_slice(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
Self {
inner: Arc::new(RwLock::new(StorageInner {
data: StorageData::Cuda(PooledCudaSlice::new(slice, true)),
device,
})),
offset: 0,
len,
}
}
pub fn from_cuda_slice_unmanaged(slice: CudaSlice<f32>, len: usize, device: Device) -> Self {
Self {
inner: Arc::new(RwLock::new(StorageInner {
data: StorageData::Cuda(PooledCudaSlice::new(slice, false)),
device,
})),
offset: 0,
len,
}
}
pub fn as_cuda_slice(&self) -> CudaSliceReadGuard<'_> {
CudaSliceReadGuard {
guard: self.inner.read(),
}
}
pub fn as_cuda_slice_mut(&self) -> CudaSliceWriteGuard<'_> {
CudaSliceWriteGuard {
guard: self.inner.write(),
}
}
}
#[cfg(feature = "cuda")]
pub struct CudaSliceReadGuard<'a> {
guard: parking_lot::RwLockReadGuard<'a, StorageInner<f32>>,
}
#[cfg(feature = "cuda")]
impl<'a> CudaSliceReadGuard<'a> {
pub fn slice(&self) -> &CudaSlice<f32> {
match &self.guard.data {
StorageData::Cuda(pooled) => pooled.slice(),
StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
}
}
}
#[cfg(feature = "cuda")]
pub struct CudaSliceWriteGuard<'a> {
guard: parking_lot::RwLockWriteGuard<'a, StorageInner<f32>>,
}
#[cfg(feature = "cuda")]
impl<'a> CudaSliceWriteGuard<'a> {
pub fn slice_mut(&mut self) -> &mut CudaSlice<f32> {
match &mut self.guard.data {
StorageData::Cuda(pooled) => pooled.slice_mut(),
StorageData::Cpu(_) => panic!("Storage is on CPU, not GPU"),
}
}
}
impl<T: Scalar> Clone for Storage<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
offset: self.offset,
len: self.len,
}
}
}
pub struct StorageReadGuard<'a, T: Scalar> {
guard: parking_lot::RwLockReadGuard<'a, StorageInner<T>>,
offset: usize,
len: usize,
}
impl<T: Scalar> Deref for StorageReadGuard<'_, T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
match &self.guard.data {
StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
#[cfg(feature = "cuda")]
StorageData::Cuda(_) => panic!(
"Cannot access GPU storage as CPU slice. Use to_vec() for device-safe access."
),
}
}
}
pub struct StorageWriteGuard<'a, T: Scalar> {
guard: parking_lot::RwLockWriteGuard<'a, StorageInner<T>>,
offset: usize,
len: usize,
}
impl<T: Scalar> Deref for StorageWriteGuard<'_, T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
match &self.guard.data {
StorageData::Cpu(data) => &data[self.offset..self.offset + self.len],
#[cfg(feature = "cuda")]
StorageData::Cuda(_) => panic!("Cannot access GPU storage as CPU slice."),
}
}
}
impl<T: Scalar> DerefMut for StorageWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match &mut self.guard.data {
StorageData::Cpu(data) => &mut data[self.offset..self.offset + self.len],
#[cfg(feature = "cuda")]
StorageData::Cuda(_) => panic!("Cannot access GPU storage as mutable CPU slice."),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_storage_zeros() {
let storage = Storage::<f32>::zeros(10, Device::Cpu);
assert_eq!(storage.len(), 10);
assert!(!storage.is_empty());
let data = storage.as_slice();
for &val in data.iter() {
assert_eq!(val, 0.0);
}
}
#[test]
fn test_storage_from_vec() {
let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let storage = Storage::from_vec(vec.clone(), Device::Cpu);
let data = storage.as_slice();
assert_eq!(&*data, &vec[..]);
}
#[test]
fn test_storage_slice() {
let vec = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let storage = Storage::from_vec(vec, Device::Cpu);
let slice = storage.slice(1, 3).unwrap();
assert_eq!(slice.len(), 3);
let data = slice.as_slice();
assert_eq!(&*data, &[2.0, 3.0, 4.0]);
}
#[test]
fn test_storage_clone_shares() {
let storage1 = Storage::<f32>::zeros(10, Device::Cpu);
let storage2 = storage1.clone();
assert!(!storage1.is_unique());
assert!(!storage2.is_unique());
}
#[test]
fn test_storage_deep_copy() {
let storage1 = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
let storage2 = storage1.deep_copy();
assert!(storage1.is_unique());
assert!(storage2.is_unique());
storage2.as_slice_mut()[0] = 99.0;
assert_eq!(storage1.as_slice()[0], 1.0);
}
#[test]
fn test_storage_copy_from() {
let src = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
let dst = Storage::<f32>::zeros(3, Device::Cpu);
dst.copy_from(&src).unwrap();
let data = dst.as_slice();
assert_eq!(&*data, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_storage_slice_out_of_bounds() {
let storage = Storage::<f32>::zeros(10, Device::Cpu);
let result = storage.slice(5, 10);
assert!(result.is_err());
}
#[test]
fn test_storage_to_vec_cpu() {
let storage = Storage::from_vec(vec![1.0_f32, 2.0, 3.0], Device::Cpu);
assert_eq!(storage.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_storage_is_cpu() {
let storage = Storage::from_vec(vec![1.0_f32], Device::Cpu);
assert!(storage.is_cpu());
assert!(!storage.is_gpu());
}
}