cubecl_convolution/components/global/read/reader/
im2col_tma.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3
4use cubecl_matmul::components::{MatrixPrecision, StageIdent, stage::StageMemoryConfig};
5use cubecl_std::FastDivmod;
6
7use crate::{
8    components::{ConvolutionParams, Dimensionality, global::memory::Im2colTmaReader},
9    kernels::layered::selector::RuntimeArgs,
10};
11use cubecl_matmul::components::stage::{ColMajorTilingOrder, ContiguousTilingLayout, StridedStage};
12
13pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
14pub type TmaIm2colStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaIm2colTiling>;
15
16#[derive(CubeType)]
18pub struct TmaIm2colGlobalReader<IP: MatrixPrecision> {
19    pub map: Im2colTmaReader<IP::Global>,
20    pub stages: Sequence<StridedStage<IP::Stage, TmaIm2colTiling>>,
21    padded_channels: FastDivmod,
22    #[cube(comptime)]
23    params: ConvolutionParams,
24    #[cube(comptime)]
25    config: StageMemoryConfig,
26}
27
28#[cube]
29impl<IP: MatrixPrecision> TmaIm2colGlobalReader<IP> {
30    pub fn new(
31        tensor: TensorMap<Line<IP::Global>>,
32        x_offset: u32,
33        y_offset: u32,
34        runtime_args: &RuntimeArgs,
35        #[comptime] num_stages: u32,
36        #[comptime] params: ConvolutionParams,
37        #[comptime] config: StageMemoryConfig,
38    ) -> Self {
39        let mut stages = Sequence::new();
40
41        #[unroll]
42        for _ in 0..num_stages {
43            stages.push(StridedStage::new_aligned(StageIdent::Lhs, 128u32, config))
44        }
45
46        let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.shape_out);
47
48        let map = Im2colTmaReader::<IP::Global>::new(tensor, n_offs, spatial_offsets, y_offset);
49
50        TmaIm2colGlobalReader::<IP> {
51            map,
52            stages,
53            padded_channels: runtime_args.padded_channels,
54            params,
55            config,
56        }
57    }
58
59    pub fn fill_stage(&mut self, bar: &Barrier, #[comptime] stage_idx: u32) {
60        let stage = self.stages.index_mut(stage_idx);
61        let params = comptime![self.params];
62        let config = comptime![self.config];
63
64        if UNIT_POS == 0 {
65            let m_size = config.elements_in_stage_row();
66            let k_size = config.elements_in_tile_col;
67            let slice_size = m_size * k_size;
68            let mut full_stage = stage.as_slice_mut(1u32);
69            let tensor = self.map.tensor.try_cast_unchecked();
70
71            let spatial_dims = comptime![self.map.spatial_offsets.len()];
72            let mut in_offs = Sequence::<i32>::new();
73
74            #[unroll]
75            for dim in 0..spatial_dims {
76                let offs =
77                    self.map.spatial_offsets.index(dim) * comptime![params.stride[dim as usize]];
78                let offs = offs as i32 - comptime![params.padding[dim as usize]];
79                in_offs.push(offs);
80            }
81
82            #[unroll]
83            for tile_k in 0..config.tiles_in_stage_col {
84                let k = self.map.k_offset + tile_k * k_size;
85                let (k_idx, channel_start) = self.padded_channels.div_mod(k);
86                let slice_start = tile_k * slice_size;
87                let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
88
89                match params.dimensionality {
90                    Dimensionality::Dim1 => {
91                        let offset = k_idx * comptime![params.dilation[0]];
92
93                        bar.tma_load_im2col_3d(
94                            &tensor,
95                            &mut stage,
96                            self.map.n_offset as i32,
97                            *in_offs.index(0),
98                            channel_start as i32,
99                            offset as u16,
100                        );
101                    }
102                    Dimensionality::Dim2 => {
103                        let (k_x, k_y) = (
104                            k_idx % comptime![params.kernel_size[1]],
105                            k_idx / comptime![params.kernel_size[1]],
106                        );
107
108                        let offset_y = k_y * comptime![params.dilation[0]];
109                        let offset_x = k_x * comptime![params.dilation[1]];
110
111                        bar.tma_load_im2col_4d(
112                            &tensor,
113                            &mut stage,
114                            self.map.n_offset as i32,
115                            *in_offs.index(0),
116                            *in_offs.index(1),
117                            channel_start as i32,
118                            offset_y as u16,
119                            offset_x as u16,
120                        );
121                    }
122                    Dimensionality::Dim3 => {
123                        let (k_x, rem) = (
124                            k_idx % comptime![params.kernel_size[2]],
125                            k_idx / comptime![params.kernel_size[2]],
126                        );
127                        let (k_y, k_z) = (
128                            rem % comptime![params.kernel_size[1]],
129                            rem / comptime![params.kernel_size[1]],
130                        );
131
132                        let offset_z = k_z * comptime![params.dilation[0]];
133                        let offset_y = k_y * comptime![params.dilation[1]];
134                        let offset_x = k_x * comptime![params.dilation[2]];
135
136                        bar.tma_load_im2col_5d(
137                            &tensor,
138                            &mut stage,
139                            self.map.n_offset as i32,
140                            *in_offs.index(0),
141                            *in_offs.index(1),
142                            *in_offs.index(2),
143                            channel_start as i32,
144                            offset_z as u16,
145                            offset_y as u16,
146                            offset_x as u16,
147                        );
148                    }
149                }
150            }
151        }
152    }
153
154    pub fn advance_view(&mut self, k_offset: u32) {
155        self.map.update_view(k_offset);
156    }
157
158    pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaIm2colStage<IP> {
159        *self.stages.index(stage_idx)
160    }
161}
162
163#[cube]
166pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod>) -> (u32, Sequence<u32>) {
167    let rank = comptime![shape.len()];
168    let mut offs = pos;
169    let mut out = Sequence::new();
170
171    #[unroll]
172    for i in 0..rank {
173        let dim = comptime![rank - i - 1];
174        let (rem, offs_local) = shape.index(dim).div_mod(offs);
175        out.push(offs_local);
176        offs = rem;
177    }
178
179    (offs, out.rev())
180}