cubecl_convolution/components/
config.rs1use std::ops::Deref;
2
3use cubecl_core::CubeDim;
4use cubecl_matmul::components::{
5 MatmulLineSizes, MatmulSetupError,
6 global::{GlobalConfig, memory::GlobalMemoryConfig},
7};
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use super::*;
12
13pub trait ConvGemmConfig:
15 Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
16{
17 type GlobalMatmulConfig: GlobalConfig;
18
19 fn matmul_config(&self) -> Self::GlobalMatmulConfig;
20
21 fn convolution_params(&self) -> ConvolutionParams;
23 fn line_sizes(&self) -> MatmulLineSizes;
24 fn check_spatial_bounds(&self) -> bool;
25 fn cube_dim(&self) -> CubeDim;
26 fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
27 fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
28 fn out_global_memory_config(&self) -> GlobalMemoryConfig;
29}
30
31#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
32pub struct ConvolutionConfig<M: GlobalConfig> {
33 pub matmul: M,
34 pub convolution_params: ConvolutionParams,
35 pub num_stages: u32,
36}
37
38#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
39pub struct ConvolutionParams {
40 pub kernel_size: [u32; 3],
41 pub stride: [u32; 3],
42 pub dilation: [u32; 3],
43 pub padding: [i32; 3],
44 pub dimensionality: Dimensionality,
45}
46
47impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
48 type Target = M;
49
50 fn deref(&self) -> &Self::Target {
51 &self.matmul
52 }
53}
54
55impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
56 type GlobalMatmulConfig = M;
57
58 fn matmul_config(&self) -> Self::GlobalMatmulConfig {
59 self.matmul
60 }
61
62 fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
63 self.matmul.global_line_sizes()
64 }
65
66 fn cube_dim(&self) -> CubeDim {
67 self.matmul.cube_dim()
68 }
69
70 fn check_spatial_bounds(&self) -> bool {
71 let spatial_dims = self.convolution_params.dimensionality.num_dims();
72 let mut has_padding = false;
73 for i in 0..spatial_dims {
74 has_padding |= self.convolution_params.padding[i as usize] != 0;
75 }
76 has_padding
77 }
78
79 fn convolution_params(&self) -> ConvolutionParams {
80 self.convolution_params
81 }
82
83 fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
84 self.matmul.lhs_reader_config().gmem_config
85 }
86
87 fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
88 self.matmul.rhs_reader_config().gmem_config
89 }
90
91 fn out_global_memory_config(&self) -> GlobalMemoryConfig {
92 self.matmul.writer_config().gmem_config
93 }
94}
95
96impl<M: GlobalConfig> ConvolutionConfig<M> {
97 #[allow(clippy::too_many_arguments)]
98 pub fn new(
99 matmul: M,
100 kernel_size: &[u32],
101 stride: &[u32],
102 dilation: &[u32],
103 padding: &[i32],
104 dim: Dimensionality,
105 num_stages: u32,
106 ) -> Result<Self, MatmulSetupError> {
107 let dims = kernel_size.len();
108
109 let mut params = ConvolutionParams {
110 kernel_size: [0; 3],
111 stride: [0; 3],
112 dilation: [0; 3],
113 padding: [0; 3],
114 dimensionality: dim,
115 };
116 params.kernel_size[0..dims].copy_from_slice(kernel_size);
117 params.stride[0..dims].copy_from_slice(stride);
118 params.dilation[0..dims].copy_from_slice(dilation);
119 params.padding[0..dims].copy_from_slice(padding);
120 Ok(Self {
121 matmul,
122 convolution_params: params,
123 num_stages,
124 })
125 }
126}