cubecl_convolution/components/stage/
reader.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4    InvalidConfigError, MatrixLayout,
5    stage::{
6        StageMemoryConfig, StridedStageMemory, TilingLayout, TilingLayoutEnum, TilingValidation,
7    },
8    tile::StridedTile,
9};
10use cubecl_std::tensor::layout::Coords2d;
11
12#[derive(Clone, Copy)]
13/// Tiling layout specific for bias, which is one-dimensional with stride 0
14pub struct BiasTilingLayout {}
15
16#[cube]
17impl TilingLayout for BiasTilingLayout {
18    fn get_tile<ES: Numeric>(
19        stage: &StridedStageMemory<ES, Self>,
20        tile: Coords2d,
21        #[comptime] config: StageMemoryConfig,
22    ) -> StridedTile<ES> {
23        if comptime!(config.num_stages > 1) {
24            unimplemented!()
25        }
26
27        let (_, col) = tile;
28
29        let stage_line_size = config.line_size;
30        let tile_size_col = config.elements_per_tile_along_col / stage_line_size;
31
32        let length = tile_size_col;
33        let start = col * tile_size_col;
34
35        StridedTile::new_strided(
36            stage.as_slice(stage_line_size),
37            start,
38            start + length,
39            0,
40            stage.swizzle,
41            MatrixLayout::RowMajor,
42            stage_line_size,
43        )
44    }
45
46    fn to_enum() -> comptime_type!(TilingLayoutEnum) {
47        comptime![TilingLayoutEnum::Other]
48    }
49}
50
51impl TilingValidation for BiasTilingLayout {
52    fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
53        let stage_width = config.elements_per_stage_along_col();
54        if config.line_size > stage_width {
55            return Err(Box::new("Invalid line size"));
56        }
57        Ok(())
58    }
59}