cubek_convolution/components/global/layout/
bias.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::layout::*;
3
4#[derive(CubeType, CubeLaunch)]
5pub struct BiasLayout {
6    shape: u32,
7    #[cube(comptime)]
8    line_size: u32,
9}
10
11#[cube]
12impl Layout for BiasLayout {
13    type Coordinates = Coords3d;
14    type SourceCoordinates = Coords1d;
15
16    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
17        let (_, _, n) = pos;
18        n / self.line_size
19    }
20
21    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
22        let (_, _, n) = pos;
23        n < self.shape
24    }
25
26    fn shape(&self) -> Self::Coordinates {
27        (1, 1, self.shape)
28    }
29
30    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
31        (self.to_source_pos(pos), self.is_in_bounds(pos))
32    }
33}