cubek_convolution/components/
problem.rs

1use cubek_matmul::definition::{MatmulGlobalElems, MatmulProblem, MatrixLayout};
2
3#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
4pub enum ConvolutionOperation {
5    Forward,
6    BackwardData,
7    BackwardWeight,
8    ForwardTransposed,
9}
10
11#[derive(Clone, Debug)]
12/// Description of a matmul problem to solve, regardless of actual data
13pub struct ConvolutionProblem {
14    pub m: usize,
15    pub n: usize,
16    pub k: usize,
17
18    pub lhs_strides: Vec<usize>,
19    pub rhs_strides: Vec<usize>,
20
21    pub lhs_layout: MatrixLayout,
22    pub rhs_layout: MatrixLayout,
23
24    pub kernel_size: Vec<u32>,
25    pub stride: Vec<u32>,
26    pub padding: Vec<i32>,
27    pub dilation: Vec<u32>,
28
29    pub batches: usize,
30    pub channels: usize,
31    pub out_channels: usize,
32    pub in_shape: Vec<usize>,
33    pub out_shape: Vec<usize>,
34
35    /// Channels after applying loader-specific padding
36    pub padded_channels: usize,
37    pub operation: ConvolutionOperation,
38
39    pub dimensionality: Dimensionality,
40
41    pub global_dtypes: MatmulGlobalElems,
42}
43
44impl ConvolutionProblem {
45    pub fn as_matmul_problem(&self) -> MatmulProblem {
46        let rank = self.lhs_strides.len();
47
48        // Strides are expected to be in row major (m, n) format so for matmul checks we need to
49        // convert them to that format, with all other dims treated as batch dims so they're still
50        // checked.
51        // lhs already has the right format, but rhs needs special handling.
52        // (h, w, c, n)
53        let lhs_strides = match self.lhs_layout {
54            MatrixLayout::RowMajor => self.lhs_strides.clone(),
55            MatrixLayout::ColMajor => {
56                let mut lhs_strides = self.lhs_strides[1..rank].to_vec();
57                lhs_strides.push(self.lhs_strides[0]);
58                lhs_strides
59            }
60        };
61        let rhs_strides = match self.rhs_layout {
62            MatrixLayout::RowMajor => self.rhs_strides.clone(),
63            MatrixLayout::ColMajor => {
64                let mut rhs_strides = self.rhs_strides[1..rank].to_vec();
65                rhs_strides.push(self.rhs_strides[0]);
66                rhs_strides
67            }
68        };
69
70        MatmulProblem {
71            m: self.m,
72            n: self.n,
73            k: self.k,
74            lhs_batches: vec![],
75            rhs_batches: vec![],
76            out_batches: vec![],
77            lhs_strides,
78            rhs_strides,
79            lhs_layout: self.lhs_layout,
80            rhs_layout: self.rhs_layout,
81            lhs_shape: vec![self.m, self.k],
82            rhs_shape: vec![self.k, self.n],
83            out_shape: vec![self.m, self.n],
84            out_strides: MatrixLayout::RowMajor.to_strides(&[self.m, self.n]),
85            out_layout: MatrixLayout::RowMajor,
86            lhs_scheme: None,
87            rhs_scheme: None,
88            global_dtypes: self.global_dtypes.clone(),
89        }
90    }
91
92    pub fn should_check_channel(&self) -> bool {
93        self.channels != self.padded_channels
94    }
95
96    pub fn should_check_spatial_bounds(&self) -> bool {
97        self.padding.iter().any(|&pad| pad != 0)
98    }
99}
100
101/// Spatial dimensionality of an operation
102#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
103pub enum Dimensionality {
104    Dim1,
105    Dim2,
106    Dim3,
107}
108
109impl Dimensionality {
110    pub fn num_dims(&self) -> usize {
111        match self {
112            Dimensionality::Dim1 => 1,
113            Dimensionality::Dim2 => 2,
114            Dimensionality::Dim3 => 3,
115        }
116    }
117}