cubecl_convolution/reader/
tma.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5#[derive(CubeType)]
6pub struct Im2colTmaReader<E: Numeric> {
10 pub tensor: TensorMap<E>,
11 pub n_offset: u32,
12 pub spatial_offsets: Sequence<u32>,
13 pub k_offset: u32,
14}
15
16#[cube]
17impl<E: Numeric> Im2colTmaReader<E> {
18 #[allow(clippy::too_many_arguments)]
19 pub fn new(
20 tensor: VirtualTensor<E>,
21 n_offset: u32,
22 spatial_offsets: Sequence<u32>,
23 k_offset: u32,
24 ) -> Im2colTmaReader<E> {
25 let map = tensor.as_tensor_map();
26
27 Im2colTmaReader::<E> {
28 tensor: map,
29 n_offset,
30 spatial_offsets,
31 k_offset,
32 }
33 }
34
35 pub fn update_view(&mut self, k_offset: u32) {
37 self.k_offset += k_offset;
38 }
39}
40
41unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
42unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}