cubecl_convolution/reader/
bias.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use cubecl_matmul::components::global::GlobalConfig;
6
7#[derive(CubeType)]
8pub struct BiasReader<E: Numeric> {
12 pub tensor: VirtualTensor<E>,
13 pub n_offset: u32,
14 pub shape_n: u32,
15}
16
17unsafe impl<E: Numeric> Sync for BiasReader<E> {}
18unsafe impl<E: Numeric> Send for BiasReader<E> {}
19
20#[cube]
21impl<E: Numeric> BiasReader<E> {
22 pub fn new(tensor: VirtualTensor<E>, n_offset: u32, shape_n: u32) -> BiasReader<E> {
24 BiasReader::<E> {
25 tensor,
26 n_offset,
27 shape_n,
28 }
29 }
30
31 pub fn load_simple<G: GlobalConfig>(
33 &self,
34 unit_id: u32,
35 #[comptime] line_size: u32,
36 ) -> Line<E> {
37 let view_n = self.n_offset + unit_id;
38 let read_pos = view_n / line_size;
39
40 select(
41 view_n < self.shape_n,
42 self.read(read_pos),
43 Line::empty(line_size).fill(E::from_int(0)),
44 )
45 }
46
47 fn read(&self, position: u32) -> Line<E> {
48 self.tensor.read(position)
49 }
50}