cubecl_convolution/config.rs
1use cubecl_matmul::components::{MatmulLineSizes, global::GlobalConfig};
2
3use super::base::Dimensionality;
4
5/// Convolution specific config, extends regular matmul [`Config`](global::Config)
6pub trait ConvGemmConfig: GlobalConfig {
7 /// The size of the convolution kernel at `dim`
8 fn kernel_size(&self, dim: u32) -> u32;
9 /// The dilation of the kernel at `dim`
10 fn dilation(&self, dim: u32) -> u32;
11 /// The stride of the kernel at `dim`
12 fn stride(&self, dim: u32) -> u32;
13 /// The padding of the kernel at `dim`
14 fn padding(&self, dim: u32) -> i32;
15 /// The dimensionality of the kernel
16 fn dimensionality(&self) -> Dimensionality;
17
18 fn line_sizes(&self) -> MatmulLineSizes;
19}