cubek_convolution/components/global/layout/
tma_im2col.rs1use cubecl::{
2 prelude::*,
3 std::{
4 FastDivmod,
5 tensor::layout::{CoordsDyn, Layout, LayoutExpand},
6 },
7};
8use cubek_matmul::launch::BatchedCoords;
9
10use crate::components::{ConvolutionOperation, ConvolutionParams, global::layout::NhwcCoords};
11
12#[derive(CubeType, CubeLaunch)]
14pub struct TmaIm2colLayout {
15 shape_out: Sequence<FastDivmod<u32>>,
16 padded_channels: FastDivmod<u32>,
17 #[cube(comptime)]
18 params: ConvolutionParams,
19 #[cube(comptime)]
20 check_kernel: bool,
21}
22
23#[cube]
24impl TmaIm2colLayout {
25 pub fn new(
26 shape_out: Sequence<FastDivmod<u32>>,
27 padded_channels: FastDivmod<u32>,
28 #[comptime] params: ConvolutionParams,
29 #[comptime] check_kernel: bool,
30 ) -> Self {
31 TmaIm2colLayout {
32 shape_out,
33 padded_channels,
34 params,
35 check_kernel,
36 }
37 }
38}
39
40#[cube]
41impl Layout for TmaIm2colLayout {
42 type Coordinates = BatchedCoords;
43 type SourceCoordinates = (NhwcCoords, CoordsDyn);
44
45 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
46 let (_, m, k) = pos;
47 let params = self.params.comptime();
48
49 let (n_offs, spatial_offsets) = div_mod_seq(m, &self.shape_out);
50 let spatial_dims = spatial_offsets.len();
51
52 let mut in_offs = Sequence::<i32>::new();
53
54 #[unroll]
55 for dim in 0..spatial_dims {
56 let stride = params.stride[dim] as i32;
57 let pad = params.padding[dim];
58 let out_pos = spatial_offsets[dim] as i32;
59 let offs = match params.operation {
60 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
61 out_pos * stride - pad
62 }
63 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
64 let ksize = params.kernel_size[dim] as i32;
65 (out_pos + pad - ((ksize - 1) * params.dilation[dim] as i32)) / stride
66 }
67 };
68 in_offs.push(offs);
69 }
70
71 let (mut k_idx, channel_start) = self.padded_channels.div_mod(k);
72
73 let mut pos = NhwcCoords {
74 batch: n_offs,
75 spatial: in_offs,
76 channel: channel_start,
77 };
78
79 let mut k_offs = Sequence::new();
80 let k_rank = params.dimensionality.num_dims();
81
82 #[unroll]
83 for i in 0..k_rank {
84 let dim = k_rank - i - 1;
85 let k_size = params.kernel_size[dim];
86 let k_pos = k_idx % k_size;
87
88 let k_pos = match params.operation {
89 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => k_pos,
90 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
91 k_size - k_pos - 1
94 }
95 };
96 k_offs.push(k_pos * params.dilation[dim]);
97 k_idx /= k_size;
98 }
99
100 if self.check_kernel.comptime() {
101 let kernel_mask = (k_idx > 0) as u32 * 0x7FFFFF00u32;
107 pos.channel = pos.channel.max(kernel_mask);
108 }
109
110 (pos, k_offs.rev())
111 }
112
113 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
114 true.runtime()
115 }
116
117 fn shape(&self) -> Self::Coordinates {
118 (u32::MAX as usize, u32::MAX, u32::MAX).runtime()
119 }
120
121 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
122 (self.to_source_pos(pos), self.is_in_bounds(pos))
123 }
124}
125
126#[cube]
129pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence<FastDivmod<u32>>) -> (u32, Sequence<u32>) {
130 let rank = shape.len().comptime();
131 let mut offs = pos;
132 let mut out = Sequence::new();
133
134 #[unroll]
135 for i in 0..rank {
136 let dim = rank - i - 1;
137 let (rem, offs_local) = shape[dim].div_mod(offs);
138 out.push(offs_local);
139 offs = rem;
140 }
141
142 (offs, out.rev())
143}