cubecl_convolution/components/stage/
reader.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4 InvalidConfigError, MatrixLayout, StageIdent,
5 global::memory::GlobalMemoryConfig,
6 stage::{StageMemoryConfig, StridedStage, TilingLayout, TilingValidation},
7 tile::StridedTile,
8};
9use cubecl_std::tensor::layout::Coords2d;
10
11#[derive(Clone, Copy)]
12pub struct BiasTilingLayout {}
14
15#[cube]
16impl TilingLayout for BiasTilingLayout {
17 fn get_tile<ES: Numeric>(
18 stage: &StridedStage<ES, Self>,
19 tile: Coords2d,
20 _buffer_index: u32,
21 #[comptime] _ident: StageIdent,
22 #[comptime] config: StageMemoryConfig,
23 ) -> StridedTile<ES> {
24 if comptime!(config.num_stages > 1) {
25 unimplemented!()
26 }
27
28 let (_, col) = tile;
29
30 let stage_line_size = config.stage_line_size;
31 let tile_size_col = config.elements_in_tile_col / stage_line_size;
32
33 let length = tile_size_col;
34 let start = col * tile_size_col;
35
36 StridedTile::new_strided(
37 stage.as_slice(stage_line_size).slice(start, start + length),
38 0,
39 MatrixLayout::RowMajor,
40 )
41 }
42}
43
44impl TilingValidation for BiasTilingLayout {
45 fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> {
46 let stage_width = config.elements_in_stage_col;
47 if config.global_line_size > stage_width {
48 return Err(Box::new("Invalid line size"));
49 }
50 Ok(())
51 }
52}