Skip to main content

cubek_convolution/components/global/layout/
tma_im2col.rs

1use cubecl::{
2    prelude::*,
3    std::{
4        FastDivmod,
5        tensor::layout::{CoordsDyn, Layout, LayoutExpand},
6    },
7};
8use cubek_matmul::launch::BatchedCoords;
9
10use crate::components::{
11    ConvolutionOperation, ConvolutionParams, ConvolutionProblem, global::layout::NhwcCoords,
12};
13
14/// Im2col layout, producing both the position and offset
15#[derive(CubeType, CubeLaunch)]
16pub struct TmaIm2colLayout {
17    shape_out: Sequence<FastDivmod<u32>>,
18    padded_channels: FastDivmod<u32>,
19    rows: u32,
20    cols: u32,
21    #[cube(comptime)]
22    params: ConvolutionParams,
23    #[cube(comptime)]
24    check_kernel: bool,
25}
26
27#[cube]
28impl TmaIm2colLayout {
29    pub fn new(
30        shape_out: Sequence<FastDivmod<u32>>,
31        padded_channels: FastDivmod<u32>,
32        rows: u32,
33        cols: u32,
34        #[comptime] params: ConvolutionParams,
35        #[comptime] check_kernel: bool,
36    ) -> Self {
37        TmaIm2colLayout {
38            shape_out,
39            padded_channels,
40            params,
41            check_kernel,
42            rows,
43            cols,
44        }
45    }
46}
47
48#[cube]
49impl Layout for TmaIm2colLayout {
50    type Coordinates = BatchedCoords;
51    type SourceCoordinates = (NhwcCoords, CoordsDyn);
52
53    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
54        let (_, m, k) = pos;
55        let params = self.params.comptime();
56
57        let (n_offs, spatial_offsets) = div_mod_seq(m, &self.shape_out);
58        let spatial_dims = spatial_offsets.len();
59
60        let mut in_offs = Sequence::<i32>::new();
61
62        #[unroll]
63        for dim in 0..spatial_dims {
64            let stride = params.stride[dim] as i32;
65            let pad = params.padding[dim];
66            let out_pos = spatial_offsets[dim] as i32;
67            let offs = match params.operation {
68                ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
69                    out_pos * stride - pad
70                }
71                ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
72                    let ksize = params.kernel_size[dim] as i32;
73                    (out_pos + pad - ((ksize - 1) * params.dilation[dim] as i32)) / stride
74                }
75            };
76            in_offs.push(offs);
77        }
78
79        let (mut k_idx, channel_start) = self.padded_channels.div_mod(k);
80
81        let mut pos = NhwcCoords {
82            batch: n_offs,
83            spatial: in_offs,
84            channel: channel_start,
85        };
86
87        let mut k_offs = Sequence::new();
88        let k_rank = params.dimensionality.num_dims();
89
90        #[unroll]
91        for i in 0..k_rank {
92            let dim = k_rank - i - 1;
93            let k_size = params.kernel_size[dim];
94            let k_pos = k_idx % k_size;
95
96            let k_pos = match params.operation {
97                ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => k_pos,
98                ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
99                    // Since kernels are always positive, we need to subtract the bottom right
100                    // corner (see position above), then add the inverted index to it.
101                    k_size - k_pos - 1
102                }
103            };
104            k_offs.push(k_pos * params.dilation[dim]);
105            k_idx /= k_size;
106        }
107
108        if self.check_kernel.comptime() {
109            // This is the largest index that's aligned to the channel count in all cases.
110            // Alignment is 256, and that's the largest tile size possible with TMA.
111            // Could alternatively solve this by only loading if in bounds, and adjusting the awaited
112            // bytes by the in-bounds tiles but that's more complicated than just trying to load a very
113            // large channel index and letting bounds checks handle it.
114            let kernel_mask = (k_idx > 0) as u32 * 0x7FFFFF00u32;
115            pos.channel = pos.channel.max(kernel_mask);
116        }
117
118        (pos, k_offs.rev())
119    }
120
121    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
122        true.runtime()
123    }
124
125    fn shape(&self) -> Self::Coordinates {
126        (1, self.rows, self.cols)
127    }
128
129    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
130        (self.to_source_pos(pos), self.is_in_bounds(pos))
131    }
132}
133
134/// Decompose a linear index into local positions along each dimension in `shape`. Also returns the
135/// left over remainder.
136#[cube]
137pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod<u32>>) -> (u32, Sequence<u32>) {
138    let rank = shape.len().comptime();
139    let mut offs = pos;
140    let mut out = Sequence::new();
141
142    #[unroll]
143    for i in 0..rank {
144        let dim = rank - i - 1;
145        let (rem, offs_local) = shape[dim].div_mod(offs);
146        out.push(offs_local);
147        offs = rem;
148    }
149
150    (offs, out.rev())
151}
152
153impl<R: Runtime> TmaIm2colLayoutLaunch<R> {
154    pub fn from_args(problem: &ConvolutionProblem, check_kernel: bool) -> Self {
155        let shape_out = problem.out_shape.iter().map(|it| *it as u32).collect();
156
157        let padded_channels = problem.padded_channels as u32;
158        let params = ConvolutionParams::from_problem(problem);
159
160        match problem.operation {
161            ConvolutionOperation::Forward
162            | ConvolutionOperation::ForwardTransposed
163            | ConvolutionOperation::BackwardData => {
164                Self::from_args_lhs(problem, shape_out, padded_channels, params, check_kernel)
165            }
166            ConvolutionOperation::BackwardWeight => {
167                Self::from_args_rhs(problem, shape_out, padded_channels, params, check_kernel)
168            }
169        }
170    }
171
172    fn from_args_lhs(
173        problem: &ConvolutionProblem,
174        shape_out: SequenceArg<R, FastDivmod<u32>>,
175        padded_channels: u32,
176        params: ConvolutionParams,
177        check_kernel: bool,
178    ) -> Self {
179        let shape_m = problem.m as u32;
180        let shape_k = problem.k as u32;
181
182        TmaIm2colLayoutLaunch::new(
183            shape_out,
184            padded_channels,
185            shape_m,
186            shape_k,
187            params,
188            check_kernel,
189        )
190    }
191
192    fn from_args_rhs(
193        problem: &ConvolutionProblem,
194        shape_out: SequenceArg<R, FastDivmod<u32>>,
195        padded_channels: u32,
196        params: ConvolutionParams,
197        check_kernel: bool,
198    ) -> Self {
199        let shape_k = problem.k as u32;
200        let shape_n = problem.n as u32;
201
202        TmaIm2colLayoutLaunch::new(
203            shape_out,
204            padded_channels,
205            shape_k,
206            shape_n,
207            params,
208            check_kernel,
209        )
210    }
211}