cubecl_convolution/components/stage/
reader.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4 InvalidConfigError, MatrixLayout,
5 stage::{
6 StageMemoryConfig, StridedStageMemory, TilingLayout, TilingLayoutEnum, TilingValidation,
7 },
8 tile::StridedTile,
9};
10use cubecl_std::tensor::layout::Coords2d;
11
12#[derive(Clone, Copy)]
13pub struct BiasTilingLayout {}
15
16#[cube]
17impl TilingLayout for BiasTilingLayout {
18 fn get_tile<ES: Numeric>(
19 stage: &StridedStageMemory<ES, Self>,
20 tile: Coords2d,
21 #[comptime] config: StageMemoryConfig,
22 ) -> StridedTile<ES> {
23 if comptime!(config.num_stages > 1) {
24 unimplemented!()
25 }
26
27 let (_, col) = tile;
28
29 let stage_line_size = config.line_size;
30 let tile_size_col = config.elements_per_tile_along_col / stage_line_size;
31
32 let length = tile_size_col;
33 let start = col * tile_size_col;
34
35 StridedTile::new_strided(
36 stage.as_slice(stage_line_size),
37 start,
38 start + length,
39 0,
40 stage.swizzle,
41 MatrixLayout::RowMajor,
42 stage_line_size,
43 )
44 }
45
46 fn to_enum() -> comptime_type!(TilingLayoutEnum) {
47 comptime![TilingLayoutEnum::Other]
48 }
49}
50
51impl TilingValidation for BiasTilingLayout {
52 fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
53 let stage_width = config.elements_per_stage_along_col();
54 if config.line_size > stage_width {
55 return Err(Box::new("Invalid line size"));
56 }
57 Ok(())
58 }
59}