cubecl_core/frontend/container/array/
launch.rsuse std::{marker::PhantomData, num::NonZero};
use crate::{
compute::{KernelBuilder, KernelLauncher},
ir::{Item, Vectorization},
prelude::{
ArgSettings, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand, TensorHandleRef,
},
Runtime,
};
use super::Array;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ArrayCompilationArg {
inplace: Option<u16>,
vectorisation: Vectorization,
}
pub struct ArrayHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle,
pub(crate) length: [usize; 1],
runtime: PhantomData<R>,
}
impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
type CompilationArg = ArrayCompilationArg;
fn expand(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Array<C>> {
builder
.input_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Array<C>> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder
.output_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.into(),
}
}
}
pub enum ArrayArg<'a, R: Runtime> {
Handle {
handle: ArrayHandleRef<'a, R>,
vectorization_factor: u8,
},
Alias {
input_pos: usize,
},
}
impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
if let ArrayArg::Handle {
handle,
vectorization_factor: _,
} = self
{
launcher.register_array(handle)
}
}
}
impl<'a, R: Runtime> ArrayArg<'a, R> {
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
vectorization_factor: u8,
) -> Self {
ArrayArg::Handle {
handle: ArrayHandleRef::from_raw_parts(handle, length),
vectorization_factor,
}
}
}
impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
) -> Self {
Self {
handle,
length: [length],
runtime: PhantomData,
}
}
pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
let shape = &self.length;
TensorHandleRef {
handle: self.handle,
strides: &[1],
shape,
runtime: PhantomData,
}
}
}
impl<C: CubePrimitive> LaunchArg for Array<C> {
type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
ArrayArg::Handle {
handle: _,
vectorization_factor,
} => ArrayCompilationArg {
inplace: None,
vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
},
ArrayArg::Alias { input_pos } => ArrayCompilationArg {
inplace: Some(*input_pos as u16),
vectorisation: Vectorization::None,
},
}
}
}