cubecl_convolution/components/
problem.rs1use cubecl_matmul::components::{MatmulProblem, MatrixLayout};
2
3#[derive(Clone, Debug)]
4pub struct ConvolutionProblem {
6 pub m: usize,
7 pub n: usize,
8 pub k: usize,
9 pub lhs_layout: MatrixLayout,
10 pub rhs_layout: MatrixLayout,
11
12 pub kernel_size: Vec<u32>,
13 pub stride: Vec<u32>,
14 pub padding: Vec<i32>,
15 pub dilation: Vec<u32>,
16
17 pub batches: usize,
18 pub channels: usize,
19 pub shape: Vec<usize>,
20 pub out_shape: Vec<usize>,
21
22 pub dimensionality: Dimensionality,
23}
24
25impl ConvolutionProblem {
26 pub fn as_matmul_problem(&self) -> MatmulProblem {
27 MatmulProblem {
28 m: self.m,
29 n: self.n,
30 k: self.k,
31 lhs_batches: vec![],
32 rhs_batches: vec![],
33 out_batches: vec![],
34 lhs_layout: self.lhs_layout,
35 rhs_layout: self.rhs_layout,
36 }
37 }
38}
39
40#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
42pub enum Dimensionality {
43 Dim1,
44 Dim2,
45 Dim3,
46}
47
48impl Dimensionality {
49 pub fn num_dims(&self) -> u32 {
50 match self {
51 Dimensionality::Dim1 => 1,
52 Dimensionality::Dim2 => 2,
53 Dimensionality::Dim3 => 3,
54 }
55 }
56}