use crate::device::Device;
use crate::dtype::Element;
use crate::gpu_dispatch::GpuBufferHandle;
pub trait CubeStorageHandle: std::fmt::Debug + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn ordinal(&self) -> usize;
fn read_to_host(&self) -> crate::error::FerrotorchResult<Vec<f32>>;
fn clone_handle(&self) -> Box<dyn CubeStorageHandle>;
}
#[derive(Debug)]
pub struct TensorStorage<T: Element> {
pub(crate) data: StorageBuffer<T>,
pub(crate) device: Device,
}
pub enum StorageBuffer<T: Element> {
Cpu(Vec<T>),
Gpu(GpuBufferHandle),
Cubecl(Box<dyn CubeStorageHandle>),
Meta { numel: usize, fill_value: Option<T> },
}
impl<T: Element> TensorStorage<T> {
pub fn cpu(data: Vec<T>) -> Self {
Self {
data: StorageBuffer::Cpu(data),
device: Device::Cpu,
}
}
pub fn meta(numel: usize) -> Self {
Self {
data: StorageBuffer::Meta {
numel,
fill_value: None,
},
device: Device::Meta,
}
}
pub fn meta_filled(numel: usize, value: T) -> Self {
Self {
data: StorageBuffer::Meta {
numel,
fill_value: Some(value),
},
device: Device::Meta,
}
}
pub fn meta_fill_value(&self) -> Option<&T> {
match &self.data {
StorageBuffer::Meta { fill_value, .. } => fill_value.as_ref(),
_ => None,
}
}
pub fn on_device(data: Vec<T>, target_device: Device) -> crate::error::FerrotorchResult<Self> {
match target_device {
Device::Cpu => Ok(Self::cpu(data)),
Device::Cuda(ordinal) => {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(crate::error::FerrotorchError::DeviceUnavailable)?;
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
data.as_ptr().cast::<u8>(),
data.len() * std::mem::size_of::<T>(),
)
};
let handle = backend.cpu_to_gpu(bytes, T::dtype(), ordinal)?;
Ok(Self::gpu(handle))
}
Device::Xpu(_) => Err(crate::error::FerrotorchError::InvalidArgument {
message: "XPU storage requires a CubeRuntime; use Tensor::to(Device::Xpu(n)) \
via ferrotorch-xpu instead of TensorStorage::on_device. Issue #673."
.into(),
}),
Device::Mps(_) => Err(crate::error::FerrotorchError::InvalidArgument {
message: "MPS storage requires the ferrotorch-mps backend; not yet wired into TensorStorage".into(),
}),
Device::Meta => {
Ok(Self::meta(data.len()))
}
}
}
pub fn on_device_pinned(
data: Vec<T>,
target_device: Device,
) -> crate::error::FerrotorchResult<Self> {
match target_device {
Device::Cpu => Ok(Self::cpu(data)),
Device::Cuda(ordinal) => {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(crate::error::FerrotorchError::DeviceUnavailable)?;
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
data.as_ptr().cast::<u8>(),
data.len() * std::mem::size_of::<T>(),
)
};
let handle = backend.cpu_to_gpu_pinned(bytes, T::dtype(), ordinal)?;
Ok(Self::gpu(handle))
}
Device::Xpu(_) => Err(crate::error::FerrotorchError::InvalidArgument {
message: "XPU storage requires a CubeRuntime; use Tensor::to(Device::Xpu(n)) \
via ferrotorch-xpu instead of TensorStorage::on_device_pinned. Issue #673."
.into(),
}),
Device::Mps(_) => Err(crate::error::FerrotorchError::InvalidArgument {
message: "MPS storage requires the ferrotorch-mps backend; not yet wired into TensorStorage".into(),
}),
Device::Meta => Ok(Self::meta(data.len())),
}
}
pub fn xpu_from_handle(handle: Box<dyn CubeStorageHandle>, ordinal: usize) -> Self {
Self {
data: StorageBuffer::Cubecl(handle),
device: Device::Xpu(ordinal),
}
}
pub fn gpu(handle: GpuBufferHandle) -> Self {
let device = Device::Cuda(handle.device_ordinal());
Self {
data: StorageBuffer::Gpu(handle),
device,
}
}
#[inline]
pub fn device(&self) -> Device {
self.device
}
pub fn len(&self) -> usize {
match &self.data {
StorageBuffer::Cpu(v) => v.len(),
StorageBuffer::Gpu(h) => h.len(),
StorageBuffer::Cubecl(h) => h.len(),
StorageBuffer::Meta { numel, .. } => *numel,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[deprecated(
since = "0.4.5",
note = "use try_as_slice() instead; this version panics on non-CPU storage"
)]
pub fn as_slice(&self) -> &[T] {
match &self.data {
StorageBuffer::Cpu(v) => v.as_slice(),
StorageBuffer::Gpu(_) => {
panic!("cannot access GPU tensor as CPU slice -- call .cpu() first")
}
StorageBuffer::Cubecl(_) => {
panic!("cannot access XPU tensor as CPU slice -- call .cpu() first")
}
StorageBuffer::Meta { .. } => {
panic!("cannot access meta tensor as a slice -- meta tensors carry no data")
}
}
}
#[deprecated(
since = "0.4.5",
note = "use try_as_mut_slice() instead; this version panics on non-CPU storage"
)]
pub fn as_mut_slice(&mut self) -> &mut [T] {
match &mut self.data {
StorageBuffer::Cpu(v) => v.as_mut_slice(),
StorageBuffer::Gpu(_) => {
panic!("cannot mutate GPU tensor as CPU slice -- call .cpu() first")
}
StorageBuffer::Cubecl(_) => {
panic!("cannot mutate XPU tensor as CPU slice -- call .cpu() first")
}
StorageBuffer::Meta { .. } => {
panic!("cannot mutate meta tensor as a slice -- meta tensors carry no data")
}
}
}
pub fn try_as_slice(&self) -> crate::error::FerrotorchResult<&[T]> {
match &self.data {
StorageBuffer::Cpu(v) => Ok(v.as_slice()),
StorageBuffer::Gpu(_) | StorageBuffer::Cubecl(_) | StorageBuffer::Meta { .. } => {
Err(crate::error::FerrotorchError::GpuTensorNotAccessible)
}
}
}
pub fn try_as_mut_slice(&mut self) -> crate::error::FerrotorchResult<&mut [T]> {
match &mut self.data {
StorageBuffer::Cpu(v) => Ok(v.as_mut_slice()),
StorageBuffer::Gpu(_) | StorageBuffer::Cubecl(_) | StorageBuffer::Meta { .. } => {
Err(crate::error::FerrotorchError::GpuTensorNotAccessible)
}
}
}
#[inline]
pub fn is_cpu(&self) -> bool {
matches!(&self.data, StorageBuffer::Cpu(_))
}
#[inline]
pub fn is_gpu(&self) -> bool {
matches!(&self.data, StorageBuffer::Gpu(_))
}
#[inline]
pub fn is_cubecl(&self) -> bool {
matches!(&self.data, StorageBuffer::Cubecl(_))
}
#[inline]
pub fn is_meta(&self) -> bool {
matches!(&self.data, StorageBuffer::Meta { .. })
}
pub fn gpu_handle(&self) -> Option<&GpuBufferHandle> {
match &self.data {
StorageBuffer::Gpu(h) => Some(h),
StorageBuffer::Cpu(_) | StorageBuffer::Cubecl(_) | StorageBuffer::Meta { .. } => None,
}
}
pub fn gpu_handle_mut(&mut self) -> Option<&mut GpuBufferHandle> {
match &mut self.data {
StorageBuffer::Gpu(h) => Some(h),
StorageBuffer::Cpu(_) | StorageBuffer::Cubecl(_) | StorageBuffer::Meta { .. } => None,
}
}
pub fn cubecl_handle(&self) -> Option<&dyn CubeStorageHandle> {
match &self.data {
StorageBuffer::Cubecl(h) => Some(h.as_ref()),
_ => None,
}
}
pub fn try_clone(&self) -> crate::error::FerrotorchResult<Self> {
match &self.data {
StorageBuffer::Cpu(v) => Ok(Self {
data: StorageBuffer::Cpu(v.clone()),
device: self.device,
}),
StorageBuffer::Gpu(h) => {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(crate::error::FerrotorchError::DeviceUnavailable)?;
let cloned = backend.clone_buffer(h)?;
Ok(Self {
data: StorageBuffer::Gpu(cloned),
device: self.device,
})
}
StorageBuffer::Cubecl(h) => {
let cloned = h.clone_handle();
Ok(Self {
data: StorageBuffer::Cubecl(cloned),
device: self.device,
})
}
StorageBuffer::Meta { numel, fill_value } => Ok(Self {
data: StorageBuffer::Meta {
numel: *numel,
fill_value: fill_value.clone(),
},
device: self.device,
}),
}
}
pub fn try_clone_subregion(
&self,
offset: usize,
numel: usize,
) -> crate::error::FerrotorchResult<Self> {
if offset == 0 && numel == self.len() {
return self.try_clone();
}
match &self.data {
StorageBuffer::Cpu(v) => {
let slice = &v[offset..offset + numel];
Ok(Self {
data: StorageBuffer::Cpu(slice.to_vec()),
device: self.device,
})
}
StorageBuffer::Gpu(h) => {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(crate::error::FerrotorchError::DeviceUnavailable)?;
let bytes = backend.gpu_to_cpu(h)?;
let elem_size = std::mem::size_of::<T>();
let start = offset * elem_size;
let end = (offset + numel) * elem_size;
let handle =
backend.cpu_to_gpu(&bytes[start..end], h.dtype(), h.device_ordinal())?;
Ok(Self {
data: StorageBuffer::Gpu(handle),
device: self.device,
})
}
StorageBuffer::Cubecl(h) => {
let all = h.read_to_host()?;
let slice = all[offset..offset + numel].to_vec();
let _ = slice;
Err(crate::error::FerrotorchError::InvalidArgument {
message: format!(
"try_clone_subregion on XPU storage is not yet supported \
(offset={offset}, numel={numel}); call .cpu() first. Issue #673."
),
})
}
StorageBuffer::Meta { .. } => Ok(Self::meta(numel)),
}
}
}
impl<T: Element> Clone for TensorStorage<T> {
fn clone(&self) -> Self {
match self.try_clone() {
Ok(cloned) => cloned,
Err(e) => panic!(
"TensorStorage::clone failed: {e}. \
Use TensorStorage::try_clone() to handle this case explicitly."
),
}
}
}
impl<T: Element> Drop for TensorStorage<T> {
fn drop(&mut self) {
if let StorageBuffer::Cpu(ref mut v) = self.data
&& !v.is_empty()
{
let buf = std::mem::take(v);
crate::cpu_pool::pool_return_cpu(buf);
}
}
}
impl<T: Element> std::fmt::Debug for StorageBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StorageBuffer::Cpu(v) => write!(f, "Cpu({} elements)", v.len()),
StorageBuffer::Gpu(h) => write!(f, "Gpu({h:?})"),
StorageBuffer::Cubecl(h) => {
write!(f, "Cubecl(ordinal={}, len={})", h.ordinal(), h.len())
}
StorageBuffer::Meta { numel, .. } => write!(f, "Meta({numel} elements)"),
}
}
}