cubek_convolution/components/stage/
reader.rs

1use cubecl;
2use cubecl::prelude::*;
3use cubecl::std::tensor::layout::Coords2d;
4use cubek_matmul::{
5    components::{
6        stage::{
7            StageMemoryConfig, StridedStageMemory, TilingLayout, TilingLayoutEnum, TilingValidation,
8        },
9        tile::StridedTile,
10    },
11    definition::{InvalidConfigError, MatrixLayout},
12};
13
14#[derive(Clone, Copy)]
15/// Tiling layout specific for bias, which is one-dimensional with stride 0
16pub struct BiasTilingLayout {}
17
18#[cube]
19impl TilingLayout for BiasTilingLayout {
20    fn get_tile<ES: Numeric>(
21        stage: &StridedStageMemory<ES, Self>,
22        tile: Coords2d,
23        #[comptime] config: StageMemoryConfig,
24    ) -> StridedTile<ES> {
25        if config.num_stages > 1 {
26            unimplemented!()
27        }
28
29        let (_, col) = tile;
30
31        let stage_line_size = config.line_size;
32        let tile_size_col = config.elements_per_tile_along_col / stage_line_size;
33
34        let length = tile_size_col;
35        let start = col * tile_size_col;
36
37        StridedTile::new_strided(
38            stage.as_slice(stage_line_size as usize),
39            start,
40            start + length,
41            0,
42            stage.swizzle,
43            MatrixLayout::RowMajor,
44            stage_line_size,
45        )
46    }
47
48    fn to_enum() -> comptime_type!(TilingLayoutEnum) {
49        TilingLayoutEnum::Other
50    }
51}
52
53impl TilingValidation for BiasTilingLayout {
54    fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
55        let stage_width = config.elements_per_stage_along_col();
56        if config.line_size > stage_width {
57            return Err(Box::new(format!(
58                "Invalid line size. Got {:?} which should not be >{:?}",
59                config.line_size, stage_width,
60            )));
61        }
62        Ok(())
63    }
64}