cubecl_convolution/components/global/layout/
im2col.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::{
4    MatmulElems,
5    global::{GlobalConfig, memory::GlobalMemoryConfig},
6};
7use cubecl_std::{
8    FastDivmod, FastDivmodArgs,
9    tensor::layout::{Coords3d, Layout, LayoutExpand},
10};
11
12use crate::components::{
13    ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
14    global::{args::RuntimeArgs, layout::NhwcCoords, read::im2col_tma::div_mod_seq},
15};
16
17/// Maps a 4D NHWC tensor to a 2D column matrix using the im2col transformation
18/// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies
19/// the convolution parameters to calculate the position in the input tensor for that kernel element.
20#[derive(CubeType, CubeLaunch, Clone)]
21pub struct Im2colLayout {
22    /// Shape of output DHW
23    pub shape_out: Sequence<FastDivmod>,
24    /// Shape of channel, for decomposing k
25    pub padded_channels: FastDivmod,
26
27    /// Shape of the combined `m` dimension, including padding
28    pub shape_m: u32,
29    /// Shape of the combined `k` dimension, including padding
30    pub shape_k: u32,
31
32    /// Comptime parameters for the convolution
33    #[cube(comptime)]
34    pub params: ConvolutionParams,
35    /// Global memory config for the backing tensor
36    #[cube(comptime)]
37    pub config: GlobalMemoryConfig,
38}
39
40#[cube]
41impl Im2colLayout {
42    pub fn new<G: GlobalConfig>(
43        args: &RuntimeArgs,
44        shape_out: Sequence<FastDivmod>,
45        #[comptime] config: ConvolutionConfig<G>,
46    ) -> Im2colLayout {
47        Im2colLayout {
48            shape_out,
49            padded_channels: args.padded_channels,
50            shape_m: args.shape_m,
51            shape_k: args.shape_k,
52            params: config.convolution_params,
53            config: config.lhs_global_memory_config(),
54        }
55    }
56}
57
58#[cube]
59impl Layout for Im2colLayout {
60    type Coordinates = Coords3d;
61    type SourceCoordinates = NhwcCoords;
62
63    fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
64        let params = comptime![self.params];
65        let (_, view_m, view_k) = pos;
66
67        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
68
69        let (mut rem, channel) = self.padded_channels.div_mod(view_k);
70
71        let spatial_dims = comptime![self.shape_out.len()];
72        let mut in_pos = Sequence::<i32>::new();
73
74        #[unroll]
75        for i in 0..spatial_dims {
76            let dim = comptime![spatial_dims - i - 1];
77            let ksize = comptime![params.kernel_size[dim as usize]];
78            let k_pos = rem % ksize;
79            rem /= ksize;
80
81            let out_pos = *out_offs.index(dim);
82            let stride = comptime![params.stride[dim as usize]];
83            let dilate = comptime![params.dilation[dim as usize]];
84            let pad = comptime![params.padding[dim as usize]];
85
86            let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
87            in_pos.push(pos);
88        }
89
90        let in_pos = in_pos.rev();
91
92        NhwcCoords {
93            batch,
94            spatial: in_pos,
95            channel,
96        }
97    }
98
99    fn shape(&self) -> Self::Coordinates {
100        (1, self.shape_m, self.shape_k)
101    }
102
103    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
104        (self.to_source_pos(pos), self.is_in_bounds(pos))
105    }
106
107    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
108        let (_, view_m, view_k) = pos;
109        // Shouldn't be relied on because it doesn't check spatial
110        let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
111        let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
112        m_in_bounds && k_in_bounds
113    }
114}
115
116impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
117    pub fn from_args(
118        client: &ComputeClient<R>,
119        problem: &ConvolutionProblem,
120        params: ConvolutionParams,
121        config: GlobalMemoryConfig,
122        dtypes: &MatmulElems,
123    ) -> Self {
124        let shape_out = problem
125            .out_shape
126            .iter()
127            .map(|s| FastDivmodArgs::new(client, *s as u32))
128            .collect();
129
130        let load_width = client.properties().hardware.load_width;
131        let channel_align = load_width / dtypes.lhs_global.size_bits() as u32;
132        let padded_channels = (problem.channels as u32).next_multiple_of(channel_align);
133
134        let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
135        let padded_channels = FastDivmodArgs::new(client, padded_channels);
136
137        let shape_m = ScalarArg::new(problem.m as u32);
138        let shape_k = ScalarArg::new(size_k);
139
140        Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
141    }
142}