cubecl_matmul/components/
problem.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use serde::{Deserialize, Serialize};
4
5use super::{Ident, MatmulProblemSize};
6
7#[derive(Clone, Debug)]
8pub struct MatmulProblem {
10 pub m: usize,
12 pub n: usize,
14 pub k: usize,
16 pub lhs_batches: Vec<usize>,
18 pub rhs_batches: Vec<usize>,
20 pub lhs_layout: MatrixLayout,
22 pub rhs_layout: MatrixLayout,
24}
25
26impl MatmulProblem {
27 fn output_batch_dims(&self) -> Vec<usize> {
29 self.lhs_batches
30 .iter()
31 .rev()
32 .zip(self.rhs_batches.iter().rev())
33 .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
34 .collect()
35 }
36
37 pub(crate) fn num_batches(&self) -> usize {
39 self.output_batch_dims().iter().product()
40 }
41
42 #[allow(unused)]
44 pub(crate) fn shape(&self, ident: Ident) -> Vec<usize> {
45 match ident {
46 Ident::Lhs => self
47 .lhs_batches
48 .iter()
49 .cloned()
50 .chain(vec![self.m, self.k])
51 .collect(),
52 Ident::Rhs => self
53 .rhs_batches
54 .iter()
55 .cloned()
56 .chain(vec![self.k, self.n])
57 .collect(),
58 Ident::Out => self
59 .output_batch_dims()
60 .iter()
61 .cloned()
62 .chain(vec![self.m, self.n])
63 .collect(),
64 }
65 }
66}
67
68#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
69pub enum MatmulKind {
71 General,
73
74 MatVec,
76
77 VecMat,
79
80 ScalarVec,
82
83 VecScalar,
85
86 InnerProduct,
88
89 OuterProduct,
91
92 ScalarProduct,
94}
95
96impl From<MatmulProblemSize> for MatmulKind {
97 fn from(matmul_size: MatmulProblemSize) -> Self {
98 enum DimKind {
99 Scalar,
100 Vector,
101 }
102
103 impl From<u32> for DimKind {
104 fn from(x: u32) -> Self {
105 match x {
106 1 => DimKind::Scalar,
107 _ => DimKind::Vector,
108 }
109 }
110 }
111
112 use DimKind::*;
113 match (
114 matmul_size.m().into(),
115 matmul_size.n().into(),
116 matmul_size.k().into(),
117 ) {
118 (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
119 (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
120 (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
121 (Scalar, Vector, Vector) => MatmulKind::VecMat,
122 (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
123 (Vector, Scalar, Vector) => MatmulKind::MatVec,
124 (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
125 (Vector, Vector, Vector) => MatmulKind::General,
126 }
127 }
128}
129
130impl From<MatmulProblem> for MatmulProblemSize {
131 fn from(problem: MatmulProblem) -> Self {
132 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
133 }
134}
135
136impl From<MatmulProblem> for MatmulKind {
137 fn from(problem: MatmulProblem) -> Self {
138 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
139 }
140}
141
142impl From<&MatmulProblem> for MatmulKind {
143 fn from(problem: &MatmulProblem) -> Self {
144 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
145 }
146}
147
148#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug)]
149pub enum MatrixLayout {
152 RowMajor,
153 ColMajor,
154}
155
156#[cube]
157pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
159 match layout {
160 MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
161 MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
162 }
163}