cubecl_convolution/components/
problem.rs

1use cubecl_matmul::components::{MatmulProblem, MatrixLayout};
2
3#[derive(Clone, Debug)]
4/// Description of a matmul problem to solve, regardless of actual data
5pub struct ConvolutionProblem {
6    pub m: usize,
7    pub n: usize,
8    pub k: usize,
9
10    pub lhs_strides: Vec<usize>,
11    pub rhs_strides: Vec<usize>,
12
13    pub lhs_layout: MatrixLayout,
14    pub rhs_layout: MatrixLayout,
15
16    pub kernel_size: Vec<u32>,
17    pub stride: Vec<u32>,
18    pub padding: Vec<i32>,
19    pub dilation: Vec<u32>,
20
21    pub batches: usize,
22    pub channels: usize,
23    pub shape: Vec<usize>,
24    pub out_shape: Vec<usize>,
25
26    pub dimensionality: Dimensionality,
27}
28
29impl ConvolutionProblem {
30    pub fn as_matmul_problem(&self) -> MatmulProblem {
31        // Strides are expected to be in row major (m, n) format so for matmul checks we need to
32        // convert them to that format, with all other dims treated as batch dims so they're still
33        // checked.
34        // lhs already has the right format, but rhs needs special handling.
35        let rank = self.lhs_strides.len();
36        // (h, w, c, n)
37        let mut rhs_strides = self.rhs_strides[1..rank].to_vec();
38        rhs_strides.push(self.rhs_strides[0]);
39
40        MatmulProblem {
41            m: self.m,
42            n: self.n,
43            k: self.k,
44            lhs_batches: vec![],
45            rhs_batches: vec![],
46            out_batches: vec![],
47            lhs_strides: self.lhs_strides.clone(),
48            rhs_strides,
49            lhs_layout: self.lhs_layout,
50            rhs_layout: self.rhs_layout,
51        }
52    }
53}
54
55/// Spatial dimensionality of an operation
56#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
57pub enum Dimensionality {
58    Dim1,
59    Dim2,
60    Dim3,
61}
62
63impl Dimensionality {
64    pub fn num_dims(&self) -> u32 {
65        match self {
66            Dimensionality::Dim1 => 1,
67            Dimensionality::Dim2 => 2,
68            Dimensionality::Dim3 => 3,
69        }
70    }
71}