cubek_convolution/components/
config.rs

1use std::ops::Deref;
2
3use cubecl::CubeDim;
4use cubek_matmul::{
5    components::global::{GlobalConfig, memory::GlobalMemoryConfig},
6    definition::{MatmulLineSizes, MatmulSetupError},
7};
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use super::*;
12
13/// Convolution specific config, extends regular matmul `Config`.
14pub trait ConvGemmConfig:
15    Deref<Target: GlobalConfig> + Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
16{
17    type GlobalMatmulConfig: GlobalConfig;
18
19    fn matmul_config(&self) -> Self::GlobalMatmulConfig;
20
21    /// The size of the convolution kernel at `dim`
22    fn params(&self) -> ConvolutionParams;
23    fn operation(&self) -> ConvolutionOperation;
24    fn line_sizes(&self) -> MatmulLineSizes;
25    fn check_spatial_bounds(&self) -> bool;
26    fn cube_dim(&self) -> CubeDim;
27    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
28    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
29    fn out_global_memory_config(&self) -> GlobalMemoryConfig;
30}
31
32#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
33pub struct ConvolutionConfig<M: GlobalConfig> {
34    pub matmul: M,
35    pub params: ConvolutionParams,
36}
37
38#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
39pub struct ConvolutionParams {
40    pub kernel_size: [u32; 3],
41    pub stride: [u32; 3],
42    pub dilation: [u32; 3],
43    pub padding: [i32; 3],
44    pub dimensionality: Dimensionality,
45    pub operation: ConvolutionOperation,
46}
47
48impl ConvolutionParams {
49    pub fn from_problem(problem: &ConvolutionProblem) -> Self {
50        let dims = problem.dimensionality.num_dims();
51
52        let mut params = ConvolutionParams {
53            kernel_size: [0; 3],
54            stride: [0; 3],
55            dilation: [0; 3],
56            padding: [0; 3],
57            dimensionality: problem.dimensionality,
58            operation: problem.operation,
59        };
60        params.kernel_size[0..dims].copy_from_slice(&problem.kernel_size);
61        params.stride[0..dims].copy_from_slice(&problem.stride);
62        params.dilation[0..dims].copy_from_slice(&problem.dilation);
63        params.padding[0..dims].copy_from_slice(&problem.padding);
64        params
65    }
66}
67
68impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
69    type Target = M;
70
71    fn deref(&self) -> &Self::Target {
72        &self.matmul
73    }
74}
75
76impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
77    type GlobalMatmulConfig = M;
78
79    fn matmul_config(&self) -> Self::GlobalMatmulConfig {
80        self.matmul
81    }
82
83    fn line_sizes(&self) -> MatmulLineSizes {
84        self.matmul.global_line_sizes()
85    }
86
87    fn cube_dim(&self) -> CubeDim {
88        self.matmul.cube_dim()
89    }
90
91    fn check_spatial_bounds(&self) -> bool {
92        let spatial_dims = self.params.dimensionality.num_dims();
93        let mut has_padding = false;
94        for i in 0..spatial_dims {
95            has_padding |= self.params.padding[i] != 0;
96        }
97        has_padding
98    }
99
100    fn params(&self) -> ConvolutionParams {
101        self.params
102    }
103
104    fn operation(&self) -> ConvolutionOperation {
105        self.params.operation
106    }
107
108    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
109        self.matmul.lhs_reader_config().gmem_config
110    }
111
112    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
113        self.matmul.rhs_reader_config().gmem_config
114    }
115
116    fn out_global_memory_config(&self) -> GlobalMemoryConfig {
117        self.matmul.writer_config().gmem_config
118    }
119}
120
121impl<M: GlobalConfig> ConvolutionConfig<M> {
122    #[allow(clippy::too_many_arguments)]
123    pub fn new(
124        matmul: M,
125        kernel_size: &[u32],
126        stride: &[u32],
127        dilation: &[u32],
128        padding: &[i32],
129        dim: Dimensionality,
130        operation: ConvolutionOperation,
131    ) -> Result<Self, MatmulSetupError> {
132        let dims = kernel_size.len();
133
134        let mut params = ConvolutionParams {
135            kernel_size: [0; 3],
136            stride: [0; 3],
137            dilation: [0; 3],
138            padding: [0; 3],
139            dimensionality: dim,
140            operation,
141        };
142        params.kernel_size[0..dims].copy_from_slice(kernel_size);
143        params.stride[0..dims].copy_from_slice(stride);
144        params.dilation[0..dims].copy_from_slice(dilation);
145        params.padding[0..dims].copy_from_slice(padding);
146        Ok(Self { matmul, params })
147    }
148}