cubecl_core/frontend/container/tensor/
launch.rsuse std::{marker::PhantomData, num::NonZero};
use crate::{
compute::{KernelBuilder, KernelLauncher},
ir::{Item, Vectorization},
prelude::{ArgSettings, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand},
Runtime,
};
use super::Tensor;
#[derive(Debug)]
pub enum TensorArg<'a, R: Runtime> {
Handle {
handle: TensorHandleRef<'a, R>,
vectorization_factor: u8,
},
Alias {
input_pos: usize,
},
}
pub struct TensorHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle,
pub strides: &'a [usize],
pub shape: &'a [usize],
pub runtime: PhantomData<R>,
}
impl<'a, R: Runtime> core::fmt::Debug for TensorHandleRef<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
"TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
self.strides, self.shape
)
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct TensorCompilationArg {
inplace: Option<u16>,
vectorisation: Vectorization,
}
impl<C: CubePrimitive> LaunchArgExpand for Tensor<C> {
type CompilationArg = TensorCompilationArg;
fn expand(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Tensor<C>> {
builder
.input_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Tensor<C>> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder
.output_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.into(),
}
}
}
impl<C: CubePrimitive> LaunchArg for Tensor<C> {
type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
TensorArg::Handle {
handle: _,
vectorization_factor,
} => TensorCompilationArg {
inplace: None,
vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
},
TensorArg::Alias { input_pos } => TensorCompilationArg {
inplace: Some(*input_pos as u16),
vectorisation: Vectorization::None,
},
}
}
}
impl<'a, R: Runtime> TensorArg<'a, R> {
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
factor: u8,
) -> Self {
unsafe {
Self::Handle {
handle: TensorHandleRef::from_raw_parts(handle, strides, shape),
vectorization_factor: factor,
}
}
}
pub fn alias(position: usize) -> Self {
Self::Alias {
input_pos: position,
}
}
}
impl<'a, R: Runtime> ArgSettings<R> for TensorArg<'a, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
if let TensorArg::Handle {
handle,
vectorization_factor: _,
} = self
{
launcher.register_tensor(handle)
}
}
}
impl<'a, R: Runtime> TensorHandleRef<'a, R> {
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
unsafe { TensorArg::from_raw_parts(self.handle, self.strides, self.shape, vectorisation) }
}
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self {
handle,
strides,
shape,
runtime: PhantomData,
}
}
}