cubecl_matmul/components/
problem.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use serde::{Deserialize, Serialize};
4
5use super::{Ident, MatmulProblemSize};
6
7#[derive(Clone, Debug)]
8/// Description of a matmul problem to solve, regardless of actual data
9pub struct MatmulProblem {
10    /// Number of rows in the output matrix
11    pub m: usize,
12    /// Number of columns in the output matrix
13    pub n: usize,
14    /// Reduction dimension
15    pub k: usize,
16    /// Batch shape for Lhs tensor
17    pub lhs_batches: Vec<usize>,
18    /// Batch shape for Rhs tensor
19    pub rhs_batches: Vec<usize>,
20    /// Memory layout of the Lhs matrix.
21    pub lhs_layout: MatrixLayout,
22    /// Memory layout of the Rhs matrix.
23    pub rhs_layout: MatrixLayout,
24}
25
26impl MatmulProblem {
27    /// Returns the batch dimensions of the output
28    fn output_batch_dims(&self) -> Vec<usize> {
29        self.lhs_batches
30            .iter()
31            .rev()
32            .zip(self.rhs_batches.iter().rev())
33            .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
34            .collect()
35    }
36
37    /// Returns the total number of batches of the output
38    pub(crate) fn num_batches(&self) -> usize {
39        self.output_batch_dims().iter().product()
40    }
41
42    /// Returns the shape of the identified tensor, inferred by the problem definition
43    #[allow(unused)]
44    pub(crate) fn shape(&self, ident: Ident) -> Vec<usize> {
45        match ident {
46            Ident::Lhs => self
47                .lhs_batches
48                .iter()
49                .cloned()
50                .chain(vec![self.m, self.k])
51                .collect(),
52            Ident::Rhs => self
53                .rhs_batches
54                .iter()
55                .cloned()
56                .chain(vec![self.k, self.n])
57                .collect(),
58            Ident::Out => self
59                .output_batch_dims()
60                .iter()
61                .cloned()
62                .chain(vec![self.m, self.n])
63                .collect(),
64        }
65    }
66}
67
68#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
69/// Interpretation of matrix multiplication based on input shapes.
70pub enum MatmulKind {
71    /// (M, K) @ (K, N) → (M, N), with M, K, N > 1
72    General,
73
74    /// (M, K) @ (K, 1) → (M, 1)
75    MatVec,
76
77    /// (1, K) @ (K, N) → (1, N)
78    VecMat,
79
80    /// (1, 1) @ (1, N) → (1, N)
81    ScalarVec,
82
83    /// (M, 1) @ (1, 1) → (M, 1)
84    VecScalar,
85
86    /// (1, K) @ (K, 1) → (1, 1)
87    InnerProduct,
88
89    /// (M, 1) @ (1, N) → (M, N)
90    OuterProduct,
91
92    /// (1, 1) @ (1, 1) → (1, 1)
93    ScalarProduct,
94}
95
96impl From<MatmulProblemSize> for MatmulKind {
97    fn from(matmul_size: MatmulProblemSize) -> Self {
98        enum DimKind {
99            Scalar,
100            Vector,
101        }
102
103        impl From<u32> for DimKind {
104            fn from(x: u32) -> Self {
105                match x {
106                    1 => DimKind::Scalar,
107                    _ => DimKind::Vector,
108                }
109            }
110        }
111
112        use DimKind::*;
113        match (
114            matmul_size.m().into(),
115            matmul_size.n().into(),
116            matmul_size.k().into(),
117        ) {
118            (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
119            (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
120            (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
121            (Scalar, Vector, Vector) => MatmulKind::VecMat,
122            (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
123            (Vector, Scalar, Vector) => MatmulKind::MatVec,
124            (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
125            (Vector, Vector, Vector) => MatmulKind::General,
126        }
127    }
128}
129
130impl From<MatmulProblem> for MatmulProblemSize {
131    fn from(problem: MatmulProblem) -> Self {
132        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
133    }
134}
135
136impl From<MatmulProblem> for MatmulKind {
137    fn from(problem: MatmulProblem) -> Self {
138        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
139    }
140}
141
142impl From<&MatmulProblem> for MatmulKind {
143    fn from(problem: &MatmulProblem) -> Self {
144        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
145    }
146}
147
148#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug)]
149/// Layout of a 2D structure such as a tensor, shared memory or slice,
150/// used within any matmul kernel level
151pub enum MatrixLayout {
152    RowMajor,
153    ColMajor,
154}
155
156#[cube]
157/// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API.
158pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
159    match layout {
160        MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
161        MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
162    }
163}