use cubecl::prelude::*;
use cubecl_core::{self as cubecl, unexpanded};
use crate::tensor::{
View, is_contiguous, is_contiguous_pitched,
launch::ViewArg,
layout::{
Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
permuted::{PermutedLayout, PermutedLayoutLaunch},
plain::{PlainLayout, PlainLayoutLaunch},
strided::{StridedLayout, StridedLayoutLaunch},
},
};
#[derive(CubeType, CubeLaunch, Clone)]
pub enum LinearLayout {
Plain(PlainLayout),
Strided(StridedLayout),
Permuted(PermutedLayout),
}
impl LinearLayout {
fn inner(&self) -> &PlainLayout {
unexpanded!()
}
}
impl LinearLayoutExpand {
fn __expand_inner_method(
&self,
_scope: &mut Scope,
) -> &dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d> {
match self {
LinearLayoutExpand::Plain(layout) => layout,
LinearLayoutExpand::Strided(layout) => layout,
LinearLayoutExpand::Permuted(layout) => layout,
}
}
}
impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
pub fn from_shape_strides(
client: &ComputeClient<R::Server>,
shape: &[usize],
strides: &[usize],
line_size: u8,
) -> Self {
if is_contiguous(shape, strides) {
Self::Plain(PlainLayoutLaunch::from_shape(shape, line_size))
} else if is_contiguous_pitched(shape, strides) {
Self::Strided(StridedLayoutLaunch::from_shape_strides(
client, shape, strides, line_size,
))
} else {
Self::Permuted(PermutedLayoutLaunch::from_shape_strides(
client, shape, strides, line_size,
))
}
}
pub fn from_shape_strides_with_reference(
client: &ComputeClient<R::Server>,
shape: &[usize],
reference_shape: &[usize],
strides: &[usize],
line_size: u8,
) -> Self {
if shape != reference_shape {
Self::Permuted(PermutedLayoutLaunch::from_shapes_strides_ref(
client,
shape,
reference_shape,
strides,
line_size,
))
} else {
Self::from_shape_strides(client, shape, strides, line_size)
}
}
pub fn from_handle(
client: &ComputeClient<R::Server>,
handle: &TensorHandleRef<'a, R>,
line_size: u8,
) -> Self {
Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
}
pub fn from_handle_with_reference(
client: &ComputeClient<R::Server>,
handle: &TensorHandleRef<'a, R>,
reference: &TensorHandleRef<'a, R>,
line_size: u8,
) -> Self {
Self::from_shape_strides_with_reference(
client,
handle.shape,
reference.shape,
handle.strides,
line_size,
)
}
}
#[cube]
impl Layout for LinearLayout {
type Coordinates = Coords1d;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
self.inner().to_source_pos(pos)
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
fn shape(&self) -> Self::Coordinates {
self.inner().shape()
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
self.inner().is_in_bounds(pos)
}
}
pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
pub type LinearViewLaunch<'a, R> = ViewArg<'a, Coords1d, R>;
pub fn linear_view<'a, R: Runtime>(
client: &ComputeClient<R::Server>,
handle: &'a TensorHandleRef<'a, R>,
line_size: u8,
) -> LinearViewLaunch<'a, R> {
let len = handle.shape.iter().product::<usize>();
let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
let buffer = unsafe {
ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
};
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}
pub fn linear_view_with_reference<'a, R: Runtime>(
client: &ComputeClient<R::Server>,
handle: &'a TensorHandleRef<'a, R>,
reference: &'a TensorHandleRef<'a, R>,
line_size: u8,
) -> LinearViewLaunch<'a, R> {
let len = handle.shape.iter().product::<usize>();
let layout = LinearLayoutArgs::from_handle_with_reference(client, handle, reference, line_size);
let buffer = unsafe {
ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
};
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}
pub fn linear_view_alias<'a, R: Runtime>(
client: &ComputeClient<R::Server>,
handle: &'a TensorHandleRef<'a, R>,
line_size: u8,
pos: usize,
) -> LinearViewLaunch<'a, R> {
let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
let buffer = ArrayArg::Alias { input_pos: pos };
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}