cubecl_convolution/reader/
tma.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5#[derive(CubeType)]
6/// A view of a feature map tensor that starts reading data from a specified offset.
7/// Ensures safe access by preventing out-of-bounds errors.
8/// Includes pre-fetched shapes and strides for optimized performance.
9pub struct Im2colTmaReader<E: Numeric> {
10    pub tensor: TensorMap<E>,
11    pub n_offset: u32,
12    pub spatial_offsets: Sequence<u32>,
13    pub k_offset: u32,
14}
15
16#[cube]
17impl<E: Numeric> Im2colTmaReader<E> {
18    #[allow(clippy::too_many_arguments)]
19    pub fn new(
20        tensor: VirtualTensor<E>,
21        n_offset: u32,
22        spatial_offsets: Sequence<u32>,
23        k_offset: u32,
24    ) -> Im2colTmaReader<E> {
25        let map = tensor.as_tensor_map();
26
27        Im2colTmaReader::<E> {
28            tensor: map,
29            n_offset,
30            spatial_offsets,
31            k_offset,
32        }
33    }
34
35    /// Advance the view along the k dimension by a specified offset, `k_offset`.
36    pub fn update_view(&mut self, k_offset: u32) {
37        self.k_offset += k_offset;
38    }
39}
40
41unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
42unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}