cubecl_linalg/convolution/loader/
bias.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand, tensor::r#virtual::VirtualTensor};
4
5use crate::{
6 convolution::homogeneous::simple::ConvTilingLayout,
7 matmul::components::{
8 Ident,
9 global::{AccumulatorLoader, GlobalConfig},
10 stage::{Stage, StageConfig},
11 tile::{Tile, TileConfig, TileMatmul},
12 },
13};
14use crate::{convolution::reader::bias::BiasReader, matmul::components::MatmulPrecision};
15
16#[derive(CubeType)]
18pub enum BiasLoader<MP: MatmulPrecision> {
19 Some {
20 tensor_view: BiasReader<MP::EO>,
21 stage: Stage<MP::EA, ConvTilingLayout>,
22 },
23 None,
24}
25
26#[cube]
27impl<MP: MatmulPrecision> AccumulatorLoader<MP> for BiasLoader<MP> {
28 fn fill_stage<G: GlobalConfig>(this: &mut Self, #[comptime] config: G) {
29 match this {
30 BiasLoader::Some { tensor_view, stage } => {
31 let stage_tiling = config.tiling_dimensions(Ident::Rhs);
32 let line_size = config.global_line_size(Ident::Out);
33
34 let num_stage_elements = stage_tiling.total_col();
35
36 let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
37 let unit_position_base = unit_id * line_size;
38
39 let mut slice = stage.as_slice_mut(line_size);
40
41 if unit_position_base < num_stage_elements {
42 let read_line = tensor_view.load_simple::<G>(unit_position_base, line_size);
43 slice[unit_id] = Line::cast_from(read_line);
44 }
45 }
46 BiasLoader::None => {}
47 }
48 }
49
50 fn load<TMM: TileMatmul<MP>>(
52 this: &mut Self,
53 acc: &mut TMM::Accumulator,
54 tile_n: u32,
55 #[comptime] config: TMM::Config,
56 ) {
57 match this {
58 BiasLoader::Some { stage, .. } => {
59 let line_size = config.stage_line_size(Ident::Out);
60 let tile_elems = config.tile_shape().n / line_size;
61 let start = tile_n * tile_elems;
62 let slice = stage
63 .as_slice_mut(line_size)
64 .slice(start, start + tile_elems);
65 let tile = Tile::new_strided(slice, 0);
66 TMM::fill_accumulator(&tile, acc, config);
67 }
68 BiasLoader::None => {
69 TMM::zero_accumulator(acc, config);
70 }
71 }
72 }
73}
74
75#[cube]
76impl<MP: MatmulPrecision> BiasLoader<MP> {
77 pub fn new<G: GlobalConfig>(
78 tensor: CubeOption<VirtualTensor<MP::EO>>,
79 n_offset: u32,
80 #[comptime] config: G,
81 ) -> Self {
82 match tensor {
83 CubeOption::Some(tensor) => {
84 let stage = init_stage::<MP::EA, G>(config);
85 let shape_n = tensor.shape(0);
86 let tensor_view = BiasReader::<MP::EO>::new(tensor, n_offset, shape_n);
87
88 BiasLoader::<MP>::new_Some(tensor_view, stage)
89 }
90 CubeOption::None => BiasLoader::new_None(),
91 }
92 }
93}
94
95#[cube]
96fn init_stage<ES: Numeric, G: GlobalConfig>(#[comptime] config: G) -> Stage<ES, ConvTilingLayout> {
97 let line_size = config.to_smm_config().stage_line_size(Ident::Out);
98
99 let smem = SharedMemory::new_lined(
100 comptime!(config.tiling_dimensions(Ident::Out).total_col() / line_size),
101 line_size,
102 );
103
104 Stage::<ES, ConvTilingLayout>::new_with_smem(smem)
105}
106
107#[cube]
108fn init_empty_stage<ES: Numeric>() -> Stage<ES, ConvTilingLayout> {
109 Stage::<ES, ConvTilingLayout>::new_with_smem(SharedMemory::new(1))
110}