cubek_convolution/components/
problem.rs1use cubek_matmul::definition::{MatmulGlobalElems, MatmulProblem, MatrixLayout};
2
3#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
4pub enum ConvolutionOperation {
5 Forward,
6 BackwardData,
7 BackwardWeight,
8 ForwardTransposed,
9}
10
11#[derive(Clone, Debug)]
12pub struct ConvolutionProblem {
14 pub m: usize,
15 pub n: usize,
16 pub k: usize,
17
18 pub lhs_strides: Vec<usize>,
19 pub rhs_strides: Vec<usize>,
20
21 pub lhs_layout: MatrixLayout,
22 pub rhs_layout: MatrixLayout,
23
24 pub kernel_size: Vec<u32>,
25 pub stride: Vec<u32>,
26 pub padding: Vec<i32>,
27 pub dilation: Vec<u32>,
28
29 pub batches: usize,
30 pub channels: usize,
31 pub out_channels: usize,
32 pub in_shape: Vec<usize>,
33 pub out_shape: Vec<usize>,
34
35 pub padded_channels: usize,
37 pub operation: ConvolutionOperation,
38
39 pub dimensionality: Dimensionality,
40
41 pub global_dtypes: MatmulGlobalElems,
42}
43
44impl ConvolutionProblem {
45 pub fn as_matmul_problem(&self) -> MatmulProblem {
46 let rank = self.lhs_strides.len();
47
48 let lhs_strides = match self.lhs_layout {
54 MatrixLayout::RowMajor => self.lhs_strides.clone(),
55 MatrixLayout::ColMajor => {
56 let mut lhs_strides = self.lhs_strides[1..rank].to_vec();
57 lhs_strides.push(self.lhs_strides[0]);
58 lhs_strides
59 }
60 };
61 let rhs_strides = match self.rhs_layout {
62 MatrixLayout::RowMajor => self.rhs_strides.clone(),
63 MatrixLayout::ColMajor => {
64 let mut rhs_strides = self.rhs_strides[1..rank].to_vec();
65 rhs_strides.push(self.rhs_strides[0]);
66 rhs_strides
67 }
68 };
69
70 MatmulProblem {
71 m: self.m,
72 n: self.n,
73 k: self.k,
74 lhs_batches: vec![],
75 rhs_batches: vec![],
76 out_batches: vec![],
77 lhs_strides,
78 rhs_strides,
79 lhs_layout: self.lhs_layout,
80 rhs_layout: self.rhs_layout,
81 lhs_shape: vec![self.m, self.k],
82 rhs_shape: vec![self.k, self.n],
83 out_shape: vec![self.m, self.n],
84 out_strides: MatrixLayout::RowMajor.to_strides(&[self.m, self.n]),
85 out_layout: MatrixLayout::RowMajor,
86 lhs_scheme: None,
87 rhs_scheme: None,
88 global_dtypes: self.global_dtypes.clone(),
89 }
90 }
91
92 pub fn should_check_channel(&self) -> bool {
93 self.channels != self.padded_channels
94 }
95
96 pub fn should_check_spatial_bounds(&self) -> bool {
97 self.padding.iter().any(|&pad| pad != 0)
98 }
99}
100
101#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
103pub enum Dimensionality {
104 Dim1,
105 Dim2,
106 Dim3,
107}
108
109impl Dimensionality {
110 pub fn num_dims(&self) -> usize {
111 match self {
112 Dimensionality::Dim1 => 1,
113 Dimensionality::Dim2 => 2,
114 Dimensionality::Dim3 => 3,
115 }
116 }
117}