cubecl_matmul/components/batch/
layout.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::layout::*;
4
5/// Slice the layout at a specific batch, and reduce its dimensionality
6/// Not general enough to be in cubecl-std
7#[derive(CubeType, Clone, Copy)]
8pub struct SliceIndex {
9    offset: u32,
10    shape: Coords2d,
11}
12
13#[cube]
14impl SliceIndex {
15    pub fn new(offset: u32, shape: Coords3d) -> Self {
16        let (_, rows, cols) = shape;
17        SliceIndex {
18            offset,
19            shape: (rows, cols),
20        }
21    }
22}
23
24#[cube]
25impl Layout for SliceIndex {
26    type Coordinates = Coords2d;
27    type SourceCoordinates = Coords3d;
28
29    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
30        let (row, col) = pos;
31        (self.offset, row, col)
32    }
33
34    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
35        // we don't check batch
36        true.runtime()
37    }
38
39    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
40        (self.to_source_pos(pos), self.is_in_bounds(pos))
41    }
42
43    fn shape(&self) -> Self::Coordinates {
44        self.shape
45    }
46}