cubek_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7 components::global::{GlobalConfig, memory::GlobalLayoutConfig},
8 launch::BatchedCoords,
9};
10
11use crate::components::{
12 ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13 global::layout::{NhwcCoords, div_mod_seq},
14};
15
16#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21 pub shape_out: Sequence<FastDivmod<u32>>,
23 pub padded_channels: FastDivmod<u32>,
25
26 pub rows: u32,
28 pub cols: u32,
30
31 #[cube(comptime)]
33 pub params: ConvolutionParams,
34 #[cube(comptime)]
36 pub config: GlobalLayoutConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41 pub fn new<G: GlobalConfig>(
42 rows: u32,
43 cols: u32,
44 padded_channels: FastDivmod<u32>,
45 shape_out: Sequence<FastDivmod<u32>>,
46 #[comptime] config: GlobalLayoutConfig,
47 #[comptime] params: ConvolutionParams,
48 ) -> Im2colLayout {
49 Im2colLayout {
50 shape_out,
51 padded_channels,
52 rows,
53 cols,
54 params,
55 config,
56 }
57 }
58}
59
60#[cube]
61impl Layout for Im2colLayout {
62 type Coordinates = BatchedCoords;
63 type SourceCoordinates = NhwcCoords;
64
65 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
66 let params = self.params.comptime();
67 let (_, view_m, view_k) = pos;
68
69 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
70
71 let (mut rem, channel) = self.padded_channels.div_mod(view_k);
72
73 let spatial_dims = params.dimensionality.num_dims();
74 let mut in_pos = Sequence::<i32>::new();
75
76 #[unroll]
77 for i in 0..spatial_dims {
78 let dim = spatial_dims - i - 1;
79 let ksize = params.kernel_size[dim];
80 let k_pos = (rem % ksize) as i32;
81 rem /= ksize;
82
83 let out_pos = out_offs[dim];
84 let stride = params.stride[dim] as i32;
85 let dilate = params.dilation[dim] as i32;
86 let pad = params.padding[dim];
87
88 let pos = match params.operation {
89 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
90 (out_pos as i32 * stride + k_pos * dilate) - pad
91 }
92 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
93 (out_pos as i32 + pad - k_pos * dilate) / stride
94 }
95 };
96 in_pos.push(pos);
97 }
98
99 let in_pos = in_pos.rev();
100
101 NhwcCoords {
102 batch,
103 spatial: in_pos,
104 channel,
105 }
106 }
107
108 fn shape(&self) -> Self::Coordinates {
109 (1, self.rows, self.cols)
110 }
111
112 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
113 (self.to_source_pos(pos), self.is_in_bounds(pos))
114 }
115
116 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
117 let (_, view_m, view_k) = pos;
118 let m_in_bounds = !self.config.check_row_bounds || view_m < self.rows;
120 let k_in_bounds = !self.config.check_col_bounds || view_k < self.cols;
121 m_in_bounds && k_in_bounds
122 }
123}
124
125impl<R: Runtime> Im2colLayoutLaunch<R> {
126 pub fn from_args(
127 problem: &ConvolutionProblem,
128 params: ConvolutionParams,
129 config: GlobalLayoutConfig,
130 ) -> Self {
131 match problem.operation {
132 ConvolutionOperation::Forward => Self::from_args_fprop(problem, params, config),
133 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
134 Self::from_args_dgrad(problem, params, config)
135 }
136 ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(problem, params, config),
137 }
138 }
139
140 fn from_args_fprop(
141 problem: &ConvolutionProblem,
142 params: ConvolutionParams,
143 config: GlobalLayoutConfig,
144 ) -> Self {
145 let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
146
147 let padded_channels = problem.padded_channels as u32;
148
149 let shape_m = problem.m as u32;
150 let shape_k = problem.k as u32;
151
152 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
153 }
154
155 fn from_args_dgrad(
156 problem: &ConvolutionProblem,
157 params: ConvolutionParams,
158 config: GlobalLayoutConfig,
159 ) -> Self {
160 let shape = problem.in_shape.iter().map(|s| *s as u32).collect();
161
162 let padded_channels = problem.padded_channels as u32;
163
164 let shape_m = problem.m as u32;
165 let shape_k = problem.k as u32;
166
167 Im2colLayoutLaunch::new(shape, padded_channels, shape_m, shape_k, params, config)
168 }
169
170 fn from_args_wgrad(
171 problem: &ConvolutionProblem,
172 params: ConvolutionParams,
173 config: GlobalLayoutConfig,
174 ) -> Self {
175 let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
176
177 let padded_channels = problem.padded_channels as u32;
178
179 let shape_k = problem.k as u32;
180 let shape_n = problem.n as u32;
181
182 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_k, shape_n, params, config)
183 }
184}