cubek_convolution/components/global/layout/
bias.rs1use cubecl::prelude::*;
2use cubecl::std::tensor::layout::*;
3
4#[derive(CubeType, CubeLaunch)]
5pub struct BiasLayout {
6 shape: u32,
7 #[cube(comptime)]
8 line_size: u32,
9}
10
11#[cube]
12impl Layout for BiasLayout {
13 type Coordinates = Coords3d;
14 type SourceCoordinates = Coords1d;
15
16 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
17 let (_, _, n) = pos;
18 n / self.line_size
19 }
20
21 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
22 let (_, _, n) = pos;
23 n < self.shape
24 }
25
26 fn shape(&self) -> Self::Coordinates {
27 (1, 1, self.shape)
28 }
29
30 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
31 (self.to_source_pos(pos), self.is_in_bounds(pos))
32 }
33}