Skip to main content

cubek_convolution/components/stage/
reader.rs

1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3use cubek_matmul::components::stage::TilingValidation;
4use cubek_std::{
5    stage::StageMemoryConfig,
6    tile::StridedTile,
7    {InvalidConfigError, MatrixLayout},
8};
9
10use crate::components::stage::bias_stage::BiasStageMemory;
11
12#[derive(Clone, Copy)]
13/// Tiling layout specific for bias, which is one-dimensional with stride 0
14pub struct BiasTilingLayout {}
15
16#[cube]
17impl BiasTilingLayout {
18    pub fn get_tile<ES: Numeric, NS: Size>(
19        stage: &BiasStageMemory<ES, NS>,
20        tile: Coords2d,
21        #[comptime] config: StageMemoryConfig,
22    ) -> StridedTile<ES, NS> {
23        if config.num_stages > 1 {
24            unimplemented!()
25        }
26
27        let (_, col) = tile;
28
29        let stage_vector_size = config.vector_size;
30        let tile_size_col = config.elements_per_tile_along_col / stage_vector_size;
31
32        let length = tile_size_col;
33        let start = col * tile_size_col;
34
35        StridedTile::new_strided(
36            stage.as_slice(),
37            start,
38            start + length,
39            0,
40            stage.swizzle,
41            MatrixLayout::RowMajor,
42        )
43    }
44}
45
46impl TilingValidation for BiasTilingLayout {
47    fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
48        let stage_width = config.elements_per_stage_along_col();
49        if config.vector_size > stage_width {
50            return Err(Box::new(format!(
51                "Invalid vector size. Got {:?} which should not be >{:?}",
52                config.vector_size, stage_width,
53            )));
54        }
55        Ok(())
56    }
57}