Skip to main content

cubek_convolution/components/global/layout/
im2col.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod,
4    tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7    components::global::{GlobalConfig, memory::GlobalLayoutConfig},
8    launch::BatchedCoords,
9};
10
11use crate::components::{
12    ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13    global::layout::{NhwcCoords, div_mod_seq},
14};
15
16/// Maps a 4D NHWC tensor to a 2D column matrix using the im2col transformation
17/// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies
18/// the convolution parameters to calculate the position in the input tensor for that kernel element.
19#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21    /// Shape of output DHW
22    pub shape_out: Sequence<FastDivmod<u32>>,
23    /// Shape of channel, for decomposing k
24    pub padded_channels: FastDivmod<u32>,
25
26    /// Shape of the combined `m` dimension, including padding
27    pub rows: u32,
28    /// Shape of the combined `k` dimension, including padding
29    pub cols: u32,
30
31    /// Comptime parameters for the convolution
32    #[cube(comptime)]
33    pub params: ConvolutionParams,
34    /// Global memory config for the backing tensor
35    #[cube(comptime)]
36    pub config: GlobalLayoutConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41    pub fn new<G: GlobalConfig>(
42        rows: u32,
43        cols: u32,
44        padded_channels: FastDivmod<u32>,
45        shape_out: Sequence<FastDivmod<u32>>,
46        #[comptime] config: GlobalLayoutConfig,
47        #[comptime] params: ConvolutionParams,
48    ) -> Im2colLayout {
49        Im2colLayout {
50            shape_out,
51            padded_channels,
52            rows,
53            cols,
54            params,
55            config,
56        }
57    }
58}
59
60#[cube]
61impl Layout for Im2colLayout {
62    type Coordinates = BatchedCoords;
63    type SourceCoordinates = NhwcCoords;
64
65    fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
66        let params = self.params.comptime();
67        let (_, view_m, view_k) = pos;
68
69        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
70
71        let (mut rem, channel) = self.padded_channels.div_mod(view_k);
72
73        let spatial_dims = params.dimensionality.num_dims();
74        let mut in_pos = Sequence::<i32>::new();
75
76        #[unroll]
77        for i in 0..spatial_dims {
78            let dim = spatial_dims - i - 1;
79            let ksize = params.kernel_size[dim];
80            let k_pos = (rem % ksize) as i32;
81            rem /= ksize;
82
83            let out_pos = out_offs[dim];
84            let stride = params.stride[dim] as i32;
85            let dilate = params.dilation[dim] as i32;
86            let pad = params.padding[dim];
87
88            let pos = match params.operation {
89                ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
90                    (out_pos as i32 * stride + k_pos * dilate) - pad
91                }
92                ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
93                    (out_pos as i32 + pad - k_pos * dilate) / stride
94                }
95            };
96            in_pos.push(pos);
97        }
98
99        let in_pos = in_pos.rev();
100
101        NhwcCoords {
102            batch,
103            spatial: in_pos,
104            channel,
105        }
106    }
107
108    fn shape(&self) -> Self::Coordinates {
109        (1, self.rows, self.cols)
110    }
111
112    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
113        (self.to_source_pos(pos), self.is_in_bounds(pos))
114    }
115
116    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
117        let (_, view_m, view_k) = pos;
118        // Shouldn't be relied on because it doesn't check spatial
119        let m_in_bounds = !self.config.check_row_bounds || view_m < self.rows;
120        let k_in_bounds = !self.config.check_col_bounds || view_k < self.cols;
121        m_in_bounds && k_in_bounds
122    }
123}
124
125impl<R: Runtime> Im2colLayoutLaunch<R> {
126    pub fn from_args(
127        problem: &ConvolutionProblem,
128        params: ConvolutionParams,
129        config: GlobalLayoutConfig,
130    ) -> Self {
131        match problem.operation {
132            ConvolutionOperation::Forward => Self::from_args_fprop(problem, params, config),
133            ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
134                Self::from_args_dgrad(problem, params, config)
135            }
136            ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(problem, params, config),
137        }
138    }
139
140    fn from_args_fprop(
141        problem: &ConvolutionProblem,
142        params: ConvolutionParams,
143        config: GlobalLayoutConfig,
144    ) -> Self {
145        let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
146
147        let padded_channels = problem.padded_channels as u32;
148
149        let shape_m = problem.m as u32;
150        let shape_k = problem.k as u32;
151
152        Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
153    }
154
155    fn from_args_dgrad(
156        problem: &ConvolutionProblem,
157        params: ConvolutionParams,
158        config: GlobalLayoutConfig,
159    ) -> Self {
160        let shape = problem.in_shape.iter().map(|s| *s as u32).collect();
161
162        let padded_channels = problem.padded_channels as u32;
163
164        let shape_m = problem.m as u32;
165        let shape_k = problem.k as u32;
166
167        Im2colLayoutLaunch::new(shape, padded_channels, shape_m, shape_k, params, config)
168    }
169
170    fn from_args_wgrad(
171        problem: &ConvolutionProblem,
172        params: ConvolutionParams,
173        config: GlobalLayoutConfig,
174    ) -> Self {
175        let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
176
177        let padded_channels = problem.padded_channels as u32;
178
179        let shape_k = problem.k as u32;
180        let shape_n = problem.n as u32;
181
182        Im2colLayoutLaunch::new(shape_out, padded_channels, shape_k, shape_n, params, config)
183    }
184}