cubecl_convolution/components/stage/
reader.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4    InvalidConfigError, MatrixLayout,
5    global::memory::GlobalMemoryConfig,
6    stage::{StageMemoryConfig, StridedStage, TilingLayout, TilingValidation},
7    tile::StridedTile,
8};
9use cubecl_std::tensor::layout::Coords2d;
10
11#[derive(Clone, Copy)]
12/// Tiling layout specific for bias, which is one-dimensional with stride 0
13pub struct BiasTilingLayout {}
14
15#[cube]
16impl TilingLayout for BiasTilingLayout {
17    fn get_tile<ES: Numeric>(
18        stage: &StridedStage<ES, Self>,
19        tile: Coords2d,
20        #[comptime] config: StageMemoryConfig,
21    ) -> StridedTile<ES> {
22        if comptime!(config.num_stages > 1) {
23            unimplemented!()
24        }
25
26        let (_, col) = tile;
27
28        let stage_line_size = config.stage_line_size;
29        let tile_size_col = config.elements_in_tile_col / stage_line_size;
30
31        let length = tile_size_col;
32        let start = col * tile_size_col;
33
34        StridedTile::new_strided(
35            stage.as_slice(stage_line_size).slice(start, start + length),
36            0,
37            MatrixLayout::RowMajor,
38        )
39    }
40}
41
42impl TilingValidation for BiasTilingLayout {
43    fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> {
44        let stage_width = config.elements_in_stage_col;
45        if config.global_line_size > stage_width {
46            return Err(Box::new("Invalid line size"));
47        }
48        Ok(())
49    }
50}