cubek_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7 components::global::{GlobalConfig, memory::GlobalMemoryConfig},
8 launch::BatchedCoords,
9};
10
11use crate::components::{
12 ConvGemmConfig, ConvolutionConfig, ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13 global::layout::{NhwcCoords, div_mod_seq},
14};
15
16#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21 pub shape_out: Sequence<FastDivmod<u32>>,
23 pub padded_channels: FastDivmod<u32>,
25
26 pub rows: u32,
28 pub cols: u32,
30
31 #[cube(comptime)]
33 pub params: ConvolutionParams,
34 #[cube(comptime)]
36 pub config: GlobalMemoryConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41 pub fn new<G: GlobalConfig>(
42 rows: u32,
43 cols: u32,
44 padded_channels: FastDivmod<u32>,
45 shape_out: Sequence<FastDivmod<u32>>,
46 #[comptime] config: ConvolutionConfig<G>,
47 ) -> Im2colLayout {
48 Im2colLayout {
49 shape_out,
50 padded_channels,
51 rows,
52 cols,
53 params: config.params,
54 config: config.lhs_global_memory_config(),
55 }
56 }
57}
58
59#[cube]
60impl Layout for Im2colLayout {
61 type Coordinates = BatchedCoords;
62 type SourceCoordinates = NhwcCoords;
63
64 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
65 let params = self.params.comptime();
66 let (_, view_m, view_k) = pos;
67
68 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
69
70 let (mut rem, channel) = self.padded_channels.div_mod(view_k);
71
72 let spatial_dims = params.dimensionality.num_dims();
73 let mut in_pos = Sequence::<i32>::new();
74
75 #[unroll]
76 for i in 0..spatial_dims {
77 let dim = spatial_dims - i - 1;
78 let ksize = params.kernel_size[dim];
79 let k_pos = (rem % ksize) as i32;
80 rem /= ksize;
81
82 let out_pos = out_offs[dim];
83 let stride = params.stride[dim] as i32;
84 let dilate = params.dilation[dim] as i32;
85 let pad = params.padding[dim];
86
87 let pos = match params.operation {
88 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
89 (out_pos as i32 * stride + k_pos * dilate) - pad
90 }
91 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
92 (out_pos as i32 + pad - k_pos * dilate) / stride
93 }
94 };
95 in_pos.push(pos);
96 }
97
98 let in_pos = in_pos.rev();
99
100 NhwcCoords {
101 batch,
102 spatial: in_pos,
103 channel,
104 }
105 }
106
107 fn shape(&self) -> Self::Coordinates {
108 (1, self.rows, self.cols)
109 }
110
111 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
112 (self.to_source_pos(pos), self.is_in_bounds(pos))
113 }
114
115 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
116 let (_, view_m, view_k) = pos;
117 let m_in_bounds = !self.config.check_row_bounds || view_m < self.rows;
119 let k_in_bounds = !self.config.check_col_bounds || view_k < self.cols;
120 m_in_bounds && k_in_bounds
121 }
122}
123
124impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
125 pub fn from_args(
126 client: &ComputeClient<R>,
127 problem: &ConvolutionProblem,
128 params: ConvolutionParams,
129 config: GlobalMemoryConfig,
130 ) -> Self {
131 match problem.operation {
132 ConvolutionOperation::Forward => Self::from_args_fprop(client, problem, params, config),
133 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
134 Self::from_args_dgrad(client, problem, params, config)
135 }
136 ConvolutionOperation::BackwardWeight => {
137 Self::from_args_wgrad(client, problem, params, config)
138 }
139 }
140 }
141
142 fn from_args_fprop(
143 client: &ComputeClient<R>,
144 problem: &ConvolutionProblem,
145 params: ConvolutionParams,
146 config: GlobalMemoryConfig,
147 ) -> Self {
148 let shape_out = problem
149 .out_shape
150 .iter()
151 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
152 .collect();
153
154 let padded_channels = problem.padded_channels as u32;
155 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
156
157 let shape_m = ScalarArg::new(problem.m as u32);
158 let shape_k = ScalarArg::new(problem.k as u32);
159
160 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
161 }
162
163 fn from_args_dgrad(
164 client: &ComputeClient<R>,
165 problem: &ConvolutionProblem,
166 params: ConvolutionParams,
167 config: GlobalMemoryConfig,
168 ) -> Self {
169 let shape = problem
170 .in_shape
171 .iter()
172 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
173 .collect();
174
175 let padded_channels = problem.padded_channels as u32;
176 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
177
178 let shape_m = ScalarArg::new(problem.m as u32);
179 let shape_k = ScalarArg::new(problem.k as u32);
180
181 Im2colLayoutLaunch::new(shape, padded_channels, shape_m, shape_k, params, config)
182 }
183
184 fn from_args_wgrad(
185 client: &ComputeClient<R>,
186 problem: &ConvolutionProblem,
187 params: ConvolutionParams,
188 config: GlobalMemoryConfig,
189 ) -> Self {
190 let shape_out = problem
191 .out_shape
192 .iter()
193 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
194 .collect();
195
196 let padded_channels = problem.padded_channels as u32;
197 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
198
199 let shape_k = ScalarArg::new(problem.k as u32);
200 let shape_n = ScalarArg::new(problem.n as u32);
201
202 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_k, shape_n, params, config)
203 }
204}