cubecl_convolution/components/
problem.rs

1use cubecl_matmul::components::{MatmulProblem, MatrixLayout};
2
3#[derive(Clone, Debug)]
4/// Description of a matmul problem to solve, regardless of actual data
5pub struct ConvolutionProblem {
6    pub m: usize,
7    pub n: usize,
8    pub k: usize,
9    pub lhs_layout: MatrixLayout,
10    pub rhs_layout: MatrixLayout,
11
12    pub kernel_size: Vec<u32>,
13    pub stride: Vec<u32>,
14    pub padding: Vec<i32>,
15    pub dilation: Vec<u32>,
16
17    pub batches: usize,
18    pub channels: usize,
19    pub shape: Vec<usize>,
20    pub out_shape: Vec<usize>,
21
22    pub dimensionality: Dimensionality,
23}
24
25impl ConvolutionProblem {
26    pub fn as_matmul_problem(&self) -> MatmulProblem {
27        MatmulProblem {
28            m: self.m,
29            n: self.n,
30            k: self.k,
31            lhs_batches: vec![],
32            rhs_batches: vec![],
33            out_batches: vec![],
34            lhs_layout: self.lhs_layout,
35            rhs_layout: self.rhs_layout,
36        }
37    }
38}
39
40/// Spatial dimensionality of an operation
41#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
42pub enum Dimensionality {
43    Dim1,
44    Dim2,
45    Dim3,
46}
47
48impl Dimensionality {
49    pub fn num_dims(&self) -> u32 {
50        match self {
51            Dimensionality::Dim1 => 1,
52            Dimensionality::Dim2 => 2,
53            Dimensionality::Dim3 => 3,
54        }
55    }
56}