cubecl_convolution/
config.rs

1use cubecl_matmul::components::{MatmulLineSizes, global::GlobalConfig};
2
3use super::base::Dimensionality;
4
5/// Convolution specific config, extends regular matmul [`Config`](global::Config)
6pub trait ConvGemmConfig: GlobalConfig {
7    /// The size of the convolution kernel at `dim`
8    fn kernel_size(&self, dim: u32) -> u32;
9    /// The dilation of the kernel at `dim`
10    fn dilation(&self, dim: u32) -> u32;
11    /// The stride of the kernel at `dim`
12    fn stride(&self, dim: u32) -> u32;
13    /// The padding of the kernel at `dim`
14    fn padding(&self, dim: u32) -> i32;
15    /// The dimensionality of the kernel
16    fn dimensionality(&self) -> Dimensionality;
17
18    fn line_sizes(&self) -> MatmulLineSizes;
19}