use cubecl::prelude::*;
use cubecl::server::Handle;
use cubecl::std::tensor::compact_strides;
use std::marker::PhantomData;
pub struct GpuTensor<R: Runtime, F: CubeElement + Numeric> {
data: Handle,
shape: Vec<usize>,
strides: Vec<usize>,
_r: PhantomData<R>,
_f: PhantomData<F>,
}
impl<R: Runtime, F: CubeElement + Numeric> Clone for GpuTensor<R, F> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
shape: self.shape.clone(),
strides: self.strides.clone(),
_r: PhantomData,
_f: PhantomData,
}
}
}
impl<R: Runtime, F: Numeric + CubeElement> GpuTensor<R, F> {
pub fn from_slice(data: &[F], shape: Vec<usize>, client: &ComputeClient<R>) -> Self {
let handle = client.create_from_slice(F::as_bytes(data));
let strides = compact_strides(&shape);
Self {
data: handle,
shape,
strides,
_r: PhantomData,
_f: PhantomData,
}
}
pub fn empty(shape: Vec<usize>, client: &ComputeClient<R>) -> Self {
let size = shape.iter().product::<usize>() * core::mem::size_of::<F>();
let handle = client.empty(size);
let strides = compact_strides(&shape);
Self {
data: handle,
shape,
strides,
_r: PhantomData,
_f: PhantomData,
}
}
pub fn into_tensor_arg(&self, line_size: usize) -> TensorArg<'_, R> {
unsafe { TensorArg::from_raw_parts::<F>(&self.data, &self.strides, &self.shape, line_size) }
}
pub fn read(self, client: &ComputeClient<R>) -> Vec<F> {
let bytes = client.read_one(self.data);
F::from_bytes(&bytes).to_vec()
}
pub fn vram_bytes(&self) -> usize {
self.shape.iter().product::<usize>() * std::mem::size_of::<F>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use cubecl::cpu::CpuDevice;
use cubecl::cpu::CpuRuntime;
#[test]
fn test_tensor_from_slice_and_read() {
let device = CpuDevice;
let client = CpuRuntime::client(&device);
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![2, 3];
let tensor = GpuTensor::<CpuRuntime, f32>::from_slice(&data, shape, &client);
let result = tensor.read(&client);
assert_eq!(result, data);
}
#[test]
fn test_tensor_empty() {
let device = CpuDevice;
let client = CpuRuntime::client(&device);
let shape = vec![3, 4];
let tensor = GpuTensor::<CpuRuntime, f32>::empty(shape.clone(), &client);
assert_eq!(tensor.shape, shape);
assert_eq!(tensor.strides, vec![4, 1]);
}
}