cubecl_matmul/components/
problem.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use serde::{Deserialize, Serialize};
4
5use super::{MatmulIdent, MatmulProblemSize};
6
7#[derive(Clone, Debug)]
8pub struct MatmulProblem {
10 pub m: usize,
12 pub n: usize,
14 pub k: usize,
16 pub lhs_batches: Vec<usize>,
18 pub rhs_batches: Vec<usize>,
20 pub out_batches: Vec<usize>,
22 pub lhs_layout: MatrixLayout,
24 pub rhs_layout: MatrixLayout,
26}
27
28impl MatmulProblem {
29 fn output_batch_dims(&self) -> Vec<usize> {
31 self.lhs_batches
32 .iter()
33 .rev()
34 .zip(self.rhs_batches.iter().rev())
35 .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
36 .collect()
37 }
38
39 pub(crate) fn num_batches(&self) -> usize {
41 self.output_batch_dims().iter().product()
42 }
43
44 #[allow(unused)]
46 pub(crate) fn shape(&self, ident: MatmulIdent) -> Vec<usize> {
47 match ident {
48 MatmulIdent::Lhs => self
49 .lhs_batches
50 .iter()
51 .cloned()
52 .chain(vec![self.m, self.k])
53 .collect(),
54 MatmulIdent::Rhs => self
55 .rhs_batches
56 .iter()
57 .cloned()
58 .chain(vec![self.k, self.n])
59 .collect(),
60 MatmulIdent::Out => self
61 .output_batch_dims()
62 .iter()
63 .cloned()
64 .chain(vec![self.m, self.n])
65 .collect(),
66 }
67 }
68}
69
70#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
71pub enum MatmulKind {
73 General,
75
76 MatVec,
78
79 VecMat,
81
82 ScalarVec,
84
85 VecScalar,
87
88 InnerProduct,
90
91 OuterProduct,
93
94 ScalarProduct,
96}
97
98impl From<MatmulProblemSize> for MatmulKind {
99 fn from(matmul_size: MatmulProblemSize) -> Self {
100 enum DimKind {
101 Scalar,
102 Vector,
103 }
104
105 impl From<u32> for DimKind {
106 fn from(x: u32) -> Self {
107 match x {
108 1 => DimKind::Scalar,
109 _ => DimKind::Vector,
110 }
111 }
112 }
113
114 use DimKind::*;
115 match (
116 matmul_size.m().into(),
117 matmul_size.n().into(),
118 matmul_size.k().into(),
119 ) {
120 (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
121 (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
122 (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
123 (Scalar, Vector, Vector) => MatmulKind::VecMat,
124 (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
125 (Vector, Scalar, Vector) => MatmulKind::MatVec,
126 (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
127 (Vector, Vector, Vector) => MatmulKind::General,
128 }
129 }
130}
131
132impl From<MatmulProblem> for MatmulProblemSize {
133 fn from(problem: MatmulProblem) -> Self {
134 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
135 }
136}
137
138impl From<MatmulProblem> for MatmulKind {
139 fn from(problem: MatmulProblem) -> Self {
140 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
141 }
142}
143
144impl From<&MatmulProblem> for MatmulKind {
145 fn from(problem: &MatmulProblem) -> Self {
146 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
147 }
148}
149
150#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
151pub enum MatrixLayout {
154 #[default]
155 RowMajor,
156 ColMajor,
157}
158
159#[cube]
160pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
162 match layout {
163 MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
164 MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
165 }
166}