use alloc::rc::Rc;
use cubecl::prelude::*;
use cubecl_core::{self as cubecl, ir::UIntKind, unexpanded, zspace::Shape};
use crate::tensor::{
View, is_contiguous, is_contiguous_pitched,
launch::{BufferArg, ConcreteLayout, ConcreteLayoutLaunch, ViewArg, ViewLayoutLaunchArg},
layout::{
Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
permuted::{PermutedLayout, PermutedLayoutCompilationArg, PermutedLayoutLaunch},
plain::PlainLayout,
strided::{StridedLayout, StridedLayoutCompilationArg},
},
};
#[derive(CubeType, Clone)]
pub enum LinearViewLayout {
Plain(PlainLayout),
Strided(StridedLayout),
Permuted(PermutedLayout),
}
impl LinearViewLayout {
fn inner(&self) -> &PlainLayout {
unexpanded!()
}
}
impl LinearViewLayoutExpand {
fn __expand_inner_method(
self,
_scope: &mut Scope,
) -> Rc<dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d>> {
match self {
LinearViewLayoutExpand::Plain(layout) => Rc::new(layout),
LinearViewLayoutExpand::Strided(layout) => Rc::new(layout),
LinearViewLayoutExpand::Permuted(layout) => Rc::new(layout),
}
}
}
#[derive(Default)]
pub struct LinearViewLayoutLaunch {
reference_shape: Option<Shape>,
}
impl ViewLayoutLaunchArg for LinearViewLayout {
type RuntimeArg<R: Runtime> = LinearViewLayoutLaunch;
type CompilationArg = LinearLayoutCompilationArg;
fn register<R: Runtime, B: BufferArg>(
runtime_arg: Self::RuntimeArg<R>,
buffer: &B,
ty: Type,
launcher: &mut KernelLauncher<R>,
) -> Self::CompilationArg {
let shape = buffer.shape();
match runtime_arg.reference_shape {
Some(reference_shape) if reference_shape.as_slice() != shape => {
let arg = PermutedLayoutLaunch::from_reference_shape(reference_shape);
let comp_arg = PermutedLayout::register(arg, buffer, ty, launcher);
LinearLayoutCompilationArg::Permuted(comp_arg)
}
_ => {
let strides = buffer.strides();
if is_contiguous(shape, strides) {
PlainLayout::register((), buffer, ty, launcher);
LinearLayoutCompilationArg::Plain
} else if is_contiguous_pitched(shape, strides) {
let comp_arg = StridedLayout::register((), buffer, ty, launcher);
LinearLayoutCompilationArg::Strided(comp_arg)
} else {
let comp_arg =
PermutedLayout::register(Default::default(), buffer, ty, launcher);
LinearLayoutCompilationArg::Permuted(comp_arg)
}
}
}
}
fn expand(
compilation_arg: &Self::CompilationArg,
ty: Type,
builder: &mut cubecl::prelude::KernelBuilder,
) -> <Self as cubecl::prelude::CubeType>::ExpandType {
match compilation_arg {
LinearLayoutCompilationArg::Plain => {
LinearViewLayoutExpand::Plain(PlainLayout::expand(&(), ty, builder))
}
LinearLayoutCompilationArg::Strided(arg) => {
LinearViewLayoutExpand::Strided(StridedLayout::expand(arg, ty, builder))
}
LinearLayoutCompilationArg::Permuted(arg) => {
LinearViewLayoutExpand::Permuted(PermutedLayout::expand(arg, ty, builder))
}
}
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub enum LinearLayoutCompilationArg {
Plain,
Strided(StridedLayoutCompilationArg),
Permuted(PermutedLayoutCompilationArg),
}
impl LinearViewLayoutLaunch {
pub fn new() -> Self {
Self::default()
}
pub fn from_reference_shape(reference_shape: Shape) -> Self {
Self {
reference_shape: Some(reference_shape),
}
}
pub fn from_reference_handle<R: Runtime>(reference: TensorBinding<R>) -> Self {
Self::from_reference_shape(reference.shape)
}
}
#[cube]
impl Layout for LinearViewLayout {
type Coordinates = Coords1d;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
self.inner().to_source_pos(pos)
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
fn shape(&self) -> Self::Coordinates {
self.inner().shape()
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
self.inner().is_in_bounds(pos)
}
}
pub type LinearLayout = ConcreteLayout<LinearViewLayout>;
pub type LinearLayoutLaunch<R> = ConcreteLayoutLaunch<LinearViewLayout, R>;
pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
pub type LinearViewLaunch<R> = ViewArg<Coords1d, R>;
pub fn linear_layout<R: Runtime>(
handle: &TensorBinding<R>,
vector_size: VectorSize,
) -> LinearLayoutLaunch<R> {
LinearLayoutLaunch::from_handle(
handle,
Type::new(UIntKind::U32.into()).with_vector_size(vector_size),
LinearViewLayoutLaunch::new(),
)
}
pub fn linear_view<R: Runtime>(handle: TensorBinding<R>) -> LinearViewLaunch<R> {
let layout = LinearViewLayoutLaunch::new();
LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.into_tensor_arg(), layout)
}
pub fn linear_view_with_reference<R: Runtime>(
handle: TensorBinding<R>,
reference: TensorBinding<R>,
) -> LinearViewLaunch<R> {
let layout = LinearViewLayoutLaunch::from_reference_handle(reference);
LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.into_tensor_arg(), layout)
}
pub fn linear_view_alias<R: Runtime>(handle: &TensorBinding<R>, pos: usize) -> LinearViewLaunch<R> {
let layout = LinearViewLayoutLaunch::new();
LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.as_alias(pos), layout)
}