cubek_convolution/components/global/layout/
tma_im2col.rs1use cubecl::{
2 prelude::*,
3 std::{
4 FastDivmod, FastDivmodArgs,
5 tensor::layout::{CoordsDyn, Layout, LayoutExpand},
6 },
7};
8use cubek_matmul::launch::BatchedCoords;
9
10use crate::components::{
11 ConvolutionOperation, ConvolutionParams, ConvolutionProblem, global::layout::NhwcCoords,
12};
13
14#[derive(CubeType, CubeLaunch)]
16pub struct TmaIm2colLayout {
17 shape_out: Sequence<FastDivmod<u32>>,
18 padded_channels: FastDivmod<u32>,
19 rows: u32,
20 cols: u32,
21 #[cube(comptime)]
22 params: ConvolutionParams,
23 #[cube(comptime)]
24 check_kernel: bool,
25}
26
27#[cube]
28impl TmaIm2colLayout {
29 pub fn new(
30 shape_out: Sequence<FastDivmod<u32>>,
31 padded_channels: FastDivmod<u32>,
32 rows: u32,
33 cols: u32,
34 #[comptime] params: ConvolutionParams,
35 #[comptime] check_kernel: bool,
36 ) -> Self {
37 TmaIm2colLayout {
38 shape_out,
39 padded_channels,
40 params,
41 check_kernel,
42 rows,
43 cols,
44 }
45 }
46}
47
48#[cube]
49impl Layout for TmaIm2colLayout {
50 type Coordinates = BatchedCoords;
51 type SourceCoordinates = (NhwcCoords, CoordsDyn);
52
53 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
54 let (_, m, k) = pos;
55 let params = self.params.comptime();
56
57 let (n_offs, spatial_offsets) = div_mod_seq(m, &self.shape_out);
58 let spatial_dims = spatial_offsets.len();
59
60 let mut in_offs = Sequence::<i32>::new();
61
62 #[unroll]
63 for dim in 0..spatial_dims {
64 let stride = params.stride[dim] as i32;
65 let pad = params.padding[dim];
66 let out_pos = spatial_offsets[dim] as i32;
67 let offs = match params.operation {
68 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
69 out_pos * stride - pad
70 }
71 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
72 let ksize = params.kernel_size[dim] as i32;
73 (out_pos + pad - ((ksize - 1) * params.dilation[dim] as i32)) / stride
74 }
75 };
76 in_offs.push(offs);
77 }
78
79 let (mut k_idx, channel_start) = self.padded_channels.div_mod(k);
80
81 let mut pos = NhwcCoords {
82 batch: n_offs,
83 spatial: in_offs,
84 channel: channel_start,
85 };
86
87 let mut k_offs = Sequence::new();
88 let k_rank = params.dimensionality.num_dims();
89
90 #[unroll]
91 for i in 0..k_rank {
92 let dim = k_rank - i - 1;
93 let k_size = params.kernel_size[dim];
94 let k_pos = k_idx % k_size;
95
96 let k_pos = match params.operation {
97 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => k_pos,
98 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
99 k_size - k_pos - 1
102 }
103 };
104 k_offs.push(k_pos * params.dilation[dim]);
105 k_idx /= k_size;
106 }
107
108 if self.check_kernel.comptime() {
109 let kernel_mask = (k_idx > 0) as u32 * 0x7FFFFF00u32;
115 pos.channel = pos.channel.max(kernel_mask);
116 }
117
118 (pos, k_offs.rev())
119 }
120
121 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
122 true.runtime()
123 }
124
125 fn shape(&self) -> Self::Coordinates {
126 (1, self.rows, self.cols)
127 }
128
129 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
130 (self.to_source_pos(pos), self.is_in_bounds(pos))
131 }
132}
133
134#[cube]
137pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod<u32>>) -> (u32, Sequence<u32>) {
138 let rank = shape.len().comptime();
139 let mut offs = pos;
140 let mut out = Sequence::new();
141
142 #[unroll]
143 for i in 0..rank {
144 let dim = rank - i - 1;
145 let (rem, offs_local) = shape[dim].div_mod(offs);
146 out.push(offs_local);
147 offs = rem;
148 }
149
150 (offs, out.rev())
151}
152
153impl<'a, R: Runtime> TmaIm2colLayoutLaunch<'a, R> {
154 pub fn from_args(
155 client: &ComputeClient<R>,
156 problem: &ConvolutionProblem,
157 check_kernel: bool,
158 ) -> Self {
159 let shape_out = problem
160 .out_shape
161 .iter()
162 .map(|it| FastDivmodArgs::<u32>::new(client, *it as u32))
163 .collect();
164
165 let padded_channels = problem.padded_channels as u32;
166 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
167 let params = ConvolutionParams::from_problem(problem);
168
169 match problem.operation {
170 ConvolutionOperation::Forward
171 | ConvolutionOperation::ForwardTransposed
172 | ConvolutionOperation::BackwardData => {
173 Self::from_args_lhs(problem, shape_out, padded_channels, params, check_kernel)
174 }
175 ConvolutionOperation::BackwardWeight => {
176 Self::from_args_rhs(problem, shape_out, padded_channels, params, check_kernel)
177 }
178 }
179 }
180
181 fn from_args_lhs(
182 problem: &ConvolutionProblem,
183 shape_out: SequenceArg<'a, R, FastDivmod<u32>>,
184 padded_channels: FastDivmodArgs<u32>,
185 params: ConvolutionParams,
186 check_kernel: bool,
187 ) -> Self {
188 let shape_m = ScalarArg::new(problem.m as u32);
189 let shape_k = ScalarArg::new(problem.k as u32);
190
191 TmaIm2colLayoutLaunch::new(
192 shape_out,
193 padded_channels,
194 shape_m,
195 shape_k,
196 params,
197 check_kernel,
198 )
199 }
200
201 fn from_args_rhs(
202 problem: &ConvolutionProblem,
203 shape_out: SequenceArg<'a, R, FastDivmod<u32>>,
204 padded_channels: FastDivmodArgs<u32>,
205 params: ConvolutionParams,
206 check_kernel: bool,
207 ) -> Self {
208 let shape_k = ScalarArg::new(problem.k as u32);
209 let shape_n = ScalarArg::new(problem.n as u32);
210
211 TmaIm2colLayoutLaunch::new(
212 shape_out,
213 padded_channels,
214 shape_k,
215 shape_n,
216 params,
217 check_kernel,
218 )
219 }
220}