cubecl_convolution/components/
config.rs

1use std::ops::Deref;
2
3use cubecl_core::CubeDim;
4use cubecl_matmul::components::{
5    MatmulIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
6    global::{
7        GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode,
8        read::ReaderMode,
9    },
10    stage::{StageConfig, StageMemoryConfig},
11};
12
13use super::*;
14
15/// Convolution specific config, extends regular matmul [`Config`](global::Config)
16pub trait ConvGemmConfig: GlobalConfig {
17    /// The size of the convolution kernel at `dim`
18    fn convolution_params(&self) -> ConvolutionParams;
19    fn line_sizes(&self) -> MatmulLineSizes;
20    fn check_spatial_bounds(&self) -> bool;
21}
22
23#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
24pub struct ConvolutionConfig<M: GlobalConfig> {
25    matmul: M,
26    params: ConvolutionParams,
27    num_stages: u32,
28}
29
30#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
31pub struct ConvolutionParams {
32    pub kernel_size: [u32; 3],
33    pub stride: [u32; 3],
34    pub dilation: [u32; 3],
35    pub padding: [i32; 3],
36    pub dimensionality: Dimensionality,
37}
38
39impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
40    type Target = M;
41
42    fn deref(&self) -> &Self::Target {
43        &self.matmul
44    }
45}
46
47impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
48    type StageConfig = M::StageConfig;
49
50    fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
51        self.stage_config().stage_memory_config(ident.into_stage())
52    }
53
54    fn stage_config(&self) -> Self::StageConfig {
55        self.matmul.stage_config()
56    }
57
58    fn global_line_size(&self, ident: MatmulIdent) -> u32 {
59        self.matmul.global_line_size(ident)
60    }
61
62    fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout {
63        self.matmul.matrix_layout(ident)
64    }
65
66    fn num_loading_planes(&self, ident: MatmulIdent) -> u32 {
67        self.matmul.num_loading_planes(ident)
68    }
69
70    fn plane_dim(&self) -> u32 {
71        self.matmul.plane_dim()
72    }
73
74    fn check_row_bounds(&self, ident: MatmulIdent) -> bool {
75        self.matmul.check_row_bounds(ident)
76    }
77
78    fn check_col_bounds(&self, ident: MatmulIdent) -> bool {
79        self.matmul.check_col_bounds(ident)
80    }
81
82    fn check_k_bounds(&self) -> bool {
83        self.matmul.check_k_bounds()
84    }
85
86    fn precompute_job(&self) -> bool {
87        self.matmul.precompute_job()
88    }
89
90    fn num_stages(&self, _ident: MatmulIdent) -> u32 {
91        self.num_stages
92    }
93
94    fn reader_mode(&self) -> ReaderMode {
95        self.matmul.reader_mode()
96    }
97
98    fn tiling_scheme(&self) -> TilingScheme {
99        self.matmul.tiling_scheme()
100    }
101
102    fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode {
103        self.matmul.event_loading_mode(ident)
104    }
105
106    fn plane_role_config(&self) -> PlaneRoleConfig {
107        self.matmul.plane_role_config()
108    }
109
110    fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
111        self.matmul.specialized_loading_sides()
112    }
113
114    fn cube_dim(&self) -> CubeDim {
115        CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
116    }
117}
118
119impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
120    fn convolution_params(&self) -> ConvolutionParams {
121        self.params
122    }
123
124    fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
125        MatmulLineSizes {
126            lhs: self.global_line_size(MatmulIdent::Lhs) as u8,
127            rhs: self.global_line_size(MatmulIdent::Rhs) as u8,
128            out: self.global_line_size(MatmulIdent::Out) as u8,
129        }
130    }
131
132    fn check_spatial_bounds(&self) -> bool {
133        let spatial_dims = self.params.dimensionality.num_dims();
134        let mut has_padding = false;
135        for i in 0..spatial_dims {
136            has_padding |= self.params.padding[i as usize] != 0;
137        }
138        has_padding
139    }
140}
141
142impl<M: GlobalConfig> ConvolutionConfig<M> {
143    #[allow(clippy::too_many_arguments)]
144    pub fn new(
145        matmul: M,
146        kernel_size: &[u32],
147        stride: &[u32],
148        dilation: &[u32],
149        padding: &[i32],
150        dim: Dimensionality,
151        num_stages: u32,
152    ) -> Result<Self, MatmulSetupError> {
153        let dims = kernel_size.len();
154
155        let mut params = ConvolutionParams {
156            kernel_size: [0; 3],
157            stride: [0; 3],
158            dilation: [0; 3],
159            padding: [0; 3],
160            dimensionality: dim,
161        };
162        params.kernel_size[0..dims].copy_from_slice(kernel_size);
163        params.stride[0..dims].copy_from_slice(stride);
164        params.dilation[0..dims].copy_from_slice(dilation);
165        params.padding[0..dims].copy_from_slice(padding);
166        Ok(Self {
167            matmul,
168            params,
169            num_stages,
170        })
171    }
172
173    pub fn to_matmul_config(self) -> M {
174        self.matmul
175    }
176}