cubecl_matmul/components/batch/
layout.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::layout::*;
4
5#[derive(CubeType, Clone, Copy)]
8pub struct SliceIndex {
9 offset: u32,
10 shape: Coords2d,
11}
12
13#[cube]
14impl SliceIndex {
15 pub fn new(offset: u32, shape: Coords3d) -> Self {
16 let (_, rows, cols) = shape;
17 SliceIndex {
18 offset,
19 shape: (rows, cols),
20 }
21 }
22}
23
24#[cube]
25impl Layout for SliceIndex {
26 type Coordinates = Coords2d;
27 type SourceCoordinates = Coords3d;
28
29 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
30 let (row, col) = pos;
31 (self.offset, row, col)
32 }
33
34 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
35 true.runtime()
37 }
38
39 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
40 (self.to_source_pos(pos), self.is_in_bounds(pos))
41 }
42
43 fn shape(&self) -> Self::Coordinates {
44 self.shape
45 }
46}