cubecl_convolution/components/
config.rs

1use std::ops::Deref;
2
3use cubecl_core::CubeDim;
4use cubecl_matmul::components::{
5    MatmulIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
6    global::{
7        GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode,
8        read::ReaderMode,
9    },
10    stage::{StageConfig, StageMemoryConfig, SwizzleMode, TilingLayoutEnum},
11};
12
13use super::*;
14
15/// Convolution specific config, extends regular matmul [`Config`](global::Config)
16pub trait ConvGemmConfig: GlobalConfig {
17    /// The size of the convolution kernel at `dim`
18    fn convolution_params(&self) -> ConvolutionParams;
19    fn line_sizes(&self) -> MatmulLineSizes;
20    fn check_spatial_bounds(&self) -> bool;
21}
22
23#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
24pub struct ConvolutionConfig<M: GlobalConfig> {
25    matmul: M,
26    params: ConvolutionParams,
27    num_stages: u32,
28}
29
30#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
31pub struct ConvolutionParams {
32    pub kernel_size: [u32; 3],
33    pub stride: [u32; 3],
34    pub dilation: [u32; 3],
35    pub padding: [i32; 3],
36    pub dimensionality: Dimensionality,
37}
38
39impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
40    type Target = M;
41
42    fn deref(&self) -> &Self::Target {
43        &self.matmul
44    }
45}
46
47impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
48    type StageConfig = M::StageConfig;
49
50    fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
51        self.stage_config().stage_memory_config(ident.into_stage())
52    }
53
54    fn stage_config(&self) -> Self::StageConfig {
55        self.matmul.stage_config()
56    }
57
58    fn global_line_size(&self, ident: MatmulIdent) -> u32 {
59        self.matmul.global_line_size(ident)
60    }
61
62    fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout {
63        self.matmul.matrix_layout(ident)
64    }
65
66    fn swizzle_mode(&self, ident: MatmulIdent) -> SwizzleMode {
67        self.matmul.swizzle_mode(ident)
68    }
69
70    fn tiling_layout(&self, ident: MatmulIdent) -> TilingLayoutEnum {
71        self.matmul.tiling_layout(ident)
72    }
73
74    fn num_loading_planes(&self, ident: MatmulIdent) -> u32 {
75        self.matmul.num_loading_planes(ident)
76    }
77
78    fn plane_dim(&self) -> u32 {
79        self.matmul.plane_dim()
80    }
81
82    fn check_row_bounds(&self, ident: MatmulIdent) -> bool {
83        self.matmul.check_row_bounds(ident)
84    }
85
86    fn check_col_bounds(&self, ident: MatmulIdent) -> bool {
87        self.matmul.check_col_bounds(ident)
88    }
89
90    fn check_k_bounds(&self) -> bool {
91        self.matmul.check_k_bounds()
92    }
93
94    fn precompute_job(&self) -> bool {
95        self.matmul.precompute_job()
96    }
97
98    fn num_stages(&self, _ident: MatmulIdent) -> u32 {
99        self.num_stages
100    }
101
102    fn reader_mode(&self) -> ReaderMode {
103        self.matmul.reader_mode()
104    }
105
106    fn tiling_scheme(&self) -> TilingScheme {
107        self.matmul.tiling_scheme()
108    }
109
110    fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode {
111        self.matmul.event_loading_mode(ident)
112    }
113
114    fn plane_role_config(&self) -> PlaneRoleConfig {
115        self.matmul.plane_role_config()
116    }
117
118    fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
119        self.matmul.specialized_loading_sides()
120    }
121
122    fn cube_dim(&self) -> CubeDim {
123        CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
124    }
125}
126
127impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
128    fn convolution_params(&self) -> ConvolutionParams {
129        self.params
130    }
131
132    fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
133        MatmulLineSizes {
134            lhs: self.global_line_size(MatmulIdent::Lhs) as u8,
135            rhs: self.global_line_size(MatmulIdent::Rhs) as u8,
136            out: self.global_line_size(MatmulIdent::Out) as u8,
137        }
138    }
139
140    fn check_spatial_bounds(&self) -> bool {
141        let spatial_dims = self.params.dimensionality.num_dims();
142        let mut has_padding = false;
143        for i in 0..spatial_dims {
144            has_padding |= self.params.padding[i as usize] != 0;
145        }
146        has_padding
147    }
148}
149
150impl<M: GlobalConfig> ConvolutionConfig<M> {
151    #[allow(clippy::too_many_arguments)]
152    pub fn new(
153        matmul: M,
154        kernel_size: &[u32],
155        stride: &[u32],
156        dilation: &[u32],
157        padding: &[i32],
158        dim: Dimensionality,
159        num_stages: u32,
160    ) -> Result<Self, MatmulSetupError> {
161        let dims = kernel_size.len();
162
163        let mut params = ConvolutionParams {
164            kernel_size: [0; 3],
165            stride: [0; 3],
166            dilation: [0; 3],
167            padding: [0; 3],
168            dimensionality: dim,
169        };
170        params.kernel_size[0..dims].copy_from_slice(kernel_size);
171        params.stride[0..dims].copy_from_slice(stride);
172        params.dilation[0..dims].copy_from_slice(dilation);
173        params.padding[0..dims].copy_from_slice(padding);
174        Ok(Self {
175            matmul,
176            params,
177            num_stages,
178        })
179    }
180
181    pub fn to_matmul_config(self) -> M {
182        self.matmul
183    }
184}