cubecl_convolution/components/global/read/reader/
im2col_tma.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3
4use cubecl_matmul::components::{MatrixPrecision, StageIdent, stage::StageMemoryConfig};
5use cubecl_std::FastDivmod;
6
7use crate::{
8 components::{ConvolutionParams, Dimensionality, global::memory::Im2colTmaReader},
9 kernels::layered::selector::RuntimeArgs,
10};
11use cubecl_matmul::components::stage::{ColMajorTilingOrder, ContiguousTilingLayout, StridedStage};
12
13pub type TmaIm2colTiling = ContiguousTilingLayout<ColMajorTilingOrder>;
14pub type TmaIm2colStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaIm2colTiling>;
15
16#[derive(CubeType)]
18pub struct TmaIm2colGlobalReader<IP: MatrixPrecision> {
19 pub map: Im2colTmaReader<IP::Global>,
20 pub stages: Sequence<StridedStage<IP::Stage, TmaIm2colTiling>>,
21 padded_channels: FastDivmod,
22 #[cube(comptime)]
23 params: ConvolutionParams,
24 #[cube(comptime)]
25 config: StageMemoryConfig,
26}
27
28#[cube]
29impl<IP: MatrixPrecision> TmaIm2colGlobalReader<IP> {
30 pub fn new(
31 tensor: TensorMap<Line<IP::Global>>,
32 x_offset: u32,
33 y_offset: u32,
34 runtime_args: &RuntimeArgs,
35 #[comptime] num_stages: u32,
36 #[comptime] params: ConvolutionParams,
37 #[comptime] config: StageMemoryConfig,
38 ) -> Self {
39 let mut stages = Sequence::new();
40
41 #[unroll]
42 for _ in 0..num_stages {
43 stages.push(StridedStage::new_aligned(StageIdent::Lhs, 128u32, config))
44 }
45
46 let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.shape_out);
47
48 let map = Im2colTmaReader::<IP::Global>::new(tensor, n_offs, spatial_offsets, y_offset);
49
50 TmaIm2colGlobalReader::<IP> {
51 map,
52 stages,
53 padded_channels: runtime_args.padded_channels,
54 params,
55 config,
56 }
57 }
58
59 pub fn fill_stage(&mut self, bar: &Barrier, #[comptime] stage_idx: u32) {
60 let stage = self.stages.index_mut(stage_idx);
61 let params = comptime![self.params];
62 let config = comptime![self.config];
63
64 if UNIT_POS == 0 {
65 let m_size = config.elements_in_stage_row();
66 let k_size = config.elements_in_tile_col;
67 let slice_size = m_size * k_size;
68 let mut full_stage = stage.as_slice_mut(1u32);
69 let tensor = self.map.tensor.try_cast_unchecked();
70
71 let spatial_dims = comptime![self.map.spatial_offsets.len()];
72 let mut in_offs = Sequence::<i32>::new();
73
74 #[unroll]
75 for dim in 0..spatial_dims {
76 let offs =
77 self.map.spatial_offsets.index(dim) * comptime![params.stride[dim as usize]];
78 let offs = offs as i32 - comptime![params.padding[dim as usize]];
79 in_offs.push(offs);
80 }
81
82 #[unroll]
83 for tile_k in 0..config.tiles_in_stage_col {
84 let k = self.map.k_offset + tile_k * k_size;
85 let (k_idx, channel_start) = self.padded_channels.div_mod(k);
86 let slice_start = tile_k * slice_size;
87 let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size);
88
89 match params.dimensionality {
90 Dimensionality::Dim1 => {
91 let offset = k_idx * comptime![params.dilation[0]];
92
93 bar.tma_load_im2col_3d(
94 &tensor,
95 &mut stage,
96 self.map.n_offset as i32,
97 *in_offs.index(0),
98 channel_start as i32,
99 offset as u16,
100 );
101 }
102 Dimensionality::Dim2 => {
103 let (k_x, k_y) = (
104 k_idx % comptime![params.kernel_size[1]],
105 k_idx / comptime![params.kernel_size[1]],
106 );
107
108 let offset_y = k_y * comptime![params.dilation[0]];
109 let offset_x = k_x * comptime![params.dilation[1]];
110
111 bar.tma_load_im2col_4d(
112 &tensor,
113 &mut stage,
114 self.map.n_offset as i32,
115 *in_offs.index(0),
116 *in_offs.index(1),
117 channel_start as i32,
118 offset_y as u16,
119 offset_x as u16,
120 );
121 }
122 Dimensionality::Dim3 => {
123 let (k_x, rem) = (
124 k_idx % comptime![params.kernel_size[2]],
125 k_idx / comptime![params.kernel_size[2]],
126 );
127 let (k_y, k_z) = (
128 rem % comptime![params.kernel_size[1]],
129 rem / comptime![params.kernel_size[1]],
130 );
131
132 let offset_z = k_z * comptime![params.dilation[0]];
133 let offset_y = k_y * comptime![params.dilation[1]];
134 let offset_x = k_x * comptime![params.dilation[2]];
135
136 bar.tma_load_im2col_5d(
137 &tensor,
138 &mut stage,
139 self.map.n_offset as i32,
140 *in_offs.index(0),
141 *in_offs.index(1),
142 *in_offs.index(2),
143 channel_start as i32,
144 offset_z as u16,
145 offset_y as u16,
146 offset_x as u16,
147 );
148 }
149 }
150 }
151 }
152 }
153
154 pub fn advance_view(&mut self, k_offset: u32) {
155 self.map.update_view(k_offset);
156 }
157
158 pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaIm2colStage<IP> {
159 *self.stages.index(stage_idx)
160 }
161}
162
163#[cube]
166pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod>) -> (u32, Sequence<u32>) {
167 let rank = comptime![shape.len()];
168 let mut offs = pos;
169 let mut out = Sequence::new();
170
171 #[unroll]
172 for i in 0..rank {
173 let dim = comptime![rank - i - 1];
174 let (rem, offs_local) = shape.index(dim).div_mod(offs);
175 out.push(offs_local);
176 offs = rem;
177 }
178
179 (offs, out.rev())
180}