cubek_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7 components::global::{GlobalConfig, memory::GlobalLayoutConfig},
8 launch::BatchedCoords,
9};
10
11use crate::components::{
12 ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13 global::layout::{NhwcCoords, div_mod_seq},
14};
15
16#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21 pub shape_out: Sequence<FastDivmod<u32>>,
23 pub padded_channels: FastDivmod<u32>,
25
26 pub rows: u32,
28 pub cols: u32,
30
31 #[cube(comptime)]
33 pub params: ConvolutionParams,
34 #[cube(comptime)]
36 pub config: GlobalLayoutConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41 pub fn new<G: GlobalConfig>(
42 rows: u32,
43 cols: u32,
44 padded_channels: FastDivmod<u32>,
45 shape_out: Sequence<FastDivmod<u32>>,
46 #[comptime] config: GlobalLayoutConfig,
47 #[comptime] params: ConvolutionParams,
48 ) -> Im2colLayout {
49 Im2colLayout {
50 shape_out,
51 padded_channels,
52 rows,
53 cols,
54 params,
55 config,
56 }
57 }
58}
59
60#[cube]
61impl Layout for Im2colLayout {
62 type Coordinates = BatchedCoords;
63 type SourceCoordinates = NhwcCoords;
64
65 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
66 let params = self.params.comptime();
67 let (_, view_m, view_k) = pos;
68
69 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
70
71 let (mut rem, channel) = self.padded_channels.div_mod(view_k);
72
73 let spatial_dims = params.dimensionality.num_dims();
74 let mut in_pos = Sequence::<i32>::new();
75
76 #[unroll]
77 for i in 0..spatial_dims {
78 let dim = spatial_dims - i - 1;
79 let ksize = params.kernel_size[dim];
80 let k_pos = (rem % ksize) as i32;
81 rem /= ksize;
82
83 let out_pos = out_offs[dim];
84 let stride = params.stride[dim] as i32;
85 let dilate = params.dilation[dim] as i32;
86 let pad = params.padding[dim];
87
88 let pos = match params.operation {
89 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
90 (out_pos as i32 * stride + k_pos * dilate) - pad
91 }
92 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
93 (out_pos as i32 + pad - k_pos * dilate) / stride
94 }
95 };
96 in_pos.push(pos);
97 }
98
99 let in_pos = in_pos.rev();
100
101 NhwcCoords {
102 batch,
103 spatial: in_pos,
104 channel,
105 }
106 }
107
108 fn shape(&self) -> Self::Coordinates {
109 (1, self.rows, self.cols)
110 }
111
112 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
113 (self.to_source_pos(pos), self.is_in_bounds(pos))
114 }
115
116 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
117 let (_, view_m, view_k) = pos;
118 let m_in_bounds = !self.config.check_row_bounds || view_m < self.rows;
120 let k_in_bounds = !self.config.check_col_bounds || view_k < self.cols;
121 m_in_bounds && k_in_bounds
122 }
123}
124
125impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
126 pub fn from_args(
127 client: &ComputeClient<R>,
128 problem: &ConvolutionProblem,
129 params: ConvolutionParams,
130 config: GlobalLayoutConfig,
131 ) -> Self {
132 match problem.operation {
133 ConvolutionOperation::Forward => Self::from_args_fprop(client, problem, params, config),
134 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
135 Self::from_args_dgrad(client, problem, params, config)
136 }
137 ConvolutionOperation::BackwardWeight => {
138 Self::from_args_wgrad(client, problem, params, config)
139 }
140 }
141 }
142
143 fn from_args_fprop(
144 client: &ComputeClient<R>,
145 problem: &ConvolutionProblem,
146 params: ConvolutionParams,
147 config: GlobalLayoutConfig,
148 ) -> Self {
149 let shape_out = problem
150 .out_shape
151 .iter()
152 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
153 .collect();
154
155 let padded_channels = problem.padded_channels as u32;
156 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
157
158 let shape_m = ScalarArg::new(problem.m as u32);
159 let shape_k = ScalarArg::new(problem.k as u32);
160
161 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
162 }
163
164 fn from_args_dgrad(
165 client: &ComputeClient<R>,
166 problem: &ConvolutionProblem,
167 params: ConvolutionParams,
168 config: GlobalLayoutConfig,
169 ) -> Self {
170 let shape = problem
171 .in_shape
172 .iter()
173 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
174 .collect();
175
176 let padded_channels = problem.padded_channels as u32;
177 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
178
179 let shape_m = ScalarArg::new(problem.m as u32);
180 let shape_k = ScalarArg::new(problem.k as u32);
181
182 Im2colLayoutLaunch::new(shape, padded_channels, shape_m, shape_k, params, config)
183 }
184
185 fn from_args_wgrad(
186 client: &ComputeClient<R>,
187 problem: &ConvolutionProblem,
188 params: ConvolutionParams,
189 config: GlobalLayoutConfig,
190 ) -> Self {
191 let shape_out = problem
192 .out_shape
193 .iter()
194 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
195 .collect();
196
197 let padded_channels = problem.padded_channels as u32;
198 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
199
200 let shape_k = ScalarArg::new(problem.k as u32);
201 let shape_n = ScalarArg::new(problem.n as u32);
202
203 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_k, shape_n, params, config)
204 }
205}