cubecl_convolution/components/global/memory/
tma.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4#[derive(CubeType)]
5pub struct Im2colTmaReader<E: Numeric> {
9 pub tensor: TensorMap<Line<E>>,
10 pub n_offset: u32,
11 pub spatial_offsets: Sequence<u32>,
12 pub k_offset: u32,
13}
14
15#[cube]
16impl<E: Numeric> Im2colTmaReader<E> {
17 #[allow(clippy::too_many_arguments)]
18 pub fn new(
19 tensor: TensorMap<Line<E>>,
20 n_offset: u32,
21 spatial_offsets: Sequence<u32>,
22 k_offset: u32,
23 ) -> Im2colTmaReader<E> {
24 Im2colTmaReader::<E> {
25 tensor,
26 n_offset,
27 spatial_offsets,
28 k_offset,
29 }
30 }
31
32 pub fn update_view(&mut self, k_offset: u32) {
34 self.k_offset += k_offset;
35 }
36}
37
38unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
39unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}