cubecl_linalg/convolution/reader/
tma.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5#[derive(CubeType)]
6/// A view of a feature map tensor that starts reading data from a specified offset.
7/// Ensures safe access by preventing out-of-bounds errors.
8/// Includes pre-fetched shapes and strides for optimized performance.
9pub struct Im2colTmaReader<E: Numeric> {
10    pub tensor: TensorMap<E>,
11    pub n_offset: u32,
12    pub h_offset: u32,
13    pub w_offset: u32,
14    pub k_offset: u32,
15}
16
17#[cube]
18impl<E: Numeric> Im2colTmaReader<E> {
19    #[allow(clippy::too_many_arguments)]
20    pub fn new(
21        tensor: VirtualTensor<E>,
22        n_offset: u32,
23        h_offset: u32,
24        w_offset: u32,
25        k_offset: u32,
26    ) -> Im2colTmaReader<E> {
27        let map = tensor.as_tensor_map();
28
29        Im2colTmaReader::<E> {
30            tensor: map,
31            n_offset,
32            h_offset,
33            w_offset,
34            k_offset,
35        }
36    }
37
38    /// Advance the view along the k dimension by a specified offset, `k_offset`.
39    pub fn update_view(&mut self, k_offset: u32) {
40        self.k_offset += k_offset;
41    }
42}
43
44unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
45unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}