cubecl_convolution/components/
config.rs1use std::ops::Deref;
2
3use cubecl_core::CubeDim;
4use cubecl_matmul::components::{
5 MatmulIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
6 global::{
7 GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode,
8 read::ReaderMode,
9 },
10 stage::{StageConfig, StageMemoryConfig, SwizzleMode, TilingLayoutEnum},
11};
12
13use super::*;
14
15pub trait ConvGemmConfig: GlobalConfig {
17 fn convolution_params(&self) -> ConvolutionParams;
19 fn line_sizes(&self) -> MatmulLineSizes;
20 fn check_spatial_bounds(&self) -> bool;
21}
22
23#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
24pub struct ConvolutionConfig<M: GlobalConfig> {
25 matmul: M,
26 params: ConvolutionParams,
27 num_stages: u32,
28}
29
30#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
31pub struct ConvolutionParams {
32 pub kernel_size: [u32; 3],
33 pub stride: [u32; 3],
34 pub dilation: [u32; 3],
35 pub padding: [i32; 3],
36 pub dimensionality: Dimensionality,
37}
38
39impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
40 type Target = M;
41
42 fn deref(&self) -> &Self::Target {
43 &self.matmul
44 }
45}
46
47impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
48 type StageConfig = M::StageConfig;
49
50 fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
51 self.stage_config().stage_memory_config(ident.into_stage())
52 }
53
54 fn stage_config(&self) -> Self::StageConfig {
55 self.matmul.stage_config()
56 }
57
58 fn global_line_size(&self, ident: MatmulIdent) -> u32 {
59 self.matmul.global_line_size(ident)
60 }
61
62 fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout {
63 self.matmul.matrix_layout(ident)
64 }
65
66 fn swizzle_mode(&self, ident: MatmulIdent) -> SwizzleMode {
67 self.matmul.swizzle_mode(ident)
68 }
69
70 fn tiling_layout(&self, ident: MatmulIdent) -> TilingLayoutEnum {
71 self.matmul.tiling_layout(ident)
72 }
73
74 fn num_loading_planes(&self, ident: MatmulIdent) -> u32 {
75 self.matmul.num_loading_planes(ident)
76 }
77
78 fn plane_dim(&self) -> u32 {
79 self.matmul.plane_dim()
80 }
81
82 fn check_row_bounds(&self, ident: MatmulIdent) -> bool {
83 self.matmul.check_row_bounds(ident)
84 }
85
86 fn check_col_bounds(&self, ident: MatmulIdent) -> bool {
87 self.matmul.check_col_bounds(ident)
88 }
89
90 fn check_k_bounds(&self) -> bool {
91 self.matmul.check_k_bounds()
92 }
93
94 fn precompute_job(&self) -> bool {
95 self.matmul.precompute_job()
96 }
97
98 fn num_stages(&self, _ident: MatmulIdent) -> u32 {
99 self.num_stages
100 }
101
102 fn reader_mode(&self) -> ReaderMode {
103 self.matmul.reader_mode()
104 }
105
106 fn tiling_scheme(&self) -> TilingScheme {
107 self.matmul.tiling_scheme()
108 }
109
110 fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode {
111 self.matmul.event_loading_mode(ident)
112 }
113
114 fn plane_role_config(&self) -> PlaneRoleConfig {
115 self.matmul.plane_role_config()
116 }
117
118 fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
119 self.matmul.specialized_loading_sides()
120 }
121
122 fn cube_dim(&self) -> CubeDim {
123 CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
124 }
125}
126
127impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
128 fn convolution_params(&self) -> ConvolutionParams {
129 self.params
130 }
131
132 fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
133 MatmulLineSizes {
134 lhs: self.global_line_size(MatmulIdent::Lhs) as u8,
135 rhs: self.global_line_size(MatmulIdent::Rhs) as u8,
136 out: self.global_line_size(MatmulIdent::Out) as u8,
137 }
138 }
139
140 fn check_spatial_bounds(&self) -> bool {
141 let spatial_dims = self.params.dimensionality.num_dims();
142 let mut has_padding = false;
143 for i in 0..spatial_dims {
144 has_padding |= self.params.padding[i as usize] != 0;
145 }
146 has_padding
147 }
148}
149
150impl<M: GlobalConfig> ConvolutionConfig<M> {
151 #[allow(clippy::too_many_arguments)]
152 pub fn new(
153 matmul: M,
154 kernel_size: &[u32],
155 stride: &[u32],
156 dilation: &[u32],
157 padding: &[i32],
158 dim: Dimensionality,
159 num_stages: u32,
160 ) -> Result<Self, MatmulSetupError> {
161 let dims = kernel_size.len();
162
163 let mut params = ConvolutionParams {
164 kernel_size: [0; 3],
165 stride: [0; 3],
166 dilation: [0; 3],
167 padding: [0; 3],
168 dimensionality: dim,
169 };
170 params.kernel_size[0..dims].copy_from_slice(kernel_size);
171 params.stride[0..dims].copy_from_slice(stride);
172 params.dilation[0..dims].copy_from_slice(dilation);
173 params.padding[0..dims].copy_from_slice(padding);
174 Ok(Self {
175 matmul,
176 params,
177 num_stages,
178 })
179 }
180
181 pub fn to_matmul_config(self) -> M {
182 self.matmul
183 }
184}