cubek_convolution/components/stage/
reader.rs1use cubecl;
2use cubecl::prelude::*;
3use cubecl::std::tensor::layout::Coords2d;
4use cubek_matmul::{
5 components::{
6 stage::{StageMemoryConfig, TilingValidation},
7 tile::StridedTile,
8 },
9 definition::{InvalidConfigError, MatrixLayout},
10};
11
12use crate::components::stage::bias_stage::BiasStageMemory;
13
14#[derive(Clone, Copy)]
15pub struct BiasTilingLayout {}
17
18#[cube]
19impl BiasTilingLayout {
20 pub fn get_tile<ES: Numeric>(
21 stage: &BiasStageMemory<ES>,
22 tile: Coords2d,
23 #[comptime] config: StageMemoryConfig,
24 ) -> StridedTile<ES> {
25 if config.num_stages > 1 {
26 unimplemented!()
27 }
28
29 let (_, col) = tile;
30
31 let stage_line_size = config.line_size;
32 let tile_size_col = config.elements_per_tile_along_col / stage_line_size;
33
34 let length = tile_size_col;
35 let start = col * tile_size_col;
36
37 StridedTile::new_strided(
38 stage.as_slice(stage_line_size as usize),
39 start,
40 start + length,
41 0,
42 stage.swizzle,
43 MatrixLayout::RowMajor,
44 stage_line_size,
45 )
46 }
47}
48
49impl TilingValidation for BiasTilingLayout {
50 fn check(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
51 let stage_width = config.elements_per_stage_along_col();
52 if config.line_size > stage_width {
53 return Err(Box::new(format!(
54 "Invalid line size. Got {:?} which should not be >{:?}",
55 config.line_size, stage_width,
56 )));
57 }
58 Ok(())
59 }
60}