cubecl_convolution/components/global/read/reader/
im2col_tma.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3
4use cubecl_matmul::components::{MatrixPrecision, stage::StageMemoryConfig};
5use cubecl_std::FastDivmod;
6
7use crate::{
8    components::{ConvolutionParams, Dimensionality, global::memory::Im2colTmaReader},
9    kernels::layered::selector::RuntimeArgs,
10};
11use cubecl_matmul::components::stage::{
12    ColMajorTilingOrder, ContiguousTilingLayout, StridedStageMemory,
13};
14
15pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
16pub type TmaIm2colStage<IP> = StridedStageMemory<<IP as MatrixPrecision>::Stage, TmaIm2colTiling>;
17
18/// Reader that translates matrix coordinates to input coordinates using the `im2col` algorithm
19#[derive(CubeType)]
20pub struct TmaIm2colGlobalReader<IP: MatrixPrecision> {
21    pub map: Im2colTmaReader<IP::Global>,
22    pub stages: Sequence<StridedStageMemory<IP::Stage, TmaIm2colTiling>>,
23    padded_channels: FastDivmod,
24    #[cube(comptime)]
25    params: ConvolutionParams,
26    #[cube(comptime)]
27    config: StageMemoryConfig,
28}
29
30#[cube]
31impl<IP: MatrixPrecision> TmaIm2colGlobalReader<IP> {
32    pub fn new(
33        tensor: TensorMap<Line<IP::Global>>,
34        x_offset: u32,
35        y_offset: u32,
36        runtime_args: &RuntimeArgs,
37        #[comptime] num_stages: u32,
38        #[comptime] params: ConvolutionParams,
39        #[comptime] config: StageMemoryConfig,
40    ) -> Self {
41        let mut stages = Sequence::new();
42
43        #[unroll]
44        for _ in 0..num_stages {
45            stages.push(StridedStageMemory::new_aligned(128u32, config))
46        }
47
48        let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.shape_out);
49
50        let map = Im2colTmaReader::<IP::Global>::new(tensor, n_offs, spatial_offsets, y_offset);
51
52        TmaIm2colGlobalReader::<IP> {
53            map,
54            stages,
55            padded_channels: runtime_args.padded_channels,
56            params,
57            config,
58        }
59    }
60
61    pub fn fill_stage(&mut self, bar: &Barrier, #[comptime] stage_idx: u32) {
62        let stage = self.stages.index_mut(stage_idx);
63        let params = comptime![self.params];
64        let config = comptime![self.config];
65
66        if UNIT_POS == 0 {
67            let m_size = config.elements_per_stage_along_row();
68            let k_size = config.elements_per_tile_along_col;
69            let slice_size = m_size * k_size;
70            let mut full_stage = stage.as_slice_mut(1u32);
71            let tensor = self.map.tensor.try_cast_unchecked();
72
73            let spatial_dims = comptime![self.map.spatial_offsets.len()];
74            let mut in_offs = Sequence::<i32>::new();
75
76            #[unroll]
77            for dim in 0..spatial_dims {
78                let offs =
79                    self.map.spatial_offsets.index(dim) * comptime![params.stride[dim as usize]];
80                let offs = offs as i32 - comptime![params.padding[dim as usize]];
81                in_offs.push(offs);
82            }
83
84            #[unroll]
85            for tile_k in 0..config.tiles_per_stage_along_col() {
86                let k = self.map.k_offset + tile_k * k_size;
87                let (k_idx, channel_start) = self.padded_channels.div_mod(k);
88                let slice_start = tile_k * slice_size;
89                let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
90
91                match params.dimensionality {
92                    Dimensionality::Dim1 => {
93                        let offset = k_idx * comptime![params.dilation[0]];
94
95                        bar.tma_load_im2col_3d(
96                            &tensor,
97                            &mut stage,
98                            self.map.n_offset as i32,
99                            *in_offs.index(0),
100                            channel_start as i32,
101                            offset as u16,
102                        );
103                    }
104                    Dimensionality::Dim2 => {
105                        let (k_x, k_y) = (
106                            k_idx % comptime![params.kernel_size[1]],
107                            k_idx / comptime![params.kernel_size[1]],
108                        );
109
110                        let offset_y = k_y * comptime![params.dilation[0]];
111                        let offset_x = k_x * comptime![params.dilation[1]];
112
113                        bar.tma_load_im2col_4d(
114                            &tensor,
115                            &mut stage,
116                            self.map.n_offset as i32,
117                            *in_offs.index(0),
118                            *in_offs.index(1),
119                            channel_start as i32,
120                            offset_y as u16,
121                            offset_x as u16,
122                        );
123                    }
124                    Dimensionality::Dim3 => {
125                        let (k_x, rem) = (
126                            k_idx % comptime![params.kernel_size[2]],
127                            k_idx / comptime![params.kernel_size[2]],
128                        );
129                        let (k_y, k_z) = (
130                            rem % comptime![params.kernel_size[1]],
131                            rem / comptime![params.kernel_size[1]],
132                        );
133
134                        let offset_z = k_z * comptime![params.dilation[0]];
135                        let offset_y = k_y * comptime![params.dilation[1]];
136                        let offset_x = k_x * comptime![params.dilation[2]];
137
138                        bar.tma_load_im2col_5d(
139                            &tensor,
140                            &mut stage,
141                            self.map.n_offset as i32,
142                            *in_offs.index(0),
143                            *in_offs.index(1),
144                            *in_offs.index(2),
145                            channel_start as i32,
146                            offset_z as u16,
147                            offset_y as u16,
148                            offset_x as u16,
149                        );
150                    }
151                }
152            }
153        }
154    }
155
156    pub fn advance_view(&mut self, k_offset: u32) {
157        self.map.update_view(k_offset);
158    }
159
160    pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaIm2colStage<IP> {
161        *self.stages.index(stage_idx)
162    }
163}
164
165/// Decompose a linear index into local positions along each dimension in `shape`. Also returns the
166/// left over remainder.
167#[cube]
168pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod>) -> (u32, Sequence<u32>) {
169    let rank = comptime![shape.len()];
170    let mut offs = pos;
171    let mut out = Sequence::new();
172
173    #[unroll]
174    for i in 0..rank {
175        let dim = comptime![rank - i - 1];
176        let (rem, offs_local) = shape.index(dim).div_mod(offs);
177        out.push(offs_local);
178        offs = rem;
179    }
180
181    (offs, out.rev())
182}