cubek_convolution/components/global/layout/
bias.rs1use cubecl::{
2 std::tensor::{launch::BufferArg, layout::*},
3 {prelude::*, std::tensor::launch::ViewLayoutLaunchArg},
4};
5use cubek_matmul::launch::BatchedCoords;
6
7#[derive(CubeType)]
8pub struct BiasLayout {
9 shape: u32,
10 #[cube(comptime)]
11 vector_size: u32,
12}
13
14#[cube]
15impl Layout for BiasLayout {
16 type Coordinates = BatchedCoords;
17 type SourceCoordinates = Coords1d;
18
19 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
20 let (_, _, n) = pos;
21 (n / self.vector_size) as usize
22 }
23
24 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
25 let (_, _, n) = pos;
26 n < self.shape
27 }
28
29 fn shape(&self) -> Self::Coordinates {
30 (1, 1, self.shape)
31 }
32
33 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
34 (self.to_source_pos(pos), self.is_in_bounds(pos))
35 }
36}
37
38impl ViewLayoutLaunchArg for BiasLayout {
39 type RuntimeArg<R: Runtime> = ();
40 type CompilationArg = ();
41
42 fn register<R: Runtime, B: BufferArg>(
43 _: Self::RuntimeArg<R>,
44 buffer: &B,
45 _: Type,
46 launcher: &mut KernelLauncher<R>,
47 ) {
48 let shape = buffer.len();
49 <u32 as LaunchArg>::register(shape as u32, launcher);
50 }
51
52 fn expand(
53 _: &Self::CompilationArg,
54 ty: Type,
55 builder: &mut KernelBuilder,
56 ) -> <Self as CubeType>::ExpandType {
57 BiasLayoutExpand {
58 shape: <u32 as LaunchArg>::expand(&(), builder),
59 vector_size: ty.vector_size() as u32,
60 }
61 }
62}