cubek_convolution/components/
problem.rs1use cubecl::{
2 ir::AddressType,
3 zspace::{Shape, Strides, shape},
4};
5use cubek_matmul::definition::{MatmulGlobalElems, MatmulProblem};
6use cubek_std::MatrixLayout;
7
8#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub enum ConvolutionOperation {
10 Forward,
11 BackwardData,
12 BackwardWeight,
13 ForwardTransposed,
14}
15
16#[derive(Clone, Debug)]
17pub struct ConvolutionProblem {
19 pub m: usize,
20 pub n: usize,
21 pub k: usize,
22
23 pub lhs_strides: Strides,
24 pub rhs_strides: Strides,
25
26 pub lhs_layout: MatrixLayout,
27 pub rhs_layout: MatrixLayout,
28
29 pub kernel_size: Vec<u32>,
30 pub stride: Vec<u32>,
31 pub padding: Vec<i32>,
32 pub dilation: Vec<u32>,
33
34 pub batches: usize,
35 pub channels: usize,
36 pub out_channels: usize,
37 pub in_shape: Shape,
38 pub out_shape: Shape,
39
40 pub padded_channels: usize,
42 pub operation: ConvolutionOperation,
43
44 pub dimensionality: Dimensionality,
45
46 pub global_dtypes: MatmulGlobalElems,
47 pub address_type: AddressType,
49}
50
51impl ConvolutionProblem {
52 pub fn as_matmul_problem(&self) -> MatmulProblem {
53 let rank = self.lhs_strides.len();
54
55 let lhs_strides = match self.lhs_layout {
61 MatrixLayout::RowMajor => self.lhs_strides.clone(),
62 MatrixLayout::ColMajor => {
63 let mut lhs_strides: Strides = self.lhs_strides[1..rank].into();
64 lhs_strides.push(self.lhs_strides[0]);
65 lhs_strides
66 }
67 };
68 let rhs_strides = match self.rhs_layout {
69 MatrixLayout::RowMajor => self.rhs_strides.clone(),
70 MatrixLayout::ColMajor => {
71 let mut rhs_strides: Strides = self.rhs_strides[1..rank].into();
72 rhs_strides.push(self.rhs_strides[0]);
73 rhs_strides
74 }
75 };
76
77 MatmulProblem {
78 m: self.m,
79 n: self.n,
80 k: self.k,
81 lhs_batches: shape![],
82 rhs_batches: shape![],
83 out_batches: shape![],
84 lhs_strides,
85 rhs_strides,
86 lhs_layout: self.lhs_layout,
87 rhs_layout: self.rhs_layout,
88 lhs_shape: shape![self.m, self.k],
89 rhs_shape: shape![self.k, self.n],
90 out_shape: shape![self.m, self.n],
91 out_strides: MatrixLayout::RowMajor.to_strides(&[self.m, self.n]),
92 out_layout: MatrixLayout::RowMajor,
93 lhs_scheme: None,
94 rhs_scheme: None,
95 global_dtypes: self.global_dtypes.clone(),
96 address_type: self.address_type,
97 }
98 }
99
100 pub fn should_check_channel(&self) -> bool {
101 self.channels != self.padded_channels
102 }
103
104 pub fn should_check_spatial_bounds(&self) -> bool {
105 self.padding.iter().any(|&pad| pad != 0)
106 }
107}
108
109#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
111pub enum Dimensionality {
112 Dim1,
113 Dim2,
114 Dim3,
115}
116
117impl Dimensionality {
118 pub fn num_dims(&self) -> usize {
119 match self {
120 Dimensionality::Dim1 => 1,
121 Dimensionality::Dim2 => 2,
122 Dimensionality::Dim3 => 3,
123 }
124 }
125}