cubek_convolution/components/global/layout/
im2col.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7    components::global::{GlobalConfig, memory::GlobalMemoryConfig},
8    launch::BatchedCoords,
9};
10
11use crate::components::{
12    ConvGemmConfig, ConvolutionConfig, ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13    global::layout::{NhwcCoords, div_mod_seq},
14};
15
16/// Maps a 4D NHWC tensor to a 2D column matrix using the im2col transformation
17/// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies
18/// the convolution parameters to calculate the position in the input tensor for that kernel element.
19#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21    /// Shape of output DHW
22    pub shape_out: Sequence<FastDivmod<u32>>,
23    /// Shape of channel, for decomposing k
24    pub padded_channels: FastDivmod<u32>,
25
26    /// Shape of the combined `m` dimension, including padding
27    pub rows: u32,
28    /// Shape of the combined `k` dimension, including padding
29    pub cols: u32,
30
31    /// Comptime parameters for the convolution
32    #[cube(comptime)]
33    pub params: ConvolutionParams,
34    /// Global memory config for the backing tensor
35    #[cube(comptime)]
36    pub config: GlobalMemoryConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41    pub fn new<G: GlobalConfig>(
42        rows: u32,
43        cols: u32,
44        padded_channels: FastDivmod<u32>,
45        shape_out: Sequence<FastDivmod<u32>>,
46        #[comptime] config: ConvolutionConfig<G>,
47    ) -> Im2colLayout {
48        Im2colLayout {
49            shape_out,
50            padded_channels,
51            rows,
52            cols,
53            params: config.params,
54            config: config.lhs_global_memory_config(),
55        }
56    }
57}
58
59#[cube]
60impl Layout for Im2colLayout {
61    type Coordinates = BatchedCoords;
62    type SourceCoordinates = NhwcCoords;
63
64    fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
65        let params = self.params.comptime();
66        let (_, view_m, view_k) = pos;
67
68        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
69
70        let (mut rem, channel) = self.padded_channels.div_mod(view_k);
71
72        let spatial_dims = params.dimensionality.num_dims();
73        let mut in_pos = Sequence::<i32>::new();
74
75        #[unroll]
76        for i in 0..spatial_dims {
77            let dim = spatial_dims - i - 1;
78            let ksize = params.kernel_size[dim];
79            let k_pos = (rem % ksize) as i32;
80            rem /= ksize;
81
82            let out_pos = out_offs[dim];
83            let stride = params.stride[dim] as i32;
84            let dilate = params.dilation[dim] as i32;
85            let pad = params.padding[dim];
86
87            let pos = match params.operation {
88                ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
89                    (out_pos as i32 * stride + k_pos * dilate) - pad
90                }
91                ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
92                    (out_pos as i32 + pad - k_pos * dilate) / stride
93                }
94            };
95            in_pos.push(pos);
96        }
97
98        let in_pos = in_pos.rev();
99
100        NhwcCoords {
101            batch,
102            spatial: in_pos,
103            channel,
104        }
105    }
106
107    fn shape(&self) -> Self::Coordinates {
108        (1, self.rows, self.cols)
109    }
110
111    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
112        (self.to_source_pos(pos), self.is_in_bounds(pos))
113    }
114
115    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
116        let (_, view_m, view_k) = pos;
117        // Shouldn't be relied on because it doesn't check spatial
118        let m_in_bounds = !self.config.check_row_bounds || view_m < self.rows;
119        let k_in_bounds = !self.config.check_col_bounds || view_k < self.cols;
120        m_in_bounds && k_in_bounds
121    }
122}
123
124impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
125    pub fn from_args(
126        client: &ComputeClient<R>,
127        problem: &ConvolutionProblem,
128        params: ConvolutionParams,
129        config: GlobalMemoryConfig,
130    ) -> Self {
131        match problem.operation {
132            ConvolutionOperation::Forward => Self::from_args_fprop(client, problem, params, config),
133            ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
134                Self::from_args_dgrad(client, problem, params, config)
135            }
136            ConvolutionOperation::BackwardWeight => {
137                Self::from_args_wgrad(client, problem, params, config)
138            }
139        }
140    }
141
142    fn from_args_fprop(
143        client: &ComputeClient<R>,
144        problem: &ConvolutionProblem,
145        params: ConvolutionParams,
146        config: GlobalMemoryConfig,
147    ) -> Self {
148        let shape_out = problem
149            .out_shape
150            .iter()
151            .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
152            .collect();
153
154        let padded_channels = problem.padded_channels as u32;
155        let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
156
157        let shape_m = ScalarArg::new(problem.m as u32);
158        let shape_k = ScalarArg::new(problem.k as u32);
159
160        Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
161    }
162
163    fn from_args_dgrad(
164        client: &ComputeClient<R>,
165        problem: &ConvolutionProblem,
166        params: ConvolutionParams,
167        config: GlobalMemoryConfig,
168    ) -> Self {
169        let shape = problem
170            .in_shape
171            .iter()
172            .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
173            .collect();
174
175        let padded_channels = problem.padded_channels as u32;
176        let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
177
178        let shape_m = ScalarArg::new(problem.m as u32);
179        let shape_k = ScalarArg::new(problem.k as u32);
180
181        Im2colLayoutLaunch::new(shape, padded_channels, shape_m, shape_k, params, config)
182    }
183
184    fn from_args_wgrad(
185        client: &ComputeClient<R>,
186        problem: &ConvolutionProblem,
187        params: ConvolutionParams,
188        config: GlobalMemoryConfig,
189    ) -> Self {
190        let shape_out = problem
191            .out_shape
192            .iter()
193            .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
194            .collect();
195
196        let padded_channels = problem.padded_channels as u32;
197        let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
198
199        let shape_k = ScalarArg::new(problem.k as u32);
200        let shape_n = ScalarArg::new(problem.n as u32);
201
202        Im2colLayoutLaunch::new(shape_out, padded_channels, shape_k, shape_n, params, config)
203    }
204}