cubek_convolution/components/global/layout/
im2col.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::layout::{Coords3d, Layout, LayoutExpand},
5};
6use cubek_matmul::components::{
7    MatmulElems,
8    global::{GlobalConfig, memory::GlobalMemoryConfig},
9};
10
11use crate::components::{
12    ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
13    global::{args::RuntimeArgs, layout::NhwcCoords, read::im2col_tma::div_mod_seq},
14};
15
16/// Maps a 4D NHWC tensor to a 2D column matrix using the im2col transformation
17/// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies
18/// the convolution parameters to calculate the position in the input tensor for that kernel element.
19#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21    /// Shape of output DHW
22    pub shape_out: Sequence<FastDivmod>,
23    /// Shape of channel, for decomposing k
24    pub padded_channels: FastDivmod,
25
26    /// Shape of the combined `m` dimension, including padding
27    pub shape_m: u32,
28    /// Shape of the combined `k` dimension, including padding
29    pub shape_k: u32,
30
31    /// Comptime parameters for the convolution
32    #[cube(comptime)]
33    pub params: ConvolutionParams,
34    /// Global memory config for the backing tensor
35    #[cube(comptime)]
36    pub config: GlobalMemoryConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41    pub fn new<G: GlobalConfig>(
42        args: &RuntimeArgs,
43        shape_out: Sequence<FastDivmod>,
44        #[comptime] config: ConvolutionConfig<G>,
45    ) -> Im2colLayout {
46        Im2colLayout {
47            shape_out,
48            padded_channels: args.padded_channels,
49            shape_m: args.shape_m,
50            shape_k: args.shape_k,
51            params: config.convolution_params,
52            config: config.lhs_global_memory_config(),
53        }
54    }
55}
56
57#[cube]
58impl Layout for Im2colLayout {
59    type Coordinates = Coords3d;
60    type SourceCoordinates = NhwcCoords;
61
62    fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
63        let params = comptime![self.params];
64        let (_, view_m, view_k) = pos;
65
66        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
67
68        let (mut rem, channel) = self.padded_channels.div_mod(view_k);
69
70        let spatial_dims = comptime![self.shape_out.len()];
71        let mut in_pos = Sequence::<i32>::new();
72
73        #[unroll]
74        for i in 0..spatial_dims {
75            let dim = comptime![spatial_dims - i - 1];
76            let ksize = comptime![params.kernel_size[dim as usize]];
77            let k_pos = rem % ksize;
78            rem /= ksize;
79
80            let out_pos = *out_offs.index(dim);
81            let stride = comptime![params.stride[dim as usize]];
82            let dilate = comptime![params.dilation[dim as usize]];
83            let pad = comptime![params.padding[dim as usize]];
84
85            let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
86            in_pos.push(pos);
87        }
88
89        let in_pos = in_pos.rev();
90
91        NhwcCoords {
92            batch,
93            spatial: in_pos,
94            channel,
95        }
96    }
97
98    fn shape(&self) -> Self::Coordinates {
99        (1, self.shape_m, self.shape_k)
100    }
101
102    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
103        (self.to_source_pos(pos), self.is_in_bounds(pos))
104    }
105
106    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
107        let (_, view_m, view_k) = pos;
108        // Shouldn't be relied on because it doesn't check spatial
109        let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
110        let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
111        m_in_bounds && k_in_bounds
112    }
113}
114
115impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
116    pub fn from_args(
117        client: &ComputeClient<R>,
118        problem: &ConvolutionProblem,
119        params: ConvolutionParams,
120        config: GlobalMemoryConfig,
121        dtypes: &MatmulElems,
122    ) -> Self {
123        let shape_out = problem
124            .out_shape
125            .iter()
126            .map(|s| FastDivmodArgs::new(client, *s as u32))
127            .collect();
128
129        let load_width = client.properties().hardware.load_width;
130        let channel_align = load_width / dtypes.lhs_global.size_bits() as u32;
131        let padded_channels = (problem.channels as u32).next_multiple_of(channel_align);
132
133        let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
134        let padded_channels = FastDivmodArgs::new(client, padded_channels);
135
136        let shape_m = ScalarArg::new(problem.m as u32);
137        let shape_k = ScalarArg::new(size_k);
138
139        Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
140    }
141}