Skip to main content

cubek_convolution/components/stage/
reader.rs

1use cubecl;
2use cubecl::prelude::*;
3use cubecl::std::tensor::layout::Coords2d;
4use cubek_matmul::{
5    components::{
6        stage::{StageMemoryConfig, TilingValidation},
7        tile::StridedTile,
8    },
9    definition::{InvalidConfigError, MatrixLayout},
10};
11
12use crate::components::stage::bias_stage::BiasStageMemory;
13
14#[derive(Clone, Copy)]
15/// Tiling layout specific for bias, which is one-dimensional with stride 0
16pub struct BiasTilingLayout {}
17
18#[cube]
19impl BiasTilingLayout {
20    pub fn get_tile<ES: Numeric>(
21        stage: &BiasStageMemory<ES>,
22        tile: Coords2d,
23        #[comptime] config: StageMemoryConfig,
24    ) -> StridedTile<ES> {
25        if config.num_stages > 1 {
26            unimplemented!()
27        }
28
29        let (_, col) = tile;
30
31        let stage_line_size = config.line_size;
32        let tile_size_col = config.elements_per_tile_along_col / stage_line_size;
33
34        let length = tile_size_col;
35        let start = col * tile_size_col;
36
37        StridedTile::new_strided(
38            stage.as_slice(stage_line_size as usize),
39            start,
40            start + length,
41            0,
42            stage.swizzle,
43            MatrixLayout::RowMajor,
44            stage_line_size,
45        )
46    }
47}
48
49impl TilingValidation for BiasTilingLayout {
50    fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
51        let stage_width = config.elements_per_stage_along_col();
52        if config.line_size > stage_width {
53            return Err(Box::new(format!(
54                "Invalid line size. Got {:?} which should not be >{:?}",
55                config.line_size, stage_width,
56            )));
57        }
58        Ok(())
59    }
60}