cubecl_linalg/convolution/reader/
tma.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5#[derive(CubeType)]
6pub struct Im2colTmaReader<E: Numeric> {
10 pub tensor: TensorMap<E>,
11 pub n_offset: u32,
12 pub h_offset: u32,
13 pub w_offset: u32,
14 pub k_offset: u32,
15}
16
17#[cube]
18impl<E: Numeric> Im2colTmaReader<E> {
19 #[allow(clippy::too_many_arguments)]
20 pub fn new(
21 tensor: VirtualTensor<E>,
22 n_offset: u32,
23 h_offset: u32,
24 w_offset: u32,
25 k_offset: u32,
26 ) -> Im2colTmaReader<E> {
27 let map = tensor.as_tensor_map();
28
29 Im2colTmaReader::<E> {
30 tensor: map,
31 n_offset,
32 h_offset,
33 w_offset,
34 k_offset,
35 }
36 }
37
38 pub fn update_view(&mut self, k_offset: u32) {
40 self.k_offset += k_offset;
41 }
42}
43
44unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
45unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}