cubecl_linalg/convolution/loader/
im2col_tma.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3
4use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
5use std::marker::PhantomData;
6
7use crate::{
8    convolution::reader::tma::Im2colTmaReader,
9    matmul::components::{
10        Ident, MatmulPrecision,
11        stage::{ColMajorTilingOrder, ContiguousTilingLayout, FullReader, StageConfig},
12    },
13};
14use crate::{
15    convolution::{ConvGemmConfig, base::RuntimeArgs},
16    matmul::components::{InputIdent, stage::Stage},
17};
18
19pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
20pub type TmaIm2colReader<MP> = FullReader<<MP as MatmulPrecision>::ES, TmaIm2colTiling>;
21
22/// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm
23#[derive(CubeType)]
24pub struct TmaIm2colLoader<MP: MatmulPrecision, G: ConvGemmConfig> {
25    pub map: Im2colTmaReader<MP::EI>,
26    pub stage: Stage<MP::ES, ContiguousTilingLayout<ColMajorTilingOrder>>,
27    padded_channels: FastDivmod,
28    #[cube(comptime)]
29    _config: PhantomData<G>,
30}
31
32#[cube]
33impl<MP: MatmulPrecision, G: ConvGemmConfig> TmaIm2colLoader<MP, G> {
34    pub fn new(
35        tensor: VirtualTensor<MP::EI>,
36        x_offset: u32,
37        y_offset: u32,
38        runtime_args: &RuntimeArgs,
39        #[comptime] config: G,
40    ) -> Self {
41        let stage = Stage::new_aligned::<G::SmmConfig>(Ident::Lhs, 128u32, config.to_smm_config());
42
43        let (nh_offset, w_offset) = runtime_args.out_w.div_mod(x_offset);
44        let (n_offset, h_offset) = runtime_args.out_h.div_mod(nh_offset);
45
46        let map = Im2colTmaReader::<MP::EI>::new(tensor, n_offset, h_offset, w_offset, y_offset);
47
48        TmaIm2colLoader::<MP, G> {
49            map,
50            stage,
51            padded_channels: runtime_args.padded_channels,
52            _config: PhantomData::<G>,
53        }
54    }
55
56    pub fn fill_stage(this: &mut Self, bar: &Barrier<MP::ES>, #[comptime] config: G) {
57        let tmm = config.to_smm_config();
58        let tiling_dims = tmm.tiling_dimensions(Ident::Lhs);
59        if UNIT_POS == 0 {
60            let m_size = tiling_dims.total_row();
61            let k_size = tiling_dims.tile_shape_col();
62            let slice_size = m_size * k_size;
63            let mut full_stage = this.stage.as_slice_mut(1u32);
64            let tensor = this.map.tensor.try_cast_unchecked();
65
66            let in_h = (this.map.h_offset * config.stride(0)) as i32 - config.padding(0);
67            let in_w = (this.map.w_offset * config.stride(1)) as i32 - config.padding(1);
68
69            #[unroll]
70            for tile_k in 0..tiling_dims.tile_count_col() {
71                let k = this.map.k_offset + tile_k * k_size;
72                let (k_idx, channel_start) = this.padded_channels.div_mod(k);
73                let (k_x, k_y) = (k_idx % config.kernel_size(1), k_idx / config.kernel_size(1));
74                let slice_start = tile_k * slice_size;
75                let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
76
77                let offset_y = k_y * config.dilation(0);
78                let offset_x = k_x * config.dilation(1);
79
80                bar.tma_load_im2col_4d(
81                    &tensor,
82                    &mut stage,
83                    this.map.n_offset as i32,
84                    in_h,
85                    in_w,
86                    channel_start as i32,
87                    offset_y as u16,
88                    offset_x as u16,
89                );
90            }
91        }
92    }
93
94    pub fn advance_view(this: &mut Self, k_offset: u32) {
95        this.map.update_view(k_offset);
96    }
97
98    pub fn reader(this: &Self) -> TmaIm2colReader<MP> {
99        TmaIm2colReader::<MP>::new(this.stage, InputIdent::Lhs)
100    }
101}