cubek_convolution/components/stage/
reader.rs1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3use cubek_matmul::components::stage::TilingValidation;
4use cubek_std::{
5 stage::StageMemoryConfig,
6 tile::StridedTile,
7 {InvalidConfigError, MatrixLayout},
8};
9
10use crate::components::stage::bias_stage::BiasStageMemory;
11
12#[derive(Clone, Copy)]
13pub struct BiasTilingLayout {}
15
16#[cube]
17impl BiasTilingLayout {
18 pub fn get_tile<ES: Numeric, NS: Size>(
19 stage: &BiasStageMemory<ES, NS>,
20 tile: Coords2d,
21 #[comptime] config: StageMemoryConfig,
22 ) -> StridedTile<ES, NS> {
23 if config.num_stages > 1 {
24 unimplemented!()
25 }
26
27 let (_, col) = tile;
28
29 let stage_vector_size = config.vector_size;
30 let tile_size_col = config.elements_per_tile_along_col / stage_vector_size;
31
32 let length = tile_size_col;
33 let start = col * tile_size_col;
34
35 StridedTile::new_strided(
36 stage.as_slice(),
37 start,
38 start + length,
39 0,
40 stage.swizzle,
41 MatrixLayout::RowMajor,
42 )
43 }
44}
45
46impl TilingValidation for BiasTilingLayout {
47 fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
48 let stage_width = config.elements_per_stage_along_col();
49 if config.vector_size > stage_width {
50 return Err(Box::new(format!(
51 "Invalid vector size. Got {:?} which should not be >{:?}",
52 config.vector_size, stage_width,
53 )));
54 }
55 Ok(())
56 }
57}