cubecl_convolution/loader/
im2col_tma.rs1use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
2use cubecl_core::{intrinsic, prelude::*};
3
4use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
5use std::marker::PhantomData;
6
7use crate::{
8 ConvGemmConfig,
9 base::{Dimensionality, RuntimeArgs},
10 reader::tma::Im2colTmaReader,
11};
12use cubecl_matmul::components::{
13 Ident, InputIdent, MatmulPrecision,
14 stage::{ColMajorTilingOrder, ContiguousTilingLayout, FullStageToTileReader, StageMemory},
15};
16
17pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
18pub type TmaIm2colReader<MP> = FullStageToTileReader<<MP as MatmulPrecision>::ES, TmaIm2colTiling>;
19
20#[derive(CubeType)]
22pub struct TmaIm2colLoader<MP: MatmulPrecision, G: ConvGemmConfig> {
23 pub map: Im2colTmaReader<MP::EI>,
24 pub stages: Sequence<StageMemory<MP::ES, TmaIm2colTiling>>,
25 padded_channels: FastDivmod,
26 #[cube(comptime)]
27 _config: PhantomData<G>,
28}
29
30#[cube]
31impl<MP: MatmulPrecision, G: ConvGemmConfig> TmaIm2colLoader<MP, G> {
32 pub fn new(
33 tensor: VirtualTensor<MP::EI>,
34 x_offset: u32,
35 y_offset: u32,
36 runtime_args: &RuntimeArgs,
37 #[comptime] num_stages: u32,
38 #[comptime] config: G,
39 ) -> Self {
40 let mut stages = Sequence::new();
41
42 #[unroll]
43 for _ in 0..num_stages {
44 stages.push(StageMemory::new_aligned::<G::StageConfig>(
45 Ident::Lhs,
46 128u32,
47 config.stage_config(),
48 ))
49 }
50
51 let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.out_shape);
52
53 let map = Im2colTmaReader::<MP::EI>::new(tensor, n_offs, spatial_offsets, y_offset);
54
55 TmaIm2colLoader::<MP, G> {
56 map,
57 stages,
58 padded_channels: runtime_args.padded_channels,
59 _config: PhantomData::<G>,
60 }
61 }
62
63 pub fn fill_stage(
64 this: &mut Self,
65 bar: &Barrier<MP::ES>,
66 #[comptime] stage_idx: u32,
67 #[comptime] config: G,
68 ) {
69 let stage = this.stages.index_mut(stage_idx);
70
71 if UNIT_POS == 0 {
72 let m_size = config.tiling_scheme().elements_in_stage_m();
73 let k_size = config.tiling_scheme().elements_in_tile_k();
74 let slice_size = m_size * k_size;
75 let mut full_stage = stage.as_slice_mut(1u32);
76 let tensor = this.map.tensor.try_cast_unchecked();
77
78 let spatial_dims = comptime![this.map.spatial_offsets.len()];
79 let mut in_offs = Sequence::<i32>::new();
80
81 #[unroll]
82 for dim in 0..spatial_dims {
83 let dim = unwrap(dim);
84 let offs = this.map.spatial_offsets.index(dim) * comptime![config.stride(dim)];
85 let offs = offs as i32 - comptime![config.padding(dim)];
86 in_offs.push(offs);
87 }
88
89 #[unroll]
90 for tile_k in 0..config.tiling_scheme().tiles_in_stage_k() {
91 let k = this.map.k_offset + tile_k * k_size;
92 let (k_idx, channel_start) = this.padded_channels.div_mod(k);
93 let slice_start = tile_k * slice_size;
94 let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
95
96 match config.dimensionality() {
97 Dimensionality::Dim1 => {
98 let offset = k_idx * config.dilation(0);
99
100 bar.tma_load_im2col_3d(
101 &tensor,
102 &mut stage,
103 this.map.n_offset as i32,
104 *in_offs.index(0),
105 channel_start as i32,
106 offset as u16,
107 );
108 }
109 Dimensionality::Dim2 => {
110 let (k_x, k_y) =
111 (k_idx % config.kernel_size(1), k_idx / config.kernel_size(1));
112
113 let offset_y = k_y * config.dilation(0);
114 let offset_x = k_x * config.dilation(1);
115
116 bar.tma_load_im2col_4d(
117 &tensor,
118 &mut stage,
119 this.map.n_offset as i32,
120 *in_offs.index(0),
121 *in_offs.index(1),
122 channel_start as i32,
123 offset_y as u16,
124 offset_x as u16,
125 );
126 }
127 Dimensionality::Dim3 => {
128 let (k_x, rem) =
129 (k_idx % config.kernel_size(2), k_idx / config.kernel_size(2));
130 let (k_y, k_z) = (rem % config.kernel_size(1), rem / config.kernel_size(1));
131
132 let offset_z = k_z * config.dilation(0);
133 let offset_y = k_y * config.dilation(1);
134 let offset_x = k_x * config.dilation(2);
135
136 bar.tma_load_im2col_5d(
137 &tensor,
138 &mut stage,
139 this.map.n_offset as i32,
140 *in_offs.index(0),
141 *in_offs.index(1),
142 *in_offs.index(2),
143 channel_start as i32,
144 offset_z as u16,
145 offset_y as u16,
146 offset_x as u16,
147 );
148 }
149 }
150 }
151 }
152 }
153
154 pub fn advance_view(this: &mut Self, k_offset: u32) {
155 this.map.update_view(k_offset);
156 }
157
158 pub fn reader(this: &Self, #[comptime] stage_idx: u32) -> TmaIm2colReader<MP> {
159 TmaIm2colReader::<MP>::new(*this.stages.index(stage_idx), InputIdent::Lhs)
160 }
161}
162
163#[cube]
166pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod>) -> (u32, Sequence<u32>) {
167 let rank = comptime![shape.len()];
168 let mut offs = pos;
169 let mut out = Sequence::new();
170
171 #[unroll]
172 for i in 0..rank {
173 let i = unwrap(i);
174 let dim = comptime![rank - i - 1];
175 let (rem, offs_local) = shape.index(dim).div_mod(offs);
176 out.push(offs_local);
177 offs = rem;
178 }
179
180 (offs, out.rev())
181}
182
183#[allow(unused_variables)]
184#[cube]
185fn unwrap(v: u32) -> comptime_type!(u32) {
186 intrinsic!(|_| v.constant().expect("Must be constant").as_u32())
187}