cubecl_matmul/components/
problem.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use serde::{Deserialize, Serialize};
4
5use super::{MatmulIdent, MatmulProblemSize};
6
7#[derive(Clone, Debug)]
8/// Description of a matmul problem to solve, regardless of actual data
9pub struct MatmulProblem {
10    /// Number of rows in the output matrix
11    pub m: usize,
12    /// Number of columns in the output matrix
13    pub n: usize,
14    /// Reduction dimension
15    pub k: usize,
16    /// Batch shape for Lhs tensor
17    pub lhs_batches: Vec<usize>,
18    /// Batch shape for Rhs tensor
19    pub rhs_batches: Vec<usize>,
20    /// Batch shape for Out tensor
21    pub out_batches: Vec<usize>,
22    /// Memory layout of the Lhs matrix.
23    pub lhs_layout: MatrixLayout,
24    /// Memory layout of the Rhs matrix.
25    pub rhs_layout: MatrixLayout,
26}
27
28impl MatmulProblem {
29    /// Returns the batch dimensions of the output
30    fn output_batch_dims(&self) -> Vec<usize> {
31        self.lhs_batches
32            .iter()
33            .rev()
34            .zip(self.rhs_batches.iter().rev())
35            .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
36            .collect()
37    }
38
39    /// Returns the total number of batches of the output
40    pub(crate) fn num_batches(&self) -> usize {
41        self.output_batch_dims().iter().product()
42    }
43
44    /// Returns the shape of the identified tensor, inferred by the problem definition
45    #[allow(unused)]
46    pub(crate) fn shape(&self, ident: MatmulIdent) -> Vec<usize> {
47        match ident {
48            MatmulIdent::Lhs => self
49                .lhs_batches
50                .iter()
51                .cloned()
52                .chain(vec![self.m, self.k])
53                .collect(),
54            MatmulIdent::Rhs => self
55                .rhs_batches
56                .iter()
57                .cloned()
58                .chain(vec![self.k, self.n])
59                .collect(),
60            MatmulIdent::Out => self
61                .output_batch_dims()
62                .iter()
63                .cloned()
64                .chain(vec![self.m, self.n])
65                .collect(),
66        }
67    }
68}
69
70#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
71/// Interpretation of matrix multiplication based on input shapes.
72pub enum MatmulKind {
73    /// (M, K) @ (K, N) → (M, N), with M, K, N > 1
74    General,
75
76    /// (M, K) @ (K, 1) → (M, 1)
77    MatVec,
78
79    /// (1, K) @ (K, N) → (1, N)
80    VecMat,
81
82    /// (1, 1) @ (1, N) → (1, N)
83    ScalarVec,
84
85    /// (M, 1) @ (1, 1) → (M, 1)
86    VecScalar,
87
88    /// (1, K) @ (K, 1) → (1, 1)
89    InnerProduct,
90
91    /// (M, 1) @ (1, N) → (M, N)
92    OuterProduct,
93
94    /// (1, 1) @ (1, 1) → (1, 1)
95    ScalarProduct,
96}
97
98impl From<MatmulProblemSize> for MatmulKind {
99    fn from(matmul_size: MatmulProblemSize) -> Self {
100        enum DimKind {
101            Scalar,
102            Vector,
103        }
104
105        impl From<u32> for DimKind {
106            fn from(x: u32) -> Self {
107                match x {
108                    1 => DimKind::Scalar,
109                    _ => DimKind::Vector,
110                }
111            }
112        }
113
114        use DimKind::*;
115        match (
116            matmul_size.m().into(),
117            matmul_size.n().into(),
118            matmul_size.k().into(),
119        ) {
120            (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
121            (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
122            (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
123            (Scalar, Vector, Vector) => MatmulKind::VecMat,
124            (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
125            (Vector, Scalar, Vector) => MatmulKind::MatVec,
126            (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
127            (Vector, Vector, Vector) => MatmulKind::General,
128        }
129    }
130}
131
132impl From<MatmulProblem> for MatmulProblemSize {
133    fn from(problem: MatmulProblem) -> Self {
134        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
135    }
136}
137
138impl From<MatmulProblem> for MatmulKind {
139    fn from(problem: MatmulProblem) -> Self {
140        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
141    }
142}
143
144impl From<&MatmulProblem> for MatmulKind {
145    fn from(problem: &MatmulProblem) -> Self {
146        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
147    }
148}
149
150#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
151/// Layout of a 2D structure such as a tensor, shared memory or slice,
152/// used within any matmul kernel level
153pub enum MatrixLayout {
154    #[default]
155    RowMajor,
156    ColMajor,
157}
158
159#[cube]
160/// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API.
161pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
162    match layout {
163        MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
164        MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
165    }
166}