use std::marker::PhantomData;
use cubecl_runtime::runtime::Runtime;
use serde::{Deserialize, Serialize};
use crate::{
compute::{KernelBuilder, KernelLauncher},
ir::{Id, LineSize, Type},
prelude::{
ArgSettings, ArrayArg, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg,
},
};
use super::Tensor;
#[derive(Debug)]
pub enum TensorArg<'a, R: Runtime> {
Handle {
handle: TensorHandleRef<'a, R>,
line_size: LineSize,
},
Alias {
input_pos: usize,
},
}
pub struct TensorHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle,
pub strides: &'a [usize],
pub shape: &'a [usize],
pub elem_size: usize,
pub runtime: PhantomData<R>,
}
impl<'a, R: Runtime> Clone for TensorHandleRef<'a, R> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, R: Runtime> Copy for TensorHandleRef<'a, R> {}
impl<R: Runtime> TensorHandleRef<'_, R> {
pub fn size(&self) -> usize {
self.shape.iter().product()
}
}
impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
"TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
self.strides, self.shape
)
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub struct TensorCompilationArg {
pub inplace: Option<Id>,
pub line_size: LineSize,
}
impl CompilationArg for TensorCompilationArg {}
impl<C: CubePrimitive> LaunchArg for Tensor<C> {
type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
type CompilationArg = TensorCompilationArg;
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
TensorArg::Handle { line_size, .. } => TensorCompilationArg {
inplace: None,
line_size: *line_size as LineSize,
},
TensorArg::Alias { input_pos } => TensorCompilationArg {
inplace: Some(*input_pos as Id),
line_size: 0,
},
}
}
fn expand(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Tensor<C>> {
builder
.input_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
.into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Tensor<C>> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder
.output_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
.into(),
}
}
}
impl<'a, R: Runtime> TensorArg<'a, R> {
pub unsafe fn from_raw_parts<E: CubePrimitive>(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
factor: LineSize,
) -> Self {
unsafe {
Self::Handle {
handle: TensorHandleRef::from_raw_parts(
handle,
strides,
shape,
E::size().expect("Element should have a size"),
),
line_size: factor,
}
}
}
pub unsafe fn from_raw_parts_and_size(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
factor: LineSize,
elem_size: usize,
) -> Self {
unsafe {
Self::Handle {
handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
line_size: factor,
}
}
}
pub fn alias(position: usize) -> Self {
Self::Alias {
input_pos: position,
}
}
}
impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_tensor(self);
}
}
impl<'a, R: Runtime> TensorHandleRef<'a, R> {
pub fn as_tensor_arg(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
unsafe {
TensorArg::from_raw_parts_and_size(
self.handle,
self.strides,
self.shape,
line_size,
self.elem_size,
)
}
}
pub fn as_array_arg(&'a self, line_size: LineSize) -> ArrayArg<'a, R> {
let length = self.shape.iter().product();
unsafe { ArrayArg::from_raw_parts_and_size(self.handle, length, line_size, self.elem_size) }
}
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
elem_size: usize,
) -> Self {
Self {
handle,
strides,
shape,
elem_size,
runtime: PhantomData,
}
}
}