cubek_convolution/components/
config.rs1use std::ops::Deref;
2
3use cubecl::CubeDim;
4use cubek_matmul::{
5 components::global::{GlobalConfig, memory::GlobalMemoryConfig},
6 definition::{MatmulSetupError, MatmulVectorSizes},
7};
8use std::{fmt::Debug, hash::Hash};
9
10use super::*;
11
12pub trait ConvGemmConfig:
14 Deref<Target: GlobalConfig> + Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
15{
16 type GlobalMatmulConfig: GlobalConfig;
17
18 fn matmul_config(&self) -> Self::GlobalMatmulConfig;
19
20 fn params(&self) -> ConvolutionParams;
22 fn operation(&self) -> ConvolutionOperation;
23 fn vector_sizes(&self) -> MatmulVectorSizes;
24 fn check_spatial_bounds(&self) -> bool;
25 fn cube_dim(&self) -> CubeDim;
26 fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
27 fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
28 fn out_global_memory_config(&self) -> GlobalMemoryConfig;
29}
30
31#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
32pub struct ConvolutionConfig<M: GlobalConfig> {
33 pub matmul: M,
34 pub params: ConvolutionParams,
35}
36
37#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
38pub struct ConvolutionParams {
39 pub kernel_size: [u32; 3],
40 pub stride: [u32; 3],
41 pub dilation: [u32; 3],
42 pub padding: [i32; 3],
43 pub dimensionality: Dimensionality,
44 pub operation: ConvolutionOperation,
45}
46
47impl ConvolutionParams {
48 pub fn from_problem(problem: &ConvolutionProblem) -> Self {
49 let dims = problem.dimensionality.num_dims();
50
51 let mut params = ConvolutionParams {
52 kernel_size: [0; 3],
53 stride: [0; 3],
54 dilation: [0; 3],
55 padding: [0; 3],
56 dimensionality: problem.dimensionality,
57 operation: problem.operation,
58 };
59 params.kernel_size[0..dims].copy_from_slice(&problem.kernel_size);
60 params.stride[0..dims].copy_from_slice(&problem.stride);
61 params.dilation[0..dims].copy_from_slice(&problem.dilation);
62 params.padding[0..dims].copy_from_slice(&problem.padding);
63 params
64 }
65}
66
67impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
68 type Target = M;
69
70 fn deref(&self) -> &Self::Target {
71 &self.matmul
72 }
73}
74
75impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
76 type GlobalMatmulConfig = M;
77
78 fn matmul_config(&self) -> Self::GlobalMatmulConfig {
79 self.matmul
80 }
81
82 fn vector_sizes(&self) -> MatmulVectorSizes {
83 self.matmul.global_vector_sizes()
84 }
85
86 fn cube_dim(&self) -> CubeDim {
87 self.matmul.cube_dim()
88 }
89
90 fn check_spatial_bounds(&self) -> bool {
91 let spatial_dims = self.params.dimensionality.num_dims();
92 let mut has_padding = false;
93 for i in 0..spatial_dims {
94 has_padding |= self.params.padding[i] != 0;
95 }
96 has_padding
97 }
98
99 fn params(&self) -> ConvolutionParams {
100 self.params
101 }
102
103 fn operation(&self) -> ConvolutionOperation {
104 self.params.operation
105 }
106
107 fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
108 self.matmul.lhs_reader_config().gmem_config
109 }
110
111 fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
112 self.matmul.rhs_reader_config().gmem_config
113 }
114
115 fn out_global_memory_config(&self) -> GlobalMemoryConfig {
116 self.matmul.writer_config().gmem_config
117 }
118}
119
120impl<M: GlobalConfig> ConvolutionConfig<M> {
121 #[allow(clippy::too_many_arguments)]
122 pub fn new(
123 matmul: M,
124 kernel_size: &[u32],
125 stride: &[u32],
126 dilation: &[u32],
127 padding: &[i32],
128 dim: Dimensionality,
129 operation: ConvolutionOperation,
130 ) -> Result<Self, MatmulSetupError> {
131 let dims = kernel_size.len();
132
133 let mut params = ConvolutionParams {
134 kernel_size: [0; 3],
135 stride: [0; 3],
136 dilation: [0; 3],
137 padding: [0; 3],
138 dimensionality: dim,
139 operation,
140 };
141 params.kernel_size[0..dims].copy_from_slice(kernel_size);
142 params.stride[0..dims].copy_from_slice(stride);
143 params.dilation[0..dims].copy_from_slice(dilation);
144 params.padding[0..dims].copy_from_slice(padding);
145 Ok(Self { matmul, params })
146 }
147}