cubecl_convolution/components/global/read/reader/
im2col_tma.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3
4use cubecl_matmul::components::{MatrixPrecision, stage::StageMemoryConfig};
5use cubecl_std::FastDivmod;
6
7use crate::{
8 components::{ConvolutionParams, Dimensionality, global::memory::Im2colTmaReader},
9 kernels::layered::selector::RuntimeArgs,
10};
11use cubecl_matmul::components::stage::{
12 ColMajorTilingOrder, ContiguousTilingLayout, StridedStageMemory,
13};
14
15pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
16pub type TmaIm2colStage<IP> = StridedStageMemory<<IP as MatrixPrecision>::Stage, TmaIm2colTiling>;
17
18#[derive(CubeType)]
20pub struct TmaIm2colGlobalReader<IP: MatrixPrecision> {
21 pub map: Im2colTmaReader<IP::Global>,
22 pub stages: Sequence<StridedStageMemory<IP::Stage, TmaIm2colTiling>>,
23 padded_channels: FastDivmod,
24 #[cube(comptime)]
25 params: ConvolutionParams,
26 #[cube(comptime)]
27 config: StageMemoryConfig,
28}
29
30#[cube]
31impl<IP: MatrixPrecision> TmaIm2colGlobalReader<IP> {
32 pub fn new(
33 tensor: TensorMap<Line<IP::Global>>,
34 x_offset: u32,
35 y_offset: u32,
36 runtime_args: &RuntimeArgs,
37 #[comptime] num_stages: u32,
38 #[comptime] params: ConvolutionParams,
39 #[comptime] config: StageMemoryConfig,
40 ) -> Self {
41 let mut stages = Sequence::new();
42
43 #[unroll]
44 for _ in 0..num_stages {
45 stages.push(StridedStageMemory::new_aligned(128u32, config))
46 }
47
48 let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.shape_out);
49
50 let map = Im2colTmaReader::<IP::Global>::new(tensor, n_offs, spatial_offsets, y_offset);
51
52 TmaIm2colGlobalReader::<IP> {
53 map,
54 stages,
55 padded_channels: runtime_args.padded_channels,
56 params,
57 config,
58 }
59 }
60
61 pub fn fill_stage(&mut self, bar: &Barrier, #[comptime] stage_idx: u32) {
62 let stage = self.stages.index_mut(stage_idx);
63 let params = comptime![self.params];
64 let config = comptime![self.config];
65
66 if UNIT_POS == 0 {
67 let m_size = config.elements_per_stage_along_row();
68 let k_size = config.elements_per_tile_along_col;
69 let slice_size = m_size * k_size;
70 let mut full_stage = stage.as_slice_mut(1u32);
71 let tensor = self.map.tensor.try_cast_unchecked();
72
73 let spatial_dims = comptime![self.map.spatial_offsets.len()];
74 let mut in_offs = Sequence::<i32>::new();
75
76 #[unroll]
77 for dim in 0..spatial_dims {
78 let offs =
79 self.map.spatial_offsets.index(dim) * comptime![params.stride[dim as usize]];
80 let offs = offs as i32 - comptime![params.padding[dim as usize]];
81 in_offs.push(offs);
82 }
83
84 #[unroll]
85 for tile_k in 0..config.tiles_per_stage_along_col() {
86 let k = self.map.k_offset + tile_k * k_size;
87 let (k_idx, channel_start) = self.padded_channels.div_mod(k);
88 let slice_start = tile_k * slice_size;
89 let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
90
91 match params.dimensionality {
92 Dimensionality::Dim1 => {
93 let offset = k_idx * comptime![params.dilation[0]];
94
95 bar.tma_load_im2col_3d(
96 &tensor,
97 &mut stage,
98 self.map.n_offset as i32,
99 *in_offs.index(0),
100 channel_start as i32,
101 offset as u16,
102 );
103 }
104 Dimensionality::Dim2 => {
105 let (k_x, k_y) = (
106 k_idx % comptime![params.kernel_size[1]],
107 k_idx / comptime![params.kernel_size[1]],
108 );
109
110 let offset_y = k_y * comptime![params.dilation[0]];
111 let offset_x = k_x * comptime![params.dilation[1]];
112
113 bar.tma_load_im2col_4d(
114 &tensor,
115 &mut stage,
116 self.map.n_offset as i32,
117 *in_offs.index(0),
118 *in_offs.index(1),
119 channel_start as i32,
120 offset_y as u16,
121 offset_x as u16,
122 );
123 }
124 Dimensionality::Dim3 => {
125 let (k_x, rem) = (
126 k_idx % comptime![params.kernel_size[2]],
127 k_idx / comptime![params.kernel_size[2]],
128 );
129 let (k_y, k_z) = (
130 rem % comptime![params.kernel_size[1]],
131 rem / comptime![params.kernel_size[1]],
132 );
133
134 let offset_z = k_z * comptime![params.dilation[0]];
135 let offset_y = k_y * comptime![params.dilation[1]];
136 let offset_x = k_x * comptime![params.dilation[2]];
137
138 bar.tma_load_im2col_5d(
139 &tensor,
140 &mut stage,
141 self.map.n_offset as i32,
142 *in_offs.index(0),
143 *in_offs.index(1),
144 *in_offs.index(2),
145 channel_start as i32,
146 offset_z as u16,
147 offset_y as u16,
148 offset_x as u16,
149 );
150 }
151 }
152 }
153 }
154 }
155
156 pub fn advance_view(&mut self, k_offset: u32) {
157 self.map.update_view(k_offset);
158 }
159
160 pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaIm2colStage<IP> {
161 *self.stages.index(stage_idx)
162 }
163}
164
165#[cube]
168pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod>) -> (u32, Sequence<u32>) {
169 let rank = comptime![shape.len()];
170 let mut offs = pos;
171 let mut out = Sequence::new();
172
173 #[unroll]
174 for i in 0..rank {
175 let dim = comptime![rank - i - 1];
176 let (rem, offs_local) = shape.index(dim).div_mod(offs);
177 out.push(offs_local);
178 offs = rem;
179 }
180
181 (offs, out.rev())
182}