cubecl_convolution/loader/
bias.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand, tensor::r#virtual::VirtualTensor};
4
5use crate::homogeneous::simple::ConvTilingLayout;
6use crate::reader::bias::BiasReader;
7use cubecl_matmul::components::{
8    Ident, MatmulPrecision,
9    global::{AccumulatorLoader, GlobalConfig},
10    stage::{StageConfig, StageMemory},
11    tile::{Tile, TileConfig, TileMatmul},
12};
13
14/// Special loader to broadcast the 1D bias to the 2D accumulator matrix
15#[derive(CubeType)]
16pub enum BiasLoader<MP: MatmulPrecision> {
17    Some {
18        tensor_view: BiasReader<MP::EO>,
19        stage: StageMemory<MP::EA, ConvTilingLayout>,
20    },
21    None,
22}
23
24#[cube]
25impl<MP: MatmulPrecision> AccumulatorLoader<MP> for BiasLoader<MP> {
26    fn fill_stage<G: GlobalConfig>(this: &mut Self, #[comptime] config: G) {
27        match this {
28            BiasLoader::Some { tensor_view, stage } => {
29                let line_size = config.global_line_size(Ident::Out);
30                let num_stage_elements = config.tiling_scheme().elements_in_stage_n();
31
32                let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
33                let unit_position_base = unit_id * line_size;
34
35                let mut slice = stage.as_slice_mut(line_size);
36
37                if unit_position_base < num_stage_elements {
38                    let read_line = tensor_view.load_simple::<G>(unit_position_base, line_size);
39                    slice[unit_id] = Line::cast_from(read_line);
40                }
41            }
42            BiasLoader::None => {}
43        }
44    }
45
46    /// Load accumulator
47    fn load<TMM: TileMatmul<MP>>(
48        this: &mut Self,
49        acc: &mut TMM::Accumulator,
50        tile_n: u32,
51        #[comptime] config: TMM::Config,
52    ) {
53        match this {
54            BiasLoader::Some { stage, .. } => {
55                let line_size = config.stage_line_size(Ident::Out);
56                let tile_elems = config.tile_size().n() / line_size;
57                let start = tile_n * tile_elems;
58                let slice = stage
59                    .as_slice_mut(line_size)
60                    .slice(start, start + tile_elems);
61                let tile = Tile::new_strided(slice, 0, config.matrix_layout(Ident::Out));
62                TMM::fill_accumulator(&tile, acc, config);
63            }
64            BiasLoader::None => {
65                TMM::zero_accumulator(acc, config);
66            }
67        }
68    }
69}
70
71#[cube]
72impl<MP: MatmulPrecision> BiasLoader<MP> {
73    pub fn new<G: GlobalConfig>(
74        tensor: CubeOption<VirtualTensor<MP::EO>>,
75        n_offset: u32,
76        #[comptime] config: G,
77    ) -> Self {
78        match tensor {
79            CubeOption::Some(tensor) => {
80                let stage = init_stage::<MP::EA, G>(config);
81                let shape_n = tensor.shape(0);
82                let tensor_view = BiasReader::<MP::EO>::new(tensor, n_offset, shape_n);
83
84                BiasLoader::<MP>::new_Some(tensor_view, stage)
85            }
86            CubeOption::None => BiasLoader::new_None(),
87        }
88    }
89}
90
91#[cube]
92fn init_stage<ES: Numeric, G: GlobalConfig>(
93    #[comptime] config: G,
94) -> StageMemory<ES, ConvTilingLayout> {
95    let line_size = config.stage_config().stage_line_size(Ident::Out);
96
97    let smem = SharedMemory::new_lined(
98        comptime!(config.tiling_scheme().elements_in_stage_n() / line_size),
99        line_size,
100    );
101
102    StageMemory::<ES, ConvTilingLayout>::new_with_smem(smem, 1u32)
103}
104
105#[cube]
106fn init_empty_stage<ES: Numeric>() -> StageMemory<ES, ConvTilingLayout> {
107    StageMemory::<ES, ConvTilingLayout>::new_with_smem(SharedMemory::new(1), 1u32)
108}