cubecl_matmul/components/
problem.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use serde::{Deserialize, Serialize};
4
5use super::{MatmulIdent, MatmulProblemSize};
6
7#[derive(Clone, Debug)]
8/// Description of a matmul problem to solve, regardless of actual data
9pub struct MatmulProblem {
10    /// Number of rows in the output matrix
11    pub m: usize,
12    /// Number of columns in the output matrix
13    pub n: usize,
14    /// Reduction dimension
15    pub k: usize,
16
17    /// Batch shape for Lhs tensor
18    pub lhs_batches: Vec<usize>,
19    /// Batch shape for Rhs tensor
20    pub rhs_batches: Vec<usize>,
21    /// Batch shape for Out tensor
22    pub out_batches: Vec<usize>,
23
24    /// Strides for the Lhs tensor
25    pub lhs_strides: Vec<usize>,
26    /// Strides for the Rhs tensor
27    pub rhs_strides: Vec<usize>,
28
29    /// Memory layout of the Lhs matrix.
30    pub lhs_layout: MatrixLayout,
31    /// Memory layout of the Rhs matrix.
32    pub rhs_layout: MatrixLayout,
33}
34
35impl MatmulProblem {
36    /// Returns the batch dimensions of the output
37    fn output_batch_dims(&self) -> Vec<usize> {
38        self.lhs_batches
39            .iter()
40            .rev()
41            .zip(self.rhs_batches.iter().rev())
42            .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
43            .collect()
44    }
45
46    /// Returns the total number of batches of the output
47    pub(crate) fn num_batches(&self) -> usize {
48        self.output_batch_dims().iter().product()
49    }
50
51    /// Returns the shape of the identified tensor, inferred by the problem definition
52    #[allow(unused)]
53    pub(crate) fn shape(&self, ident: MatmulIdent) -> Vec<usize> {
54        match ident {
55            MatmulIdent::Lhs => self
56                .lhs_batches
57                .iter()
58                .cloned()
59                .chain(vec![self.m, self.k])
60                .collect(),
61            MatmulIdent::Rhs => self
62                .rhs_batches
63                .iter()
64                .cloned()
65                .chain(vec![self.k, self.n])
66                .collect(),
67            MatmulIdent::Out => self
68                .output_batch_dims()
69                .iter()
70                .cloned()
71                .chain(vec![self.m, self.n])
72                .collect(),
73        }
74    }
75}
76
77#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
78/// Interpretation of matrix multiplication based on input shapes.
79pub enum MatmulKind {
80    /// (M, K) @ (K, N) → (M, N), with M, K, N > 1
81    General,
82
83    /// (M, K) @ (K, 1) → (M, 1)
84    MatVec,
85
86    /// (1, K) @ (K, N) → (1, N)
87    VecMat,
88
89    /// (1, 1) @ (1, N) → (1, N)
90    ScalarVec,
91
92    /// (M, 1) @ (1, 1) → (M, 1)
93    VecScalar,
94
95    /// (1, K) @ (K, 1) → (1, 1)
96    InnerProduct,
97
98    /// (M, 1) @ (1, N) → (M, N)
99    OuterProduct,
100
101    /// (1, 1) @ (1, 1) → (1, 1)
102    ScalarProduct,
103}
104
105impl From<MatmulProblemSize> for MatmulKind {
106    fn from(matmul_size: MatmulProblemSize) -> Self {
107        enum DimKind {
108            Scalar,
109            Vector,
110        }
111
112        impl From<u32> for DimKind {
113            fn from(x: u32) -> Self {
114                match x {
115                    1 => DimKind::Scalar,
116                    _ => DimKind::Vector,
117                }
118            }
119        }
120
121        use DimKind::*;
122        match (
123            matmul_size.m().into(),
124            matmul_size.n().into(),
125            matmul_size.k().into(),
126        ) {
127            (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
128            (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
129            (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
130            (Scalar, Vector, Vector) => MatmulKind::VecMat,
131            (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
132            (Vector, Scalar, Vector) => MatmulKind::MatVec,
133            (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
134            (Vector, Vector, Vector) => MatmulKind::General,
135        }
136    }
137}
138
139impl From<MatmulProblem> for MatmulProblemSize {
140    fn from(problem: MatmulProblem) -> Self {
141        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
142    }
143}
144
145impl From<MatmulProblem> for MatmulKind {
146    fn from(problem: MatmulProblem) -> Self {
147        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
148    }
149}
150
151impl From<&MatmulProblem> for MatmulKind {
152    fn from(problem: &MatmulProblem) -> Self {
153        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
154    }
155}
156
157#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
158/// Layout of a 2D structure such as a tensor, shared memory or slice,
159/// used within any matmul kernel level
160pub enum MatrixLayout {
161    #[default]
162    RowMajor,
163    ColMajor,
164}
165
166#[cube]
167/// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API.
168pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
169    match layout {
170        MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
171        MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
172    }
173}
174
175#[cube]
176/// Maps the cmma's MatrixLayout to matmul MatrixLayout.
177pub fn from_cmma_layout(#[comptime] layout: cmma::MatrixLayout) -> comptime_type!(MatrixLayout) {
178    match layout {
179        cmma::MatrixLayout::RowMajor => MatrixLayout::RowMajor,
180        cmma::MatrixLayout::ColMajor => MatrixLayout::ColMajor,
181        cmma::MatrixLayout::Undefined => MatrixLayout::RowMajor,
182    }
183}