use core::marker::PhantomData;
use cubecl_runtime::runtime::Runtime;
use serde::{Deserialize, Serialize};
use crate::{
compute::{KernelBuilder, KernelLauncher},
ir::Id,
prelude::{CubePrimitive, LaunchArg, NativeExpand, TensorBinding},
};
use super::Array;
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub struct ArrayCompilationArg {
pub inplace: Option<Id>,
}
pub struct ArrayBinding<R: Runtime> {
pub handle: cubecl_runtime::server::Binding,
pub(crate) length: [usize; 1],
runtime: PhantomData<R>,
}
pub enum ArrayArg<R: Runtime> {
Handle {
handle: ArrayBinding<R>,
},
Alias {
input_pos: usize,
length: [usize; 1],
},
}
impl<R: Runtime> ArrayArg<R> {
pub unsafe fn from_raw_parts(handle: cubecl_runtime::server::Handle, length: usize) -> Self {
unsafe {
ArrayArg::Handle {
handle: ArrayBinding::from_raw_parts(handle, length),
}
}
}
pub unsafe fn from_raw_parts_binding(
binding: cubecl_runtime::server::Binding,
length: usize,
) -> Self {
unsafe {
ArrayArg::Handle {
handle: ArrayBinding::from_raw_parts_binding(binding, length),
}
}
}
pub fn size(&self) -> usize {
match self {
ArrayArg::Handle { handle } => handle.length[0],
ArrayArg::Alias { length, .. } => length[0],
}
}
pub fn shape(&self) -> &[usize] {
match self {
ArrayArg::Handle { handle } => &handle.length,
ArrayArg::Alias { length, .. } => length,
}
}
}
impl<R: Runtime> ArrayBinding<R> {
pub unsafe fn from_raw_parts(handle: cubecl_runtime::server::Handle, length: usize) -> Self {
unsafe { Self::from_raw_parts_binding(handle.binding(), length) }
}
pub unsafe fn from_raw_parts_binding(
handle: cubecl_runtime::server::Binding,
length: usize,
) -> Self {
Self {
handle,
length: [length],
runtime: PhantomData,
}
}
pub fn into_tensor(self) -> TensorBinding<R> {
let shape = self.length.into();
TensorBinding {
handle: self.handle,
strides: [1].into(),
shape,
runtime: PhantomData,
}
}
}
impl<C: CubePrimitive> LaunchArg for Array<C> {
type RuntimeArg<R: Runtime> = ArrayArg<R>;
type CompilationArg = ArrayCompilationArg;
fn register<R: Runtime>(
arg: Self::RuntimeArg<R>,
launcher: &mut KernelLauncher<R>,
) -> Self::CompilationArg {
let ty = launcher.with_scope(|scope| C::as_type(scope));
let compilation_arg = match &arg {
ArrayArg::Handle { .. } => ArrayCompilationArg { inplace: None },
ArrayArg::Alias { input_pos, .. } => ArrayCompilationArg {
inplace: Some(*input_pos as Id),
},
};
launcher.register_array(arg, ty);
compilation_arg
}
fn expand(_arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> NativeExpand<Array<C>> {
let ty = C::as_type(&builder.scope);
builder.input_array(ty).into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> NativeExpand<Array<C>> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder.output_array(C::as_type(&builder.scope)).into(),
}
}
}