cubecl_convolution/components/
config.rs1use std::ops::Deref;
2
3use cubecl_core::CubeDim;
4use cubecl_matmul::components::{
5 MatmulIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
6 global::{
7 GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode,
8 read::ReaderMode,
9 },
10 stage::{StageConfig, StageMemoryConfig},
11};
12
13use super::*;
14
15pub trait ConvGemmConfig: GlobalConfig {
17 fn convolution_params(&self) -> ConvolutionParams;
19 fn line_sizes(&self) -> MatmulLineSizes;
20 fn check_spatial_bounds(&self) -> bool;
21}
22
23#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
24pub struct ConvolutionConfig<M: GlobalConfig> {
25 matmul: M,
26 params: ConvolutionParams,
27 num_stages: u32,
28}
29
30#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
31pub struct ConvolutionParams {
32 pub kernel_size: [u32; 3],
33 pub stride: [u32; 3],
34 pub dilation: [u32; 3],
35 pub padding: [i32; 3],
36 pub dimensionality: Dimensionality,
37}
38
39impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
40 type Target = M;
41
42 fn deref(&self) -> &Self::Target {
43 &self.matmul
44 }
45}
46
47impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
48 type StageConfig = M::StageConfig;
49
50 fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
51 self.stage_config().stage_memory_config(ident.into_stage())
52 }
53
54 fn stage_config(&self) -> Self::StageConfig {
55 self.matmul.stage_config()
56 }
57
58 fn global_line_size(&self, ident: MatmulIdent) -> u32 {
59 self.matmul.global_line_size(ident)
60 }
61
62 fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout {
63 self.matmul.matrix_layout(ident)
64 }
65
66 fn num_loading_planes(&self, ident: MatmulIdent) -> u32 {
67 self.matmul.num_loading_planes(ident)
68 }
69
70 fn plane_dim(&self) -> u32 {
71 self.matmul.plane_dim()
72 }
73
74 fn check_row_bounds(&self, ident: MatmulIdent) -> bool {
75 self.matmul.check_row_bounds(ident)
76 }
77
78 fn check_col_bounds(&self, ident: MatmulIdent) -> bool {
79 self.matmul.check_col_bounds(ident)
80 }
81
82 fn check_k_bounds(&self) -> bool {
83 self.matmul.check_k_bounds()
84 }
85
86 fn precompute_job(&self) -> bool {
87 self.matmul.precompute_job()
88 }
89
90 fn num_stages(&self, _ident: MatmulIdent) -> u32 {
91 self.num_stages
92 }
93
94 fn reader_mode(&self) -> ReaderMode {
95 self.matmul.reader_mode()
96 }
97
98 fn tiling_scheme(&self) -> TilingScheme {
99 self.matmul.tiling_scheme()
100 }
101
102 fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode {
103 self.matmul.event_loading_mode(ident)
104 }
105
106 fn plane_role_config(&self) -> PlaneRoleConfig {
107 self.matmul.plane_role_config()
108 }
109
110 fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
111 self.matmul.specialized_loading_sides()
112 }
113
114 fn cube_dim(&self) -> CubeDim {
115 CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
116 }
117}
118
119impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
120 fn convolution_params(&self) -> ConvolutionParams {
121 self.params
122 }
123
124 fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
125 MatmulLineSizes {
126 lhs: self.global_line_size(MatmulIdent::Lhs) as u8,
127 rhs: self.global_line_size(MatmulIdent::Rhs) as u8,
128 out: self.global_line_size(MatmulIdent::Out) as u8,
129 }
130 }
131
132 fn check_spatial_bounds(&self) -> bool {
133 let spatial_dims = self.params.dimensionality.num_dims();
134 let mut has_padding = false;
135 for i in 0..spatial_dims {
136 has_padding |= self.params.padding[i as usize] != 0;
137 }
138 has_padding
139 }
140}
141
142impl<M: GlobalConfig> ConvolutionConfig<M> {
143 #[allow(clippy::too_many_arguments)]
144 pub fn new(
145 matmul: M,
146 kernel_size: &[u32],
147 stride: &[u32],
148 dilation: &[u32],
149 padding: &[i32],
150 dim: Dimensionality,
151 num_stages: u32,
152 ) -> Result<Self, MatmulSetupError> {
153 let dims = kernel_size.len();
154
155 let mut params = ConvolutionParams {
156 kernel_size: [0; 3],
157 stride: [0; 3],
158 dilation: [0; 3],
159 padding: [0; 3],
160 dimensionality: dim,
161 };
162 params.kernel_size[0..dims].copy_from_slice(kernel_size);
163 params.stride[0..dims].copy_from_slice(stride);
164 params.dilation[0..dims].copy_from_slice(dilation);
165 params.padding[0..dims].copy_from_slice(padding);
166 Ok(Self {
167 matmul,
168 params,
169 num_stages,
170 })
171 }
172
173 pub fn to_matmul_config(self) -> M {
174 self.matmul
175 }
176}