cubecl_convolution/loader/
im2col_tma.rs

1use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
2use cubecl_core::{intrinsic, prelude::*};
3
4use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
5use std::marker::PhantomData;
6
7use crate::{
8    ConvGemmConfig,
9    base::{Dimensionality, RuntimeArgs},
10    reader::tma::Im2colTmaReader,
11};
12use cubecl_matmul::components::{
13    Ident, InputIdent, MatmulPrecision,
14    stage::{ColMajorTilingOrder, ContiguousTilingLayout, FullStageToTileReader, StageMemory},
15};
16
17pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
18pub type TmaIm2colReader<MP> = FullStageToTileReader<<MP as MatmulPrecision>::ES, TmaIm2colTiling>;
19
20/// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm
21#[derive(CubeType)]
22pub struct TmaIm2colLoader<MP: MatmulPrecision, G: ConvGemmConfig> {
23    pub map: Im2colTmaReader<MP::EI>,
24    pub stages: Sequence<StageMemory<MP::ES, TmaIm2colTiling>>,
25    padded_channels: FastDivmod,
26    #[cube(comptime)]
27    _config: PhantomData<G>,
28}
29
30#[cube]
31impl<MP: MatmulPrecision, G: ConvGemmConfig> TmaIm2colLoader<MP, G> {
32    pub fn new(
33        tensor: VirtualTensor<MP::EI>,
34        x_offset: u32,
35        y_offset: u32,
36        runtime_args: &RuntimeArgs,
37        #[comptime] num_stages: u32,
38        #[comptime] config: G,
39    ) -> Self {
40        let mut stages = Sequence::new();
41
42        #[unroll]
43        for _ in 0..num_stages {
44            stages.push(StageMemory::new_aligned::<G::StageConfig>(
45                Ident::Lhs,
46                128u32,
47                config.stage_config(),
48            ))
49        }
50
51        let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.out_shape);
52
53        let map = Im2colTmaReader::<MP::EI>::new(tensor, n_offs, spatial_offsets, y_offset);
54
55        TmaIm2colLoader::<MP, G> {
56            map,
57            stages,
58            padded_channels: runtime_args.padded_channels,
59            _config: PhantomData::<G>,
60        }
61    }
62
63    pub fn fill_stage(
64        this: &mut Self,
65        bar: &Barrier<MP::ES>,
66        #[comptime] stage_idx: u32,
67        #[comptime] config: G,
68    ) {
69        let stage = this.stages.index_mut(stage_idx);
70
71        if UNIT_POS == 0 {
72            let m_size = config.tiling_scheme().elements_in_stage_m();
73            let k_size = config.tiling_scheme().elements_in_tile_k();
74            let slice_size = m_size * k_size;
75            let mut full_stage = stage.as_slice_mut(1u32);
76            let tensor = this.map.tensor.try_cast_unchecked();
77
78            let spatial_dims = comptime![this.map.spatial_offsets.len()];
79            let mut in_offs = Sequence::<i32>::new();
80
81            #[unroll]
82            for dim in 0..spatial_dims {
83                let dim = unwrap(dim);
84                let offs = this.map.spatial_offsets.index(dim) * comptime![config.stride(dim)];
85                let offs = offs as i32 - comptime![config.padding(dim)];
86                in_offs.push(offs);
87            }
88
89            #[unroll]
90            for tile_k in 0..config.tiling_scheme().tiles_in_stage_k() {
91                let k = this.map.k_offset + tile_k * k_size;
92                let (k_idx, channel_start) = this.padded_channels.div_mod(k);
93                let slice_start = tile_k * slice_size;
94                let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
95
96                match config.dimensionality() {
97                    Dimensionality::Dim1 => {
98                        let offset = k_idx * config.dilation(0);
99
100                        bar.tma_load_im2col_3d(
101                            &tensor,
102                            &mut stage,
103                            this.map.n_offset as i32,
104                            *in_offs.index(0),
105                            channel_start as i32,
106                            offset as u16,
107                        );
108                    }
109                    Dimensionality::Dim2 => {
110                        let (k_x, k_y) =
111                            (k_idx % config.kernel_size(1), k_idx / config.kernel_size(1));
112
113                        let offset_y = k_y * config.dilation(0);
114                        let offset_x = k_x * config.dilation(1);
115
116                        bar.tma_load_im2col_4d(
117                            &tensor,
118                            &mut stage,
119                            this.map.n_offset as i32,
120                            *in_offs.index(0),
121                            *in_offs.index(1),
122                            channel_start as i32,
123                            offset_y as u16,
124                            offset_x as u16,
125                        );
126                    }
127                    Dimensionality::Dim3 => {
128                        let (k_x, rem) =
129                            (k_idx % config.kernel_size(2), k_idx / config.kernel_size(2));
130                        let (k_y, k_z) = (rem % config.kernel_size(1), rem / config.kernel_size(1));
131
132                        let offset_z = k_z * config.dilation(0);
133                        let offset_y = k_y * config.dilation(1);
134                        let offset_x = k_x * config.dilation(2);
135
136                        bar.tma_load_im2col_5d(
137                            &tensor,
138                            &mut stage,
139                            this.map.n_offset as i32,
140                            *in_offs.index(0),
141                            *in_offs.index(1),
142                            *in_offs.index(2),
143                            channel_start as i32,
144                            offset_z as u16,
145                            offset_y as u16,
146                            offset_x as u16,
147                        );
148                    }
149                }
150            }
151        }
152    }
153
154    pub fn advance_view(this: &mut Self, k_offset: u32) {
155        this.map.update_view(k_offset);
156    }
157
158    pub fn reader(this: &Self, #[comptime] stage_idx: u32) -> TmaIm2colReader<MP> {
159        TmaIm2colReader::<MP>::new(*this.stages.index(stage_idx), InputIdent::Lhs)
160    }
161}
162
163/// Decompose a linear index into local positions along each dimension in `shape`. Also returns the
164/// left over remainder.
165#[cube]
166pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod>) -> (u32, Sequence<u32>) {
167    let rank = comptime![shape.len()];
168    let mut offs = pos;
169    let mut out = Sequence::new();
170
171    #[unroll]
172    for i in 0..rank {
173        let i = unwrap(i);
174        let dim = comptime![rank - i - 1];
175        let (rem, offs_local) = shape.index(dim).div_mod(offs);
176        out.push(offs_local);
177        offs = rem;
178    }
179
180    (offs, out.rev())
181}
182
183#[allow(unused_variables)]
184#[cube]
185fn unwrap(v: u32) -> comptime_type!(u32) {
186    intrinsic!(|_| v.constant().expect("Must be constant").as_u32())
187}