cubek_convolution/components/global/layout/
tma_im2col.rs

1use cubecl::{
2    prelude::*,
3    std::{
4        FastDivmod,
5        tensor::layout::{CoordsDyn, Layout, LayoutExpand},
6    },
7};
8use cubek_matmul::launch::BatchedCoords;
9
10use crate::components::{ConvolutionOperation, ConvolutionParams, global::layout::NhwcCoords};
11
12/// Im2col layout, producing both the position and offset
13#[derive(CubeType, CubeLaunch)]
14pub struct TmaIm2colLayout {
15    shape_out: Sequence<FastDivmod<u32>>,
16    padded_channels: FastDivmod<u32>,
17    #[cube(comptime)]
18    params: ConvolutionParams,
19    #[cube(comptime)]
20    check_kernel: bool,
21}
22
23#[cube]
24impl TmaIm2colLayout {
25    pub fn new(
26        shape_out: Sequence<FastDivmod<u32>>,
27        padded_channels: FastDivmod<u32>,
28        #[comptime] params: ConvolutionParams,
29        #[comptime] check_kernel: bool,
30    ) -> Self {
31        TmaIm2colLayout {
32            shape_out,
33            padded_channels,
34            params,
35            check_kernel,
36        }
37    }
38}
39
40#[cube]
41impl Layout for TmaIm2colLayout {
42    type Coordinates = BatchedCoords;
43    type SourceCoordinates = (NhwcCoords, CoordsDyn);
44
45    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
46        let (_, m, k) = pos;
47        let params = self.params.comptime();
48
49        let (n_offs, spatial_offsets) = div_mod_seq(m, &self.shape_out);
50        let spatial_dims = spatial_offsets.len();
51
52        let mut in_offs = Sequence::<i32>::new();
53
54        #[unroll]
55        for dim in 0..spatial_dims {
56            let stride = params.stride[dim] as i32;
57            let pad = params.padding[dim];
58            let out_pos = spatial_offsets[dim] as i32;
59            let offs = match params.operation {
60                ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
61                    out_pos * stride - pad
62                }
63                ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
64                    let ksize = params.kernel_size[dim] as i32;
65                    (out_pos + pad - ((ksize - 1) * params.dilation[dim] as i32)) / stride
66                }
67            };
68            in_offs.push(offs);
69        }
70
71        let (mut k_idx, channel_start) = self.padded_channels.div_mod(k);
72
73        let mut pos = NhwcCoords {
74            batch: n_offs,
75            spatial: in_offs,
76            channel: channel_start,
77        };
78
79        let mut k_offs = Sequence::new();
80        let k_rank = params.dimensionality.num_dims();
81
82        #[unroll]
83        for i in 0..k_rank {
84            let dim = k_rank - i - 1;
85            let k_size = params.kernel_size[dim];
86            let k_pos = k_idx % k_size;
87
88            let k_pos = match params.operation {
89                ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => k_pos,
90                ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
91                    // Since kernels are always positive, we need to subtract the bottom right
92                    // corner (see position above), then add the inverted index to it.
93                    k_size - k_pos - 1
94                }
95            };
96            k_offs.push(k_pos * params.dilation[dim]);
97            k_idx /= k_size;
98        }
99
100        if self.check_kernel.comptime() {
101            // This is the largest index that's aligned to the channel count in all cases.
102            // Alignment is 256, and that's the largest tile size possible with TMA.
103            // Could alternatively solve this by only loading if in bounds, and adjusting the awaited
104            // bytes by the in-bounds tiles but that's more complicated than just trying to load a very
105            // large channel index and letting bounds checks handle it.
106            let kernel_mask = (k_idx > 0) as u32 * 0x7FFFFF00u32;
107            pos.channel = pos.channel.max(kernel_mask);
108        }
109
110        (pos, k_offs.rev())
111    }
112
113    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
114        true.runtime()
115    }
116
117    fn shape(&self) -> Self::Coordinates {
118        (u32::MAX as usize, u32::MAX, u32::MAX).runtime()
119    }
120
121    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
122        (self.to_source_pos(pos), self.is_in_bounds(pos))
123    }
124}
125
126/// Decompose a linear index into local positions along each dimension in `shape`. Also returns the
127/// left over remainder.
128#[cube]
129pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod<u32>>) -> (u32, Sequence<u32>) {
130    let rank = shape.len().comptime();
131    let mut offs = pos;
132    let mut out = Sequence::new();
133
134    #[unroll]
135    for i in 0..rank {
136        let dim = rank - i - 1;
137        let (rem, offs_local) = shape[dim].div_mod(offs);
138        out.push(offs_local);
139        offs = rem;
140    }
141
142    (offs, out.rev())
143}