cubecl_linalg/convolution/loader/
bias.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand, tensor::r#virtual::VirtualTensor};
4
5use crate::{
6    convolution::homogeneous::simple::ConvTilingLayout,
7    matmul::components::{
8        Ident,
9        global::{AccumulatorLoader, GlobalConfig},
10        stage::{Stage, StageConfig},
11        tile::{Tile, TileConfig, TileMatmul},
12    },
13};
14use crate::{convolution::reader::bias::BiasReader, matmul::components::MatmulPrecision};
15
16/// Special loader to broadcast the 1D bias to the 2D accumulator matrix
17#[derive(CubeType)]
18pub enum BiasLoader<MP: MatmulPrecision> {
19    Some {
20        tensor_view: BiasReader<MP::EO>,
21        stage: Stage<MP::EA, ConvTilingLayout>,
22    },
23    None,
24}
25
26#[cube]
27impl<MP: MatmulPrecision> AccumulatorLoader<MP> for BiasLoader<MP> {
28    fn fill_stage<G: GlobalConfig>(this: &mut Self, #[comptime] config: G) {
29        match this {
30            BiasLoader::Some { tensor_view, stage } => {
31                let stage_tiling = config.tiling_dimensions(Ident::Rhs);
32                let line_size = config.global_line_size(Ident::Out);
33
34                let num_stage_elements = stage_tiling.total_col();
35
36                let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
37                let unit_position_base = unit_id * line_size;
38
39                let mut slice = stage.as_slice_mut(line_size);
40
41                if unit_position_base < num_stage_elements {
42                    let read_line = tensor_view.load_simple::<G>(unit_position_base, line_size);
43                    slice[unit_id] = Line::cast_from(read_line);
44                }
45            }
46            BiasLoader::None => {}
47        }
48    }
49
50    /// Load accumulator
51    fn load<TMM: TileMatmul<MP>>(
52        this: &mut Self,
53        acc: &mut TMM::Accumulator,
54        tile_n: u32,
55        #[comptime] config: TMM::Config,
56    ) {
57        match this {
58            BiasLoader::Some { stage, .. } => {
59                let line_size = config.stage_line_size(Ident::Out);
60                let tile_elems = config.tile_shape().n / line_size;
61                let start = tile_n * tile_elems;
62                let slice = stage
63                    .as_slice_mut(line_size)
64                    .slice(start, start + tile_elems);
65                let tile = Tile::new_strided(slice, 0);
66                TMM::fill_accumulator(&tile, acc, config);
67            }
68            BiasLoader::None => {
69                TMM::zero_accumulator(acc, config);
70            }
71        }
72    }
73}
74
75#[cube]
76impl<MP: MatmulPrecision> BiasLoader<MP> {
77    pub fn new<G: GlobalConfig>(
78        tensor: CubeOption<VirtualTensor<MP::EO>>,
79        n_offset: u32,
80        #[comptime] config: G,
81    ) -> Self {
82        match tensor {
83            CubeOption::Some(tensor) => {
84                let stage = init_stage::<MP::EA, G>(config);
85                let shape_n = tensor.shape(0);
86                let tensor_view = BiasReader::<MP::EO>::new(tensor, n_offset, shape_n);
87
88                BiasLoader::<MP>::new_Some(tensor_view, stage)
89            }
90            CubeOption::None => BiasLoader::new_None(),
91        }
92    }
93}
94
95#[cube]
96fn init_stage<ES: Numeric, G: GlobalConfig>(#[comptime] config: G) -> Stage<ES, ConvTilingLayout> {
97    let line_size = config.to_smm_config().stage_line_size(Ident::Out);
98
99    let smem = SharedMemory::new_lined(
100        comptime!(config.tiling_dimensions(Ident::Out).total_col() / line_size),
101        line_size,
102    );
103
104    Stage::<ES, ConvTilingLayout>::new_with_smem(smem)
105}
106
107#[cube]
108fn init_empty_stage<ES: Numeric>() -> Stage<ES, ConvTilingLayout> {
109    Stage::<ES, ConvTilingLayout>::new_with_smem(SharedMemory::new(1))
110}