use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use crate::{
Runtime,
compute::{KernelBuilder, KernelLauncher},
ir::{Id, LineSize, Type},
prelude::{
ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, TensorHandleRef,
},
};
use super::Array;
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub struct ArrayCompilationArg {
pub inplace: Option<Id>,
pub line_size: LineSize,
}
impl CompilationArg for ArrayCompilationArg {}
pub struct ArrayHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle,
pub(crate) length: [usize; 1],
pub elem_size: usize,
runtime: PhantomData<R>,
}
pub enum ArrayArg<'a, R: Runtime> {
Handle {
handle: ArrayHandleRef<'a, R>,
line_size: u8,
},
Alias {
input_pos: usize,
},
}
impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_array(self)
}
}
impl<'a, R: Runtime> ArrayArg<'a, R> {
pub unsafe fn from_raw_parts<E: CubePrimitive>(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
line_size: u8,
) -> Self {
unsafe {
ArrayArg::Handle {
handle: ArrayHandleRef::from_raw_parts(
handle,
length,
E::size().expect("Element should have a size"),
),
line_size,
}
}
}
pub unsafe fn from_raw_parts_and_size(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
line_size: u8,
elem_size: usize,
) -> Self {
unsafe {
ArrayArg::Handle {
handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
line_size,
}
}
}
}
impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
elem_size: usize,
) -> Self {
Self {
handle,
length: [length],
elem_size,
runtime: PhantomData,
}
}
pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
let shape = &self.length;
TensorHandleRef {
handle: self.handle,
strides: &[1],
shape,
elem_size: self.elem_size,
runtime: PhantomData,
}
}
}
impl<C: CubePrimitive> LaunchArg for Array<C> {
type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
type CompilationArg = ArrayCompilationArg;
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
ArrayArg::Handle { line_size, .. } => ArrayCompilationArg {
inplace: None,
line_size: *line_size as u32,
},
ArrayArg::Alias { input_pos } => ArrayCompilationArg {
inplace: Some(*input_pos as Id),
line_size: 0,
},
}
}
fn expand(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Array<C>> {
builder
.input_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
.into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Array<C>> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder
.output_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
.into(),
}
}
}