cubecl_convolution/components/global/memory/
tma.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4#[derive(CubeType)]
5pub struct Im2colTmaReader<E: Numeric> {
9    pub tensor: TensorMap<Line<E>>,
10    pub n_offset: u32,
11    pub spatial_offsets: Sequence<u32>,
12    pub k_offset: u32,
13}
14
15#[cube]
16impl<E: Numeric> Im2colTmaReader<E> {
17    #[allow(clippy::too_many_arguments)]
18    pub fn new(
19        tensor: TensorMap<Line<E>>,
20        n_offset: u32,
21        spatial_offsets: Sequence<u32>,
22        k_offset: u32,
23    ) -> Im2colTmaReader<E> {
24        Im2colTmaReader::<E> {
25            tensor,
26            n_offset,
27            spatial_offsets,
28            k_offset,
29        }
30    }
31
32    pub fn update_view(&mut self, k_offset: u32) {
34        self.k_offset += k_offset;
35    }
36}
37
38unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
39unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}