cubecl_convolution/components/global/read/reader/
bias.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4 CubeOption, CubeOptionExpand,
5 tensor::{View, layout::Coords2d},
6};
7
8use cubecl_matmul::components::{
9 MatrixPrecision,
10 global::GlobalConfig,
11 stage::{StageConfig as _, StageMemoryConfig, StridedStageMemory},
12};
13
14use crate::components::stage::reader::BiasTilingLayout;
15
16#[derive(CubeType)]
18pub enum BiasGlobalReader<IP: MatrixPrecision> {
19 Some {
20 view: View<Line<IP::Global>, Coords2d>,
21 stage: StridedStageMemory<IP::Stage, BiasTilingLayout>,
22 },
23 None,
24}
25
26pub type BiasStage<E> = CubeOption<StridedStageMemory<E, BiasTilingLayout>>;
28
29#[cube]
30impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
31 pub fn load_stage<G: GlobalConfig>(&mut self, #[comptime] config: G) {
34 match self {
35 BiasGlobalReader::Some { view, stage } => {
36 let line_size = view.line_size();
37 let num_stage_elements = config.stage_config().elements_in_stage_n();
38
39 let unit_id = UNIT_POS_Y * config.stage_config().plane_dim() + UNIT_POS_X;
40 let unit_pos = unit_id * line_size;
41
42 let mut slice = stage.as_slice_mut(line_size);
43
44 if unit_pos < num_stage_elements {
45 let read_line = view.read_checked((0, unit_pos));
46 slice[unit_id] = Line::cast_from(read_line);
47 }
48 }
49 BiasGlobalReader::None => {}
50 }
51 }
52
53 pub fn stage(&self) -> BiasStage<IP::Stage> {
56 match self {
57 BiasGlobalReader::Some { stage, .. } => CubeOption::new_Some(*stage),
58 BiasGlobalReader::None => CubeOption::new_None(),
59 }
60 }
61}
62
63#[cube]
64impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
65 pub fn new(
67 view: CubeOption<View<Line<IP::Global>, Coords2d>>,
68 #[comptime] config: StageMemoryConfig,
69 ) -> Self {
70 match view {
71 CubeOption::Some(view) => {
72 let stage = init_stage::<IP::Stage>(config);
73
74 BiasGlobalReader::<IP>::new_Some(view, stage)
75 }
76 CubeOption::None => BiasGlobalReader::new_None(),
77 }
78 }
79}
80
81#[cube]
83fn init_stage<ES: Numeric>(
84 #[comptime] config: StageMemoryConfig,
85) -> StridedStageMemory<ES, BiasTilingLayout> {
86 let line_size = config.line_size;
87
88 let stage_len = comptime!(config.elements_per_stage_along_col() / line_size);
89 let smem = SharedMemory::new_lined(stage_len, line_size);
90
91 StridedStageMemory::<ES, BiasTilingLayout>::new_with_smem(smem, stage_len, config)
92}