cubecl_convolution/components/
problem.rs1use cubecl_matmul::components::{MatmulProblem, MatrixLayout};
2
3#[derive(Clone, Debug)]
4pub struct ConvolutionProblem {
6 pub m: usize,
7 pub n: usize,
8 pub k: usize,
9
10 pub lhs_strides: Vec<usize>,
11 pub rhs_strides: Vec<usize>,
12
13 pub lhs_layout: MatrixLayout,
14 pub rhs_layout: MatrixLayout,
15
16 pub kernel_size: Vec<u32>,
17 pub stride: Vec<u32>,
18 pub padding: Vec<i32>,
19 pub dilation: Vec<u32>,
20
21 pub batches: usize,
22 pub channels: usize,
23 pub shape: Vec<usize>,
24 pub out_shape: Vec<usize>,
25
26 pub dimensionality: Dimensionality,
27}
28
29impl ConvolutionProblem {
30 pub fn as_matmul_problem(&self) -> MatmulProblem {
31 let rank = self.lhs_strides.len();
36 let mut rhs_strides = self.rhs_strides[1..rank].to_vec();
38 rhs_strides.push(self.rhs_strides[0]);
39
40 MatmulProblem {
41 m: self.m,
42 n: self.n,
43 k: self.k,
44 lhs_batches: vec![],
45 rhs_batches: vec![],
46 out_batches: vec![],
47 lhs_strides: self.lhs_strides.clone(),
48 rhs_strides,
49 lhs_layout: self.lhs_layout,
50 rhs_layout: self.rhs_layout,
51 }
52 }
53}
54
55#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
57pub enum Dimensionality {
58 Dim1,
59 Dim2,
60 Dim3,
61}
62
63impl Dimensionality {
64 pub fn num_dims(&self) -> u32 {
65 match self {
66 Dimensionality::Dim1 => 1,
67 Dimensionality::Dim2 => 2,
68 Dimensionality::Dim3 => 3,
69 }
70 }
71}