cubecl_convolution/components/global/layout/
im2col.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::{GlobalConfig, memory::GlobalMemoryConfig};
4use cubecl_std::{
5    FastDivmod, FastDivmodArgs,
6    tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::{
10    components::{
11        ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
12        global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq},
13    },
14    kernels::layered::selector::RuntimeArgs,
15};
16
17/// Maps a 4D NHWC tensor to a 2D column matrix using the im2col transformation
18/// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies
19/// the convolution parameters to calculate the position in the input tensor for that kernel element.
20#[derive(CubeType, CubeLaunch, Clone)]
21pub struct Im2colLayout {
22    /// Shape of output DHW
23    pub shape_out: Sequence<FastDivmod>,
24    /// Shape of channel, for decomposing k
25    pub shape_channel: FastDivmod,
26
27    /// Shape of the combined `m` dimension, including padding
28    pub shape_m: u32,
29    /// Shape of the combined `k` dimension, including padding
30    pub shape_k: u32,
31
32    /// Comptime parameters for the convolution
33    #[cube(comptime)]
34    pub params: ConvolutionParams,
35    /// Global memory config for the backing tensor
36    #[cube(comptime)]
37    pub config: GlobalMemoryConfig,
38}
39
40#[cube]
41impl Im2colLayout {
42    pub fn new<G: GlobalConfig>(
43        args: &RuntimeArgs,
44        #[comptime] config: ConvolutionConfig<G>,
45    ) -> Im2colLayout {
46        let shape_out = args.shape_out.clone();
47
48        Im2colLayout {
49            shape_out,
50            shape_channel: args.shape_channel,
51            shape_m: args.shape_m,
52            shape_k: args.shape_k,
53            params: config.convolution_params,
54            config: config.lhs_global_memory_config(),
55        }
56    }
57}
58
59#[cube]
60impl Layout for Im2colLayout {
61    type Coordinates = Coords3d;
62    type SourceCoordinates = NhwcCoords;
63
64    fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
65        let params = comptime![self.params];
66        let (_, view_m, view_k) = pos;
67
68        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
69
70        let (mut rem, channel) = self.shape_channel.div_mod(view_k);
71
72        let spatial_dims = comptime![self.shape_out.len()];
73        let mut in_pos = Sequence::<i32>::new();
74
75        #[unroll]
76        for i in 0..spatial_dims {
77            let dim = comptime![spatial_dims - i - 1];
78            let ksize = comptime![params.kernel_size[dim as usize]];
79            let k_pos = rem % ksize;
80            rem /= ksize;
81
82            let out_pos = *out_offs.index(dim);
83            let stride = comptime![params.stride[dim as usize]];
84            let dilate = comptime![params.dilation[dim as usize]];
85            let pad = comptime![params.padding[dim as usize]];
86
87            let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
88            in_pos.push(pos);
89        }
90
91        let in_pos = in_pos.rev();
92
93        NhwcCoords {
94            batch,
95            spatial: in_pos,
96            channel,
97        }
98    }
99
100    fn shape(&self) -> Self::Coordinates {
101        (1, self.shape_m, self.shape_k)
102    }
103
104    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
105        (self.to_source_pos(pos), self.is_in_bounds(pos))
106    }
107
108    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
109        let (_, view_m, view_k) = pos;
110        // Shouldn't be relied on because it doesn't check spatial
111        let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
112        let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
113        m_in_bounds && k_in_bounds
114    }
115}
116
117impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
118    pub fn from_args(
119        client: &ComputeClient<R>,
120        problem: &ConvolutionProblem,
121        params: ConvolutionParams,
122        config: GlobalMemoryConfig,
123    ) -> Self {
124        let shape_out = problem
125            .out_shape
126            .iter()
127            .map(|s| FastDivmodArgs::new(client, *s as u32))
128            .collect();
129        let shape_channel = FastDivmodArgs::new(client, problem.channels as u32);
130
131        let shape_m = ScalarArg::new(problem.m as u32);
132        let shape_k = ScalarArg::new(problem.k as u32);
133
134        Im2colLayoutLaunch::new(shape_out, shape_channel, shape_m, shape_k, params, config)
135    }
136}