cubecl_convolution/reader/
bias.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use cubecl_matmul::components::global::GlobalConfig;
6
7#[derive(CubeType)]
8/// A view of a tensor that starts reading data from a specified offset.
9/// Ensures safe access by preventing out-of-bounds errors.
10/// Includes pre-fetched shapes and strides for optimized performance.
11pub struct BiasReader<E: Numeric> {
12    pub tensor: VirtualTensor<E>,
13    pub n_offset: u32,
14    pub shape_n: u32,
15}
16
17unsafe impl<E: Numeric> Sync for BiasReader<E> {}
18unsafe impl<E: Numeric> Send for BiasReader<E> {}
19
20#[cube]
21impl<E: Numeric> BiasReader<E> {
22    /// Load the 1D bias into shared memory
23    pub fn new(tensor: VirtualTensor<E>, n_offset: u32, shape_n: u32) -> BiasReader<E> {
24        BiasReader::<E> {
25            tensor,
26            n_offset,
27            shape_n,
28        }
29    }
30
31    /// Load the 1D bias into shared memory
32    pub fn load_simple<G: GlobalConfig>(
33        &self,
34        unit_id: u32,
35        #[comptime] line_size: u32,
36    ) -> Line<E> {
37        let view_n = self.n_offset + unit_id;
38        let read_pos = view_n / line_size;
39
40        select(
41            view_n < self.shape_n,
42            self.read(read_pos),
43            Line::empty(line_size).fill(E::from_int(0)),
44        )
45    }
46
47    fn read(&self, position: u32) -> Line<E> {
48        self.tensor.read(position)
49    }
50}