use cubecl::{prelude::*, std::tensor::layout::*};
use crate::launch::BatchedCoords;
#[derive(CubeType, Clone, Copy)]
pub struct SliceIndex {
offset: usize,
shape: Coords2d,
}
#[cube]
impl SliceIndex {
pub fn new(offset: usize, shape: BatchedCoords) -> Self {
let (_, rows, cols) = shape;
SliceIndex {
offset,
shape: (rows, cols),
}
}
}
#[cube]
impl Layout for SliceIndex {
type Coordinates = Coords2d;
type SourceCoordinates = (usize, u32, u32);
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
let (row, col) = pos;
(self.offset, row, col)
}
fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
true.runtime()
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
fn shape(&self) -> Self::Coordinates {
self.shape
}
}