cubek_convolution/components/stage/
reader.rs1use cubecl;
2use cubecl::prelude::*;
3use cubecl::std::tensor::layout::Coords2d;
4use cubek_matmul::{
5 components::{
6 stage::{
7 StageMemoryConfig, StridedStageMemory, TilingLayout, TilingLayoutEnum, TilingValidation,
8 },
9 tile::StridedTile,
10 },
11 definition::{InvalidConfigError, MatrixLayout},
12};
13
14#[derive(Clone, Copy)]
15pub struct BiasTilingLayout {}
17
18#[cube]
19impl TilingLayout for BiasTilingLayout {
20 fn get_tile<ES: Numeric>(
21 stage: &StridedStageMemory<ES, Self>,
22 tile: Coords2d,
23 #[comptime] config: StageMemoryConfig,
24 ) -> StridedTile<ES> {
25 if config.num_stages > 1 {
26 unimplemented!()
27 }
28
29 let (_, col) = tile;
30
31 let stage_line_size = config.line_size;
32 let tile_size_col = config.elements_per_tile_along_col / stage_line_size;
33
34 let length = tile_size_col;
35 let start = col * tile_size_col;
36
37 StridedTile::new_strided(
38 stage.as_slice(stage_line_size as usize),
39 start,
40 start + length,
41 0,
42 stage.swizzle,
43 MatrixLayout::RowMajor,
44 stage_line_size,
45 )
46 }
47
48 fn to_enum() -> comptime_type!(TilingLayoutEnum) {
49 TilingLayoutEnum::Other
50 }
51}
52
53impl TilingValidation for BiasTilingLayout {
54 fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
55 let stage_width = config.elements_per_stage_along_col();
56 if config.line_size > stage_width {
57 return Err(Box::new(format!(
58 "Invalid line size. Got {:?} which should not be >{:?}",
59 config.line_size, stage_width,
60 )));
61 }
62 Ok(())
63 }
64}