cubecl_matmul/components/
problem.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use serde::{Deserialize, Serialize};
4
5use super::{MatmulIdent, MatmulProblemSize};
6
7#[derive(Clone, Debug)]
8pub struct MatmulProblem {
10 pub m: usize,
12 pub n: usize,
14 pub k: usize,
16
17 pub lhs_batches: Vec<usize>,
19 pub rhs_batches: Vec<usize>,
21 pub out_batches: Vec<usize>,
23
24 pub lhs_strides: Vec<usize>,
26 pub rhs_strides: Vec<usize>,
28
29 pub lhs_layout: MatrixLayout,
31 pub rhs_layout: MatrixLayout,
33}
34
35impl MatmulProblem {
36 fn output_batch_dims(&self) -> Vec<usize> {
38 self.lhs_batches
39 .iter()
40 .rev()
41 .zip(self.rhs_batches.iter().rev())
42 .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
43 .collect()
44 }
45
46 pub(crate) fn num_batches(&self) -> usize {
48 self.output_batch_dims().iter().product()
49 }
50
51 #[allow(unused)]
53 pub(crate) fn shape(&self, ident: MatmulIdent) -> Vec<usize> {
54 match ident {
55 MatmulIdent::Lhs => self
56 .lhs_batches
57 .iter()
58 .cloned()
59 .chain(vec![self.m, self.k])
60 .collect(),
61 MatmulIdent::Rhs => self
62 .rhs_batches
63 .iter()
64 .cloned()
65 .chain(vec![self.k, self.n])
66 .collect(),
67 MatmulIdent::Out => self
68 .output_batch_dims()
69 .iter()
70 .cloned()
71 .chain(vec![self.m, self.n])
72 .collect(),
73 }
74 }
75}
76
77#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
78pub enum MatmulKind {
80 General,
82
83 MatVec,
85
86 VecMat,
88
89 ScalarVec,
91
92 VecScalar,
94
95 InnerProduct,
97
98 OuterProduct,
100
101 ScalarProduct,
103}
104
105impl From<MatmulProblemSize> for MatmulKind {
106 fn from(matmul_size: MatmulProblemSize) -> Self {
107 enum DimKind {
108 Scalar,
109 Vector,
110 }
111
112 impl From<u32> for DimKind {
113 fn from(x: u32) -> Self {
114 match x {
115 1 => DimKind::Scalar,
116 _ => DimKind::Vector,
117 }
118 }
119 }
120
121 use DimKind::*;
122 match (
123 matmul_size.m().into(),
124 matmul_size.n().into(),
125 matmul_size.k().into(),
126 ) {
127 (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
128 (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
129 (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
130 (Scalar, Vector, Vector) => MatmulKind::VecMat,
131 (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
132 (Vector, Scalar, Vector) => MatmulKind::MatVec,
133 (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
134 (Vector, Vector, Vector) => MatmulKind::General,
135 }
136 }
137}
138
139impl From<MatmulProblem> for MatmulProblemSize {
140 fn from(problem: MatmulProblem) -> Self {
141 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
142 }
143}
144
145impl From<MatmulProblem> for MatmulKind {
146 fn from(problem: MatmulProblem) -> Self {
147 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
148 }
149}
150
151impl From<&MatmulProblem> for MatmulKind {
152 fn from(problem: &MatmulProblem) -> Self {
153 MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
154 }
155}
156
157#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
158pub enum MatrixLayout {
161 #[default]
162 RowMajor,
163 ColMajor,
164}
165
166#[cube]
167pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
169 match layout {
170 MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
171 MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
172 }
173}
174
175#[cube]
176pub fn from_cmma_layout(#[comptime] layout: cmma::MatrixLayout) -> comptime_type!(MatrixLayout) {
178 match layout {
179 cmma::MatrixLayout::RowMajor => MatrixLayout::RowMajor,
180 cmma::MatrixLayout::ColMajor => MatrixLayout::ColMajor,
181 cmma::MatrixLayout::Undefined => MatrixLayout::RowMajor,
182 }
183}