cubek_convolution/components/global/layout/
tma_im2col.rs1use cubecl::{
2 prelude::*,
3 std::{
4 FastDivmod,
5 tensor::layout::{CoordsDyn, Layout, LayoutExpand},
6 },
7};
8use cubek_matmul::launch::BatchedCoords;
9
10use crate::components::{
11 ConvolutionOperation, ConvolutionParams, ConvolutionProblem, global::layout::NhwcCoords,
12};
13
14#[derive(CubeType, CubeLaunch)]
16pub struct TmaIm2colLayout {
17 shape_out: Sequence<FastDivmod<u32>>,
18 padded_channels: FastDivmod<u32>,
19 rows: u32,
20 cols: u32,
21 #[cube(comptime)]
22 params: ConvolutionParams,
23 #[cube(comptime)]
24 check_kernel: bool,
25}
26
27#[cube]
28impl TmaIm2colLayout {
29 pub fn new(
30 shape_out: Sequence<FastDivmod<u32>>,
31 padded_channels: FastDivmod<u32>,
32 rows: u32,
33 cols: u32,
34 #[comptime] params: ConvolutionParams,
35 #[comptime] check_kernel: bool,
36 ) -> Self {
37 TmaIm2colLayout {
38 shape_out,
39 padded_channels,
40 params,
41 check_kernel,
42 rows,
43 cols,
44 }
45 }
46}
47
48#[cube]
49impl Layout for TmaIm2colLayout {
50 type Coordinates = BatchedCoords;
51 type SourceCoordinates = (NhwcCoords, CoordsDyn);
52
53 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
54 let (_, m, k) = pos;
55 let params = self.params.comptime();
56
57 let (n_offs, spatial_offsets) = div_mod_seq(m, &self.shape_out);
58 let spatial_dims = spatial_offsets.len();
59
60 let mut in_offs = Sequence::<i32>::new();
61
62 #[unroll]
63 for dim in 0..spatial_dims {
64 let stride = params.stride[dim] as i32;
65 let pad = params.padding[dim];
66 let out_pos = spatial_offsets[dim] as i32;
67 let offs = match params.operation {
68 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
69 out_pos * stride - pad
70 }
71 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
72 let ksize = params.kernel_size[dim] as i32;
73 (out_pos + pad - ((ksize - 1) * params.dilation[dim] as i32)) / stride
74 }
75 };
76 in_offs.push(offs);
77 }
78
79 let (mut k_idx, channel_start) = self.padded_channels.div_mod(k);
80
81 let mut pos = NhwcCoords {
82 batch: n_offs,
83 spatial: in_offs,
84 channel: channel_start,
85 };
86
87 let mut k_offs = Sequence::new();
88 let k_rank = params.dimensionality.num_dims();
89
90 #[unroll]
91 for i in 0..k_rank {
92 let dim = k_rank - i - 1;
93 let k_size = params.kernel_size[dim];
94 let k_pos = k_idx % k_size;
95
96 let k_pos = match params.operation {
97 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => k_pos,
98 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
99 k_size - k_pos - 1
102 }
103 };
104 k_offs.push(k_pos * params.dilation[dim]);
105 k_idx /= k_size;
106 }
107
108 if self.check_kernel.comptime() {
109 let kernel_mask = (k_idx > 0) as u32 * 0x7FFFFF00u32;
115 pos.channel = pos.channel.max(kernel_mask);
116 }
117
118 (pos, k_offs.rev())
119 }
120
121 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
122 true.runtime()
123 }
124
125 fn shape(&self) -> Self::Coordinates {
126 (1, self.rows, self.cols)
127 }
128
129 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
130 (self.to_source_pos(pos), self.is_in_bounds(pos))
131 }
132}
133
134#[cube]
137pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod<u32>>) -> (u32, Sequence<u32>) {
138 let rank = shape.len().comptime();
139 let mut offs = pos;
140 let mut out = Sequence::new();
141
142 #[unroll]
143 for i in 0..rank {
144 let dim = rank - i - 1;
145 let (rem, offs_local) = shape[dim].div_mod(offs);
146 out.push(offs_local);
147 offs = rem;
148 }
149
150 (offs, out.rev())
151}
152
153impl<R: Runtime> TmaIm2colLayoutLaunch<R> {
154 pub fn from_args(problem: &ConvolutionProblem, check_kernel: bool) -> Self {
155 let shape_out = problem.out_shape.iter().map(|it| *it as u32).collect();
156
157 let padded_channels = problem.padded_channels as u32;
158 let params = ConvolutionParams::from_problem(problem);
159
160 match problem.operation {
161 ConvolutionOperation::Forward
162 | ConvolutionOperation::ForwardTransposed
163 | ConvolutionOperation::BackwardData => {
164 Self::from_args_lhs(problem, shape_out, padded_channels, params, check_kernel)
165 }
166 ConvolutionOperation::BackwardWeight => {
167 Self::from_args_rhs(problem, shape_out, padded_channels, params, check_kernel)
168 }
169 }
170 }
171
172 fn from_args_lhs(
173 problem: &ConvolutionProblem,
174 shape_out: SequenceArg<R, FastDivmod<u32>>,
175 padded_channels: u32,
176 params: ConvolutionParams,
177 check_kernel: bool,
178 ) -> Self {
179 let shape_m = problem.m as u32;
180 let shape_k = problem.k as u32;
181
182 TmaIm2colLayoutLaunch::new(
183 shape_out,
184 padded_channels,
185 shape_m,
186 shape_k,
187 params,
188 check_kernel,
189 )
190 }
191
192 fn from_args_rhs(
193 problem: &ConvolutionProblem,
194 shape_out: SequenceArg<R, FastDivmod<u32>>,
195 padded_channels: u32,
196 params: ConvolutionParams,
197 check_kernel: bool,
198 ) -> Self {
199 let shape_k = problem.k as u32;
200 let shape_n = problem.n as u32;
201
202 TmaIm2colLayoutLaunch::new(
203 shape_out,
204 padded_channels,
205 shape_k,
206 shape_n,
207 params,
208 check_kernel,
209 )
210 }
211}