cubecl_convolution/components/
config.rs

1use std::ops::Deref;
2
3use cubecl_core::CubeDim;
4use cubecl_matmul::components::{
5    MatmulLineSizes, MatmulSetupError,
6    global::{GlobalConfig, memory::GlobalMemoryConfig},
7};
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use super::*;
12
13/// Convolution specific config, extends regular matmul [`Config`](global::Config)
14pub trait ConvGemmConfig:
15    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
16{
17    type GlobalMatmulConfig: GlobalConfig;
18
19    fn matmul_config(&self) -> Self::GlobalMatmulConfig;
20
21    /// The size of the convolution kernel at `dim`
22    fn convolution_params(&self) -> ConvolutionParams;
23    fn line_sizes(&self) -> MatmulLineSizes;
24    fn check_spatial_bounds(&self) -> bool;
25    fn cube_dim(&self) -> CubeDim;
26    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
27    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
28    fn out_global_memory_config(&self) -> GlobalMemoryConfig;
29}
30
31#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
32pub struct ConvolutionConfig<M: GlobalConfig> {
33    pub matmul: M,
34    pub convolution_params: ConvolutionParams,
35    pub num_stages: u32,
36}
37
38#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
39pub struct ConvolutionParams {
40    pub kernel_size: [u32; 3],
41    pub stride: [u32; 3],
42    pub dilation: [u32; 3],
43    pub padding: [i32; 3],
44    pub dimensionality: Dimensionality,
45}
46
47impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
48    type Target = M;
49
50    fn deref(&self) -> &Self::Target {
51        &self.matmul
52    }
53}
54
55impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
56    type GlobalMatmulConfig = M;
57
58    fn matmul_config(&self) -> Self::GlobalMatmulConfig {
59        self.matmul
60    }
61
62    fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
63        self.matmul.global_line_sizes()
64    }
65
66    fn cube_dim(&self) -> CubeDim {
67        self.matmul.cube_dim()
68    }
69
70    fn check_spatial_bounds(&self) -> bool {
71        let spatial_dims = self.convolution_params.dimensionality.num_dims();
72        let mut has_padding = false;
73        for i in 0..spatial_dims {
74            has_padding |= self.convolution_params.padding[i as usize] != 0;
75        }
76        has_padding
77    }
78
79    fn convolution_params(&self) -> ConvolutionParams {
80        self.convolution_params
81    }
82
83    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
84        self.matmul.lhs_reader_config().gmem_config
85    }
86
87    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
88        self.matmul.rhs_reader_config().gmem_config
89    }
90
91    fn out_global_memory_config(&self) -> GlobalMemoryConfig {
92        self.matmul.writer_config().gmem_config
93    }
94}
95
96impl<M: GlobalConfig> ConvolutionConfig<M> {
97    #[allow(clippy::too_many_arguments)]
98    pub fn new(
99        matmul: M,
100        kernel_size: &[u32],
101        stride: &[u32],
102        dilation: &[u32],
103        padding: &[i32],
104        dim: Dimensionality,
105        num_stages: u32,
106    ) -> Result<Self, MatmulSetupError> {
107        let dims = kernel_size.len();
108
109        let mut params = ConvolutionParams {
110            kernel_size: [0; 3],
111            stride: [0; 3],
112            dilation: [0; 3],
113            padding: [0; 3],
114            dimensionality: dim,
115        };
116        params.kernel_size[0..dims].copy_from_slice(kernel_size);
117        params.stride[0..dims].copy_from_slice(stride);
118        params.dilation[0..dims].copy_from_slice(dilation);
119        params.padding[0..dims].copy_from_slice(padding);
120        Ok(Self {
121            matmul,
122            convolution_params: params,
123            num_stages,
124        })
125    }
126}