cubecl_convolution/components/global/read/reader/
bias.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4 CubeOption, CubeOptionExpand,
5 tensor::{View, layout::Coords2d},
6};
7
8use cubecl_matmul::components::{
9 MatrixPrecision, StageIdent,
10 global::GlobalConfig,
11 stage::{StageMemoryConfig, StridedStage},
12};
13
14use crate::components::stage::reader::BiasTilingLayout;
15
16#[derive(CubeType)]
18pub enum BiasGlobalReader<IP: MatrixPrecision> {
19 Some {
20 view: View<Line<IP::Global>, Coords2d>,
21 stage: StridedStage<IP::Stage, BiasTilingLayout>,
22 },
23 None,
24}
25
26pub type BiasStage<E> = CubeOption<StridedStage<E, BiasTilingLayout>>;
28
29#[cube]
30impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
31 pub fn load_stage<G: GlobalConfig>(&mut self, #[comptime] config: G) {
34 match self {
35 BiasGlobalReader::Some { view, stage } => {
36 let line_size = view.line_size();
37 let num_stage_elements = config.tiling_scheme().elements_in_stage_n();
38
39 let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
40 let unit_pos = unit_id * line_size;
41
42 let mut slice = stage.as_slice_mut(line_size);
43
44 if unit_pos < num_stage_elements {
45 let read_line = view.read_checked((0, unit_pos));
46 slice[unit_id] = Line::cast_from(read_line);
47 }
48 }
49 BiasGlobalReader::None => {}
50 }
51 }
52
53 pub fn stage(&self) -> BiasStage<IP::Stage> {
56 match self {
57 BiasGlobalReader::Some { stage, .. } => CubeOption::new_Some(*stage),
58 BiasGlobalReader::None => CubeOption::new_None(),
59 }
60 }
61}
62
63#[cube]
64impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
65 pub fn new(
67 view: CubeOption<View<Line<IP::Global>, Coords2d>>,
68 #[comptime] config: StageMemoryConfig,
69 ) -> Self {
70 match view {
71 CubeOption::Some(view) => {
72 let stage = init_stage::<IP::Stage>(config);
73
74 BiasGlobalReader::<IP>::new_Some(view, stage)
75 }
76 CubeOption::None => BiasGlobalReader::new_None(),
77 }
78 }
79}
80
81#[cube]
83fn init_stage<ES: Numeric>(
84 #[comptime] config: StageMemoryConfig,
85) -> StridedStage<ES, BiasTilingLayout> {
86 let line_size = config.stage_line_size;
87
88 let smem = SharedMemory::new_lined(
89 comptime!(config.elements_in_stage_col() / line_size),
90 line_size,
91 );
92
93 StridedStage::<ES, BiasTilingLayout>::new_with_smem(smem, StageIdent::Acc, config)
94}