Skip to main content

cubek_convolution/components/
config.rs

1use std::ops::Deref;
2
3use cubecl::CubeDim;
4use cubek_matmul::{
5    components::global::{GlobalConfig, memory::GlobalMemoryConfig},
6    definition::{MatmulSetupError, MatmulVectorSizes},
7};
8use std::{fmt::Debug, hash::Hash};
9
10use super::*;
11
12/// Convolution specific config, extends regular matmul `Config`.
13pub trait ConvGemmConfig:
14    Deref<Target: GlobalConfig> + Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
15{
16    type GlobalMatmulConfig: GlobalConfig;
17
18    fn matmul_config(&self) -> Self::GlobalMatmulConfig;
19
20    /// The size of the convolution kernel at `dim`
21    fn params(&self) -> ConvolutionParams;
22    fn operation(&self) -> ConvolutionOperation;
23    fn vector_sizes(&self) -> MatmulVectorSizes;
24    fn check_spatial_bounds(&self) -> bool;
25    fn cube_dim(&self) -> CubeDim;
26    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
27    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
28    fn out_global_memory_config(&self) -> GlobalMemoryConfig;
29}
30
31#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
32pub struct ConvolutionConfig<M: GlobalConfig> {
33    pub matmul: M,
34    pub params: ConvolutionParams,
35}
36
37#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
38pub struct ConvolutionParams {
39    pub kernel_size: [u32; 3],
40    pub stride: [u32; 3],
41    pub dilation: [u32; 3],
42    pub padding: [i32; 3],
43    pub dimensionality: Dimensionality,
44    pub operation: ConvolutionOperation,
45}
46
47impl ConvolutionParams {
48    pub fn from_problem(problem: &ConvolutionProblem) -> Self {
49        let dims = problem.dimensionality.num_dims();
50
51        let mut params = ConvolutionParams {
52            kernel_size: [0; 3],
53            stride: [0; 3],
54            dilation: [0; 3],
55            padding: [0; 3],
56            dimensionality: problem.dimensionality,
57            operation: problem.operation,
58        };
59        params.kernel_size[0..dims].copy_from_slice(&problem.kernel_size);
60        params.stride[0..dims].copy_from_slice(&problem.stride);
61        params.dilation[0..dims].copy_from_slice(&problem.dilation);
62        params.padding[0..dims].copy_from_slice(&problem.padding);
63        params
64    }
65}
66
67impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
68    type Target = M;
69
70    fn deref(&self) -> &Self::Target {
71        &self.matmul
72    }
73}
74
75impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
76    type GlobalMatmulConfig = M;
77
78    fn matmul_config(&self) -> Self::GlobalMatmulConfig {
79        self.matmul
80    }
81
82    fn vector_sizes(&self) -> MatmulVectorSizes {
83        self.matmul.global_vector_sizes()
84    }
85
86    fn cube_dim(&self) -> CubeDim {
87        self.matmul.cube_dim()
88    }
89
90    fn check_spatial_bounds(&self) -> bool {
91        let spatial_dims = self.params.dimensionality.num_dims();
92        let mut has_padding = false;
93        for i in 0..spatial_dims {
94            has_padding |= self.params.padding[i] != 0;
95        }
96        has_padding
97    }
98
99    fn params(&self) -> ConvolutionParams {
100        self.params
101    }
102
103    fn operation(&self) -> ConvolutionOperation {
104        self.params.operation
105    }
106
107    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
108        self.matmul.lhs_reader_config().gmem_config
109    }
110
111    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
112        self.matmul.rhs_reader_config().gmem_config
113    }
114
115    fn out_global_memory_config(&self) -> GlobalMemoryConfig {
116        self.matmul.writer_config().gmem_config
117    }
118}
119
120impl<M: GlobalConfig> ConvolutionConfig<M> {
121    #[allow(clippy::too_many_arguments)]
122    pub fn new(
123        matmul: M,
124        kernel_size: &[u32],
125        stride: &[u32],
126        dilation: &[u32],
127        padding: &[i32],
128        dim: Dimensionality,
129        operation: ConvolutionOperation,
130    ) -> Result<Self, MatmulSetupError> {
131        let dims = kernel_size.len();
132
133        let mut params = ConvolutionParams {
134            kernel_size: [0; 3],
135            stride: [0; 3],
136            dilation: [0; 3],
137            padding: [0; 3],
138            dimensionality: dim,
139            operation,
140        };
141        params.kernel_size[0..dims].copy_from_slice(kernel_size);
142        params.stride[0..dims].copy_from_slice(stride);
143        params.dilation[0..dims].copy_from_slice(dilation);
144        params.padding[0..dims].copy_from_slice(padding);
145        Ok(Self { matmul, params })
146    }
147}