use core::marker::PhantomData;
use cubecl_ir::AddressType;
use cubecl_runtime::{runtime::Runtime, server::CopyDescriptor};
use cubecl_zspace::{Shape, Strides};
use serde::{Deserialize, Serialize};
use crate::{
compute::{KernelBuilder, KernelLauncher},
ir::Id,
prelude::{ArrayArg, ArrayBinding, CubePrimitive, LaunchArg, NativeExpand},
};
use super::Tensor;
#[derive(Debug)]
pub enum TensorArg<R: Runtime> {
Handle {
handle: TensorBinding<R>,
},
Alias {
input_pos: usize,
strides: Strides,
shape: Shape,
},
}
pub struct TensorBinding<R: Runtime> {
pub handle: cubecl_runtime::server::Binding,
pub strides: Strides,
pub shape: Shape,
pub runtime: PhantomData<R>,
}
impl<R: Runtime> Clone for TensorBinding<R> {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
strides: self.strides.clone(),
shape: self.shape.clone(),
runtime: PhantomData,
}
}
}
impl<R: Runtime> TensorBinding<R> {
pub fn size(&self) -> usize {
self.shape.iter().product()
}
pub fn required_address_type(&self, elem_size: usize) -> AddressType {
AddressType::from_len(self.handle.size() as usize / elem_size)
}
}
impl<R: Runtime> core::fmt::Debug for TensorBinding<R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(
f,
"TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
self.strides, self.shape
)
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub struct TensorCompilationArg {
pub inplace: Option<Id>,
}
impl<C: CubePrimitive> LaunchArg for Tensor<C> {
type RuntimeArg<R: Runtime> = TensorArg<R>;
type CompilationArg = TensorCompilationArg;
fn register<R: Runtime>(
arg: Self::RuntimeArg<R>,
launcher: &mut KernelLauncher<R>,
) -> Self::CompilationArg {
let ty = launcher.with_scope(|scope| C::as_type(scope));
let compilation_arg = match &arg {
TensorArg::Handle { .. } => TensorCompilationArg { inplace: None },
TensorArg::Alias { input_pos, .. } => TensorCompilationArg {
inplace: Some(*input_pos as Id),
},
};
launcher.register_tensor(arg, ty);
compilation_arg
}
fn expand(_arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> NativeExpand<Tensor<C>> {
builder.input_tensor(C::as_type(&builder.scope)).into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> NativeExpand<Tensor<C>> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder.output_tensor(C::as_type(&builder.scope)).into(),
}
}
}
impl<R: Runtime> TensorArg<R> {
pub unsafe fn from_raw_parts(
handle: cubecl_runtime::server::Handle,
strides: Strides,
shape: Shape,
) -> Self {
unsafe { Self::from_raw_parts_binding(handle.binding(), strides, shape) }
}
pub(crate) unsafe fn from_raw_parts_binding(
handle: cubecl_runtime::server::Binding,
strides: Strides,
shape: Shape,
) -> Self {
unsafe {
Self::Handle {
handle: TensorBinding::from_raw_parts_binding(handle, strides, shape),
}
}
}
pub fn into_alias(self, position: usize) -> Self {
match self {
TensorArg::Handle { handle } => handle.into_alias(position),
alias @ TensorArg::Alias { .. } => alias,
}
}
pub fn size(&self) -> usize {
match self {
TensorArg::Handle { handle } => handle.size(),
TensorArg::Alias { shape, .. } => shape.iter().product(),
}
}
pub fn shape(&self) -> &[usize] {
match self {
TensorArg::Handle { handle } => &handle.shape,
TensorArg::Alias { shape, .. } => shape,
}
}
pub fn strides(&self) -> &[usize] {
match self {
TensorArg::Handle { handle } => &handle.strides,
TensorArg::Alias { strides, .. } => strides,
}
}
}
impl<R: Runtime> TensorArg<R> {
pub fn into_array_arg(self) -> ArrayArg<R> {
match self {
TensorArg::Handle { handle } => {
let handle = unsafe {
let size = handle.size();
ArrayBinding::from_raw_parts_binding(handle.handle, size)
};
ArrayArg::Handle { handle }
}
TensorArg::Alias {
input_pos, shape, ..
} => ArrayArg::Alias {
input_pos,
length: [shape.iter().product()],
},
}
}
}
impl<R: Runtime> TensorBinding<R> {
pub fn into_tensor_arg(self) -> TensorArg<R> {
unsafe { TensorArg::from_raw_parts_binding(self.handle, self.strides, self.shape) }
}
pub fn into_alias(self, index: usize) -> TensorArg<R> {
TensorArg::Alias {
input_pos: index,
strides: self.strides,
shape: self.shape,
}
}
pub fn as_alias(&self, index: usize) -> TensorArg<R> {
TensorArg::Alias {
input_pos: index,
strides: self.strides.clone(),
shape: self.shape.clone(),
}
}
pub fn into_array_arg(self) -> ArrayArg<R> {
let length = self.shape.iter().product();
unsafe { ArrayArg::from_raw_parts_binding(self.handle, length) }
}
pub unsafe fn from_raw_parts(
handle: cubecl_runtime::server::Handle,
strides: Strides,
shape: Shape,
) -> Self {
unsafe { Self::from_raw_parts_binding(handle.binding(), strides, shape) }
}
pub unsafe fn from_raw_parts_binding(
handle: cubecl_runtime::server::Binding,
strides: Strides,
shape: Shape,
) -> Self {
Self {
handle,
strides,
shape,
runtime: PhantomData,
}
}
pub fn into_copy_descriptor(self, elem_size: usize) -> CopyDescriptor {
CopyDescriptor {
handle: self.handle,
shape: self.shape,
strides: self.strides,
elem_size,
}
}
}