use std::sync::Arc;
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
use ferrotorch_core::FerrotorchError;
use ferrotorch_core::FerrotorchResult;
use ferrotorch_core::storage::CubeStorageHandle;
use crate::runtime::CubeRuntime;
#[derive(Debug)]
pub struct CubeclStorageHandle {
handle: cubecl::server::Handle,
runtime: Arc<CubeRuntime>,
len: usize,
ordinal: usize,
}
impl CubeclStorageHandle {
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
fn new(
handle: cubecl::server::Handle,
runtime: Arc<CubeRuntime>,
len: usize,
ordinal: usize,
) -> Self {
Self {
handle,
runtime,
len,
ordinal,
}
}
pub fn from_raw(
handle: cubecl::server::Handle,
runtime: Arc<CubeRuntime>,
len: usize,
ordinal: usize,
) -> Self {
Self {
handle,
runtime,
len,
ordinal,
}
}
pub fn raw_handle(&self) -> &cubecl::server::Handle {
&self.handle
}
pub fn runtime(&self) -> &Arc<CubeRuntime> {
&self.runtime
}
}
impl CubeStorageHandle for CubeclStorageHandle {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn len(&self) -> usize {
self.len
}
fn ordinal(&self) -> usize {
self.ordinal
}
fn read_to_host(&self) -> FerrotorchResult<Vec<f32>> {
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
{
self.runtime.read_f32s(self.handle.clone(), self.len)
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
{
Err(FerrotorchError::DeviceUnavailable)
}
}
fn clone_handle(&self) -> Box<dyn CubeStorageHandle> {
Box::new(CubeclStorageHandle {
handle: self.handle.clone(),
runtime: Arc::clone(&self.runtime),
len: self.len,
ordinal: self.ordinal,
})
}
}
pub fn wrap_kernel_output(
handle: cubecl::server::Handle,
shape: &[usize],
runtime: Arc<CubeRuntime>,
ordinal: usize,
) -> CubeclStorageHandle {
let numel: usize = shape.iter().product();
CubeclStorageHandle::from_raw(handle, runtime, numel, ordinal)
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn upload_f32(
data: &[f32],
runtime: Arc<CubeRuntime>,
ordinal: usize,
) -> FerrotorchResult<CubeclStorageHandle> {
use crate::runtime::CubeClient;
use cubecl::prelude::*;
let bytes = f32::as_bytes(data);
let handle = match runtime.client() {
#[cfg(feature = "wgpu")]
CubeClient::Wgpu(c) => c.create_from_slice(bytes),
#[cfg(feature = "cuda")]
CubeClient::Cuda(c) => c.create_from_slice(bytes),
#[cfg(feature = "rocm")]
CubeClient::Rocm(c) => c.create_from_slice(bytes),
};
Ok(CubeclStorageHandle::new(
handle,
runtime,
data.len(),
ordinal,
))
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn upload_f32(
_data: &[f32],
_runtime: Arc<CubeRuntime>,
_ordinal: usize,
) -> FerrotorchResult<CubeclStorageHandle> {
Err(FerrotorchError::DeviceUnavailable)
}
pub fn cubecl_handle_of(t: &ferrotorch_core::Tensor<f32>) -> Option<&CubeclStorageHandle> {
t.inner_storage_arc()
.cubecl_handle()
.and_then(|h| h.as_any().downcast_ref::<CubeclStorageHandle>())
}