cubek_convolution/components/
config.rs1use std::ops::Deref;
2
3use cubecl::CubeDim;
4use cubek_matmul::{
5 components::global::{GlobalConfig, memory::GlobalMemoryConfig},
6 definition::{MatmulLineSizes, MatmulSetupError},
7};
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use super::*;
12
13pub trait ConvGemmConfig:
15 Deref<Target: GlobalConfig> + Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
16{
17 type GlobalMatmulConfig: GlobalConfig;
18
19 fn matmul_config(&self) -> Self::GlobalMatmulConfig;
20
21 fn params(&self) -> ConvolutionParams;
23 fn operation(&self) -> ConvolutionOperation;
24 fn line_sizes(&self) -> MatmulLineSizes;
25 fn check_spatial_bounds(&self) -> bool;
26 fn cube_dim(&self) -> CubeDim;
27 fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
28 fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
29 fn out_global_memory_config(&self) -> GlobalMemoryConfig;
30}
31
32#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
33pub struct ConvolutionConfig<M: GlobalConfig> {
34 pub matmul: M,
35 pub params: ConvolutionParams,
36}
37
38#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
39pub struct ConvolutionParams {
40 pub kernel_size: [u32; 3],
41 pub stride: [u32; 3],
42 pub dilation: [u32; 3],
43 pub padding: [i32; 3],
44 pub dimensionality: Dimensionality,
45 pub operation: ConvolutionOperation,
46}
47
48impl ConvolutionParams {
49 pub fn from_problem(problem: &ConvolutionProblem) -> Self {
50 let dims = problem.dimensionality.num_dims();
51
52 let mut params = ConvolutionParams {
53 kernel_size: [0; 3],
54 stride: [0; 3],
55 dilation: [0; 3],
56 padding: [0; 3],
57 dimensionality: problem.dimensionality,
58 operation: problem.operation,
59 };
60 params.kernel_size[0..dims].copy_from_slice(&problem.kernel_size);
61 params.stride[0..dims].copy_from_slice(&problem.stride);
62 params.dilation[0..dims].copy_from_slice(&problem.dilation);
63 params.padding[0..dims].copy_from_slice(&problem.padding);
64 params
65 }
66}
67
68impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
69 type Target = M;
70
71 fn deref(&self) -> &Self::Target {
72 &self.matmul
73 }
74}
75
76impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
77 type GlobalMatmulConfig = M;
78
79 fn matmul_config(&self) -> Self::GlobalMatmulConfig {
80 self.matmul
81 }
82
83 fn line_sizes(&self) -> MatmulLineSizes {
84 self.matmul.global_line_sizes()
85 }
86
87 fn cube_dim(&self) -> CubeDim {
88 self.matmul.cube_dim()
89 }
90
91 fn check_spatial_bounds(&self) -> bool {
92 let spatial_dims = self.params.dimensionality.num_dims();
93 let mut has_padding = false;
94 for i in 0..spatial_dims {
95 has_padding |= self.params.padding[i] != 0;
96 }
97 has_padding
98 }
99
100 fn params(&self) -> ConvolutionParams {
101 self.params
102 }
103
104 fn operation(&self) -> ConvolutionOperation {
105 self.params.operation
106 }
107
108 fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
109 self.matmul.lhs_reader_config().gmem_config
110 }
111
112 fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
113 self.matmul.rhs_reader_config().gmem_config
114 }
115
116 fn out_global_memory_config(&self) -> GlobalMemoryConfig {
117 self.matmul.writer_config().gmem_config
118 }
119}
120
121impl<M: GlobalConfig> ConvolutionConfig<M> {
122 #[allow(clippy::too_many_arguments)]
123 pub fn new(
124 matmul: M,
125 kernel_size: &[u32],
126 stride: &[u32],
127 dilation: &[u32],
128 padding: &[i32],
129 dim: Dimensionality,
130 operation: ConvolutionOperation,
131 ) -> Result<Self, MatmulSetupError> {
132 let dims = kernel_size.len();
133
134 let mut params = ConvolutionParams {
135 kernel_size: [0; 3],
136 stride: [0; 3],
137 dilation: [0; 3],
138 padding: [0; 3],
139 dimensionality: dim,
140 operation,
141 };
142 params.kernel_size[0..dims].copy_from_slice(kernel_size);
143 params.stride[0..dims].copy_from_slice(stride);
144 params.dilation[0..dims].copy_from_slice(dilation);
145 params.padding[0..dims].copy_from_slice(padding);
146 Ok(Self { matmul, params })
147 }
148}