cubecl_convolution/components/global/read/reader/
bias.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4    CubeOption, CubeOptionExpand,
5    tensor::{View, layout::Coords2d},
6};
7
8use cubecl_matmul::components::{
9    MatrixPrecision, StageIdent,
10    global::GlobalConfig,
11    stage::{StageMemoryConfig, StridedStage},
12};
13
14use crate::components::stage::reader::BiasTilingLayout;
15
16/// Special reader to broadcast the 1D bias to the 2D accumulator matrix
17#[derive(CubeType)]
18pub enum BiasGlobalReader<IP: MatrixPrecision> {
19    Some {
20        view: View<Line<IP::Global>, Coords2d>,
21        stage: StridedStage<IP::Stage, BiasTilingLayout>,
22    },
23    None,
24}
25
26/// Type of the stage reader for the bias reader
27pub type BiasStage<E> = CubeOption<StridedStage<E, BiasTilingLayout>>;
28
29#[cube]
30impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
31    /// Reads all bias tiles into the stage. Unlike normal readers, bias only reads a 1D vector along
32    /// the `n` dimension.
33    pub fn load_stage<G: GlobalConfig>(&mut self, #[comptime] config: G) {
34        match self {
35            BiasGlobalReader::Some { view, stage } => {
36                let line_size = view.line_size();
37                let num_stage_elements = config.tiling_scheme().elements_in_stage_n();
38
39                let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
40                let unit_pos = unit_id * line_size;
41
42                let mut slice = stage.as_slice_mut(line_size);
43
44                if unit_pos < num_stage_elements {
45                    let read_line = view.read_checked((0, unit_pos));
46                    slice[unit_id] = Line::cast_from(read_line);
47                }
48            }
49            BiasGlobalReader::None => {}
50        }
51    }
52
53    /// Return the stage contained in this global reader. It will use custom tiling with
54    /// a stride of `0`.
55    pub fn stage(&self) -> BiasStage<IP::Stage> {
56        match self {
57            BiasGlobalReader::Some { stage, .. } => CubeOption::new_Some(*stage),
58            BiasGlobalReader::None => CubeOption::new_None(),
59        }
60    }
61}
62
63#[cube]
64impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
65    /// Create a new bias reader from the bias tensor and a global offset `n_offset`.
66    pub fn new(
67        view: CubeOption<View<Line<IP::Global>, Coords2d>>,
68        #[comptime] config: StageMemoryConfig,
69    ) -> Self {
70        match view {
71            CubeOption::Some(view) => {
72                let stage = init_stage::<IP::Stage>(config);
73
74                BiasGlobalReader::<IP>::new_Some(view, stage)
75            }
76            CubeOption::None => BiasGlobalReader::new_None(),
77        }
78    }
79}
80
81/// Create a new 1D bias stage of size `stage_size_n`.
82#[cube]
83fn init_stage<ES: Numeric>(
84    #[comptime] config: StageMemoryConfig,
85) -> StridedStage<ES, BiasTilingLayout> {
86    let line_size = config.stage_line_size;
87
88    let smem = SharedMemory::new_lined(
89        comptime!(config.elements_in_stage_col() / line_size),
90        line_size,
91    );
92
93    StridedStage::<ES, BiasTilingLayout>::new_with_smem(smem, StageIdent::Acc, config)
94}