Skip to main content

cubek_convolution/components/global/layout/
bias.rs

1use cubecl::{
2    std::tensor::{launch::BufferArg, layout::*},
3    {prelude::*, std::tensor::launch::ViewLayoutLaunchArg},
4};
5use cubek_matmul::launch::BatchedCoords;
6
7#[derive(CubeType)]
8pub struct BiasLayout {
9    shape: u32,
10    #[cube(comptime)]
11    vector_size: u32,
12}
13
14#[cube]
15impl Layout for BiasLayout {
16    type Coordinates = BatchedCoords;
17    type SourceCoordinates = Coords1d;
18
19    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
20        let (_, _, n) = pos;
21        (n / self.vector_size) as usize
22    }
23
24    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
25        let (_, _, n) = pos;
26        n < self.shape
27    }
28
29    fn shape(&self) -> Self::Coordinates {
30        (1, 1, self.shape)
31    }
32
33    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
34        (self.to_source_pos(pos), self.is_in_bounds(pos))
35    }
36}
37
38impl ViewLayoutLaunchArg for BiasLayout {
39    type RuntimeArg<R: Runtime> = ();
40    type CompilationArg = ();
41
42    fn register<R: Runtime, B: BufferArg>(
43        _: Self::RuntimeArg<R>,
44        buffer: &B,
45        _: Type,
46        launcher: &mut KernelLauncher<R>,
47    ) {
48        let shape = buffer.len();
49        <u32 as LaunchArg>::register(shape as u32, launcher);
50    }
51
52    fn expand(
53        _: &Self::CompilationArg,
54        ty: Type,
55        builder: &mut KernelBuilder,
56    ) -> <Self as CubeType>::ExpandType {
57        BiasLayoutExpand {
58            shape: <u32 as LaunchArg>::expand(&(), builder),
59            vector_size: ty.vector_size() as u32,
60        }
61    }
62}