cubecl_linalg/convolution/loader/
im2col_tma.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3
4use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
5use std::marker::PhantomData;
6
7use crate::{
8 convolution::reader::tma::Im2colTmaReader,
9 matmul::components::{
10 Ident, MatmulPrecision,
11 stage::{ColMajorTilingOrder, ContiguousTilingLayout, FullReader, StageConfig},
12 },
13};
14use crate::{
15 convolution::{ConvGemmConfig, base::RuntimeArgs},
16 matmul::components::{InputIdent, stage::Stage},
17};
18
19pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
20pub type TmaIm2colReader<MP> = FullReader<<MP as MatmulPrecision>::ES, TmaIm2colTiling>;
21
22#[derive(CubeType)]
24pub struct TmaIm2colLoader<MP: MatmulPrecision, G: ConvGemmConfig> {
25 pub map: Im2colTmaReader<MP::EI>,
26 pub stage: Stage<MP::ES, ContiguousTilingLayout<ColMajorTilingOrder>>,
27 padded_channels: FastDivmod,
28 #[cube(comptime)]
29 _config: PhantomData<G>,
30}
31
32#[cube]
33impl<MP: MatmulPrecision, G: ConvGemmConfig> TmaIm2colLoader<MP, G> {
34 pub fn new(
35 tensor: VirtualTensor<MP::EI>,
36 x_offset: u32,
37 y_offset: u32,
38 runtime_args: &RuntimeArgs,
39 #[comptime] config: G,
40 ) -> Self {
41 let stage = Stage::new_aligned::<G::SmmConfig>(Ident::Lhs, 128u32, config.to_smm_config());
42
43 let (nh_offset, w_offset) = runtime_args.out_w.div_mod(x_offset);
44 let (n_offset, h_offset) = runtime_args.out_h.div_mod(nh_offset);
45
46 let map = Im2colTmaReader::<MP::EI>::new(tensor, n_offset, h_offset, w_offset, y_offset);
47
48 TmaIm2colLoader::<MP, G> {
49 map,
50 stage,
51 padded_channels: runtime_args.padded_channels,
52 _config: PhantomData::<G>,
53 }
54 }
55
56 pub fn fill_stage(this: &mut Self, bar: &Barrier<MP::ES>, #[comptime] config: G) {
57 let tmm = config.to_smm_config();
58 let tiling_dims = tmm.tiling_dimensions(Ident::Lhs);
59 if UNIT_POS == 0 {
60 let m_size = tiling_dims.total_row();
61 let k_size = tiling_dims.tile_shape_col();
62 let slice_size = m_size * k_size;
63 let mut full_stage = this.stage.as_slice_mut(1u32);
64 let tensor = this.map.tensor.try_cast_unchecked();
65
66 let in_h = (this.map.h_offset * config.stride(0)) as i32 - config.padding(0);
67 let in_w = (this.map.w_offset * config.stride(1)) as i32 - config.padding(1);
68
69 #[unroll]
70 for tile_k in 0..tiling_dims.tile_count_col() {
71 let k = this.map.k_offset + tile_k * k_size;
72 let (k_idx, channel_start) = this.padded_channels.div_mod(k);
73 let (k_x, k_y) = (k_idx % config.kernel_size(1), k_idx / config.kernel_size(1));
74 let slice_start = tile_k * slice_size;
75 let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
76
77 let offset_y = k_y * config.dilation(0);
78 let offset_x = k_x * config.dilation(1);
79
80 bar.tma_load_im2col_4d(
81 &tensor,
82 &mut stage,
83 this.map.n_offset as i32,
84 in_h,
85 in_w,
86 channel_start as i32,
87 offset_y as u16,
88 offset_x as u16,
89 );
90 }
91 }
92 }
93
94 pub fn advance_view(this: &mut Self, k_offset: u32) {
95 this.map.update_view(k_offset);
96 }
97
98 pub fn reader(this: &Self) -> TmaIm2colReader<MP> {
99 TmaIm2colReader::<MP>::new(this.stage, InputIdent::Lhs)
100 }
101}