cubecl_convolution/loader/
bias.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand, tensor::r#virtual::VirtualTensor};
4
5use crate::homogeneous::simple::ConvTilingLayout;
6use crate::reader::bias::BiasReader;
7use cubecl_matmul::components::{
8 Ident, MatmulPrecision,
9 global::{AccumulatorLoader, GlobalConfig},
10 stage::{StageConfig, StageMemory},
11 tile::{Tile, TileConfig, TileMatmul},
12};
13
14#[derive(CubeType)]
16pub enum BiasLoader<MP: MatmulPrecision> {
17 Some {
18 tensor_view: BiasReader<MP::EO>,
19 stage: StageMemory<MP::EA, ConvTilingLayout>,
20 },
21 None,
22}
23
24#[cube]
25impl<MP: MatmulPrecision> AccumulatorLoader<MP> for BiasLoader<MP> {
26 fn fill_stage<G: GlobalConfig>(this: &mut Self, #[comptime] config: G) {
27 match this {
28 BiasLoader::Some { tensor_view, stage } => {
29 let line_size = config.global_line_size(Ident::Out);
30 let num_stage_elements = config.tiling_scheme().elements_in_stage_n();
31
32 let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
33 let unit_position_base = unit_id * line_size;
34
35 let mut slice = stage.as_slice_mut(line_size);
36
37 if unit_position_base < num_stage_elements {
38 let read_line = tensor_view.load_simple::<G>(unit_position_base, line_size);
39 slice[unit_id] = Line::cast_from(read_line);
40 }
41 }
42 BiasLoader::None => {}
43 }
44 }
45
46 fn load<TMM: TileMatmul<MP>>(
48 this: &mut Self,
49 acc: &mut TMM::Accumulator,
50 tile_n: u32,
51 #[comptime] config: TMM::Config,
52 ) {
53 match this {
54 BiasLoader::Some { stage, .. } => {
55 let line_size = config.stage_line_size(Ident::Out);
56 let tile_elems = config.tile_size().n() / line_size;
57 let start = tile_n * tile_elems;
58 let slice = stage
59 .as_slice_mut(line_size)
60 .slice(start, start + tile_elems);
61 let tile = Tile::new_strided(slice, 0, config.matrix_layout(Ident::Out));
62 TMM::fill_accumulator(&tile, acc, config);
63 }
64 BiasLoader::None => {
65 TMM::zero_accumulator(acc, config);
66 }
67 }
68 }
69}
70
71#[cube]
72impl<MP: MatmulPrecision> BiasLoader<MP> {
73 pub fn new<G: GlobalConfig>(
74 tensor: CubeOption<VirtualTensor<MP::EO>>,
75 n_offset: u32,
76 #[comptime] config: G,
77 ) -> Self {
78 match tensor {
79 CubeOption::Some(tensor) => {
80 let stage = init_stage::<MP::EA, G>(config);
81 let shape_n = tensor.shape(0);
82 let tensor_view = BiasReader::<MP::EO>::new(tensor, n_offset, shape_n);
83
84 BiasLoader::<MP>::new_Some(tensor_view, stage)
85 }
86 CubeOption::None => BiasLoader::new_None(),
87 }
88 }
89}
90
91#[cube]
92fn init_stage<ES: Numeric, G: GlobalConfig>(
93 #[comptime] config: G,
94) -> StageMemory<ES, ConvTilingLayout> {
95 let line_size = config.stage_config().stage_line_size(Ident::Out);
96
97 let smem = SharedMemory::new_lined(
98 comptime!(config.tiling_scheme().elements_in_stage_n() / line_size),
99 line_size,
100 );
101
102 StageMemory::<ES, ConvTilingLayout>::new_with_smem(smem, 1u32)
103}
104
105#[cube]
106fn init_empty_stage<ES: Numeric>() -> StageMemory<ES, ConvTilingLayout> {
107 StageMemory::<ES, ConvTilingLayout>::new_with_smem(SharedMemory::new(1), 1u32)
108}