Skip to main content

cubek_convolution/components/
problem.rs

1use cubecl::{
2    ir::AddressType,
3    zspace::{Shape, Strides, shape},
4};
5use cubek_matmul::definition::{MatmulGlobalElems, MatmulProblem};
6use cubek_std::MatrixLayout;
7
8#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub enum ConvolutionOperation {
10    Forward,
11    BackwardData,
12    BackwardWeight,
13    ForwardTransposed,
14}
15
16#[derive(Clone, Debug)]
17/// Description of a matmul problem to solve, regardless of actual data
18pub struct ConvolutionProblem {
19    pub m: usize,
20    pub n: usize,
21    pub k: usize,
22
23    pub lhs_strides: Strides,
24    pub rhs_strides: Strides,
25
26    pub lhs_layout: MatrixLayout,
27    pub rhs_layout: MatrixLayout,
28
29    pub kernel_size: Vec<u32>,
30    pub stride: Vec<u32>,
31    pub padding: Vec<i32>,
32    pub dilation: Vec<u32>,
33
34    pub batches: usize,
35    pub channels: usize,
36    pub out_channels: usize,
37    pub in_shape: Shape,
38    pub out_shape: Shape,
39
40    /// Channels after applying loader-specific padding
41    pub padded_channels: usize,
42    pub operation: ConvolutionOperation,
43
44    pub dimensionality: Dimensionality,
45
46    pub global_dtypes: MatmulGlobalElems,
47    /// Address type, defined as the max of each handle's `required_address_type`
48    pub address_type: AddressType,
49}
50
51impl ConvolutionProblem {
52    pub fn as_matmul_problem(&self) -> MatmulProblem {
53        let rank = self.lhs_strides.len();
54
55        // Strides are expected to be in row major (m, n) format so for matmul checks we need to
56        // convert them to that format, with all other dims treated as batch dims so they're still
57        // checked.
58        // lhs already has the right format, but rhs needs special handling.
59        // (h, w, c, n)
60        let lhs_strides = match self.lhs_layout {
61            MatrixLayout::RowMajor => self.lhs_strides.clone(),
62            MatrixLayout::ColMajor => {
63                let mut lhs_strides: Strides = self.lhs_strides[1..rank].into();
64                lhs_strides.push(self.lhs_strides[0]);
65                lhs_strides
66            }
67        };
68        let rhs_strides = match self.rhs_layout {
69            MatrixLayout::RowMajor => self.rhs_strides.clone(),
70            MatrixLayout::ColMajor => {
71                let mut rhs_strides: Strides = self.rhs_strides[1..rank].into();
72                rhs_strides.push(self.rhs_strides[0]);
73                rhs_strides
74            }
75        };
76
77        MatmulProblem {
78            m: self.m,
79            n: self.n,
80            k: self.k,
81            lhs_batches: shape![],
82            rhs_batches: shape![],
83            out_batches: shape![],
84            lhs_strides,
85            rhs_strides,
86            lhs_layout: self.lhs_layout,
87            rhs_layout: self.rhs_layout,
88            lhs_shape: shape![self.m, self.k],
89            rhs_shape: shape![self.k, self.n],
90            out_shape: shape![self.m, self.n],
91            out_strides: MatrixLayout::RowMajor.to_strides(&[self.m, self.n]),
92            out_layout: MatrixLayout::RowMajor,
93            lhs_scheme: None,
94            rhs_scheme: None,
95            global_dtypes: self.global_dtypes.clone(),
96            address_type: self.address_type,
97        }
98    }
99
100    pub fn should_check_channel(&self) -> bool {
101        self.channels != self.padded_channels
102    }
103
104    pub fn should_check_spatial_bounds(&self) -> bool {
105        self.padding.iter().any(|&pad| pad != 0)
106    }
107}
108
109/// Spatial dimensionality of an operation
110#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
111pub enum Dimensionality {
112    Dim1,
113    Dim2,
114    Dim3,
115}
116
117impl Dimensionality {
118    pub fn num_dims(&self) -> usize {
119        match self {
120            Dimensionality::Dim1 => 1,
121            Dimensionality::Dim2 => 2,
122            Dimensionality::Dim3 => 3,
123        }
124    }
125}