use cubecl::prelude::*;
use cubecl_core as cubecl;
use crate::tensor::layout::{
Layout, LayoutExpand,
as_dyn::{IntoDyn, IntoDynExpand},
};
#[derive(CubeType, CubeLaunch)]
pub struct FixedDimLayout<D: IntoDyn> {
shape: D,
strides: Sequence<usize>,
#[cube(comptime)]
vector_size: VectorSize,
#[cube(comptime)]
checked: bool,
}
#[cube]
impl<D: IntoDyn> FixedDimLayout<D> {
pub fn new(
shape: D,
strides: Sequence<usize>,
#[comptime] vector_size: VectorSize,
#[comptime] checked: bool,
) -> Self {
FixedDimLayout::<D> {
shape,
strides,
vector_size,
checked,
}
}
}
#[cube]
impl<D: IntoDyn> Layout for FixedDimLayout<D> {
type Coordinates = D;
type SourceCoordinates = usize;
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
let pos = pos.into_dyn();
let mut offset = 0;
#[unroll]
for i in 0..pos.len() {
offset += pos[i] as usize * self.strides[i];
}
offset / self.vector_size
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
let mut in_bounds = true;
if comptime![self.checked] {
let pos = pos.into_dyn();
let shape = self.shape.clone().into_dyn();
#[unroll]
for i in 0..pos.len() {
in_bounds &= pos[i] < shape[i];
}
}
in_bounds
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
(self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
}
fn shape(&self) -> Self::Coordinates {
self.shape.clone()
}
}
impl<D: IntoDyn, R: Runtime> FixedDimLayoutLaunch<D, R> {
pub fn from_shape_handle(
handle: &TensorBinding<R>,
shape: D::RuntimeArg<R>,
vector_size: VectorSize,
) -> Self {
let strides = handle.strides.iter().copied().collect();
Self::new(shape, strides, vector_size, true)
}
pub fn from_shape_handle_unchecked(
handle: &TensorBinding<R>,
shape: D::RuntimeArg<R>,
vector_size: VectorSize,
) -> Self {
let strides = handle.strides.iter().copied().collect();
Self::new(shape, strides, vector_size, false)
}
}