1use alloc::{format, string::String, vec, vec::Vec};
2use derive_new::new;
3
4use super::Variable;
5use crate::{OperationCode, OperationReflect};
6use crate::{StorageType, TypeHash};
7use core::fmt::Display;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
11#[allow(missing_docs)]
12pub enum MatrixIdent {
13 A,
14 B,
15 Accumulator,
16}
17
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
20#[allow(missing_docs)]
21pub enum MatrixLayout {
22 ColMajor,
23 RowMajor,
24 Undefined,
25}
26
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(new, Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
29#[allow(missing_docs)]
30pub struct Matrix {
31 pub ident: MatrixIdent,
32 pub m: u32,
33 pub n: u32,
34 pub k: u32,
35 pub storage: StorageType,
36 pub layout: MatrixLayout,
37}
38
39impl Matrix {
40 pub fn num_elems(&self) -> u32 {
42 let elems = match self.ident {
43 MatrixIdent::A => self.m * self.k,
44 MatrixIdent::B => self.k * self.n,
45 MatrixIdent::Accumulator => self.m * self.n,
46 };
47 elems / self.storage.packing_factor()
48 }
49}
50
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
54#[operation(opcode_name = CmmaOpCode)]
55#[allow(missing_docs)]
56pub enum CoopMma {
57 Fill { value: Variable },
59 Load {
61 value: Variable,
62 stride: Variable,
63 offset: Variable,
64 layout: Option<MatrixLayout>,
65 },
66 Execute {
70 mat_a: Variable,
71 mat_b: Variable,
72 mat_c: Variable,
73 },
74 Store {
76 mat: Variable,
77 stride: Variable,
78 offset: Variable,
79 layout: MatrixLayout,
80 },
81 Cast { input: Variable },
83
84 RowIndex {
86 lane_id: Variable,
87 i: Variable,
88 matrix: Matrix,
89 },
90 ColIndex {
92 lane_id: Variable,
93 i: Variable,
94 matrix: Matrix,
95 },
96 LoadMatrix {
98 buffer: Variable,
99 offset: Variable,
100 line_size: Option<u32>,
101 factor: u32,
102 transpose: bool,
103 },
104 StoreMatrix {
106 offset: Variable,
107 line_size: Option<u32>,
108 registers: Variable,
109 factor: u32,
110 transpose: bool,
111 },
112 ExecuteManual {
114 matrix: Matrix,
115 registers_a: Variable,
116 registers_b: Variable,
117 registers_c: Variable,
118 },
119 ExecuteScaled {
121 matrix: Matrix,
122 registers_a: Variable,
123 registers_b: Variable,
124 registers_c: Variable,
125 scales_a: Variable,
126 scales_b: Variable,
127 scales_factor: u32,
128 },
129}
130
131impl OperationReflect for CoopMma {
132 type OpCode = CmmaOpCode;
133
134 fn op_code(&self) -> Self::OpCode {
135 self.__match_opcode()
136 }
137
138 fn args(&self) -> Option<Vec<Variable>> {
139 match self {
140 CoopMma::Fill { value } => Some(vec![*value]),
141 CoopMma::Load { .. }
142 | CoopMma::Execute { .. }
143 | CoopMma::ExecuteManual { .. }
144 | CoopMma::ExecuteScaled { .. }
145 | CoopMma::Store { .. }
146 | CoopMma::RowIndex { .. }
147 | CoopMma::ColIndex { .. }
148 | CoopMma::LoadMatrix { .. }
149 | CoopMma::StoreMatrix { .. } => None,
150 CoopMma::Cast { input } => Some(vec![*input]),
151 }
152 }
153
154 fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
155 match op_code {
156 CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
157 CmmaOpCode::Load
158 | CmmaOpCode::Execute
159 | CmmaOpCode::ExecuteManual
160 | CmmaOpCode::ExecuteScaled
161 | CmmaOpCode::Store
162 | CmmaOpCode::RowIndex
163 | CmmaOpCode::ColIndex
164 | CmmaOpCode::LoadMatrix
165 | CmmaOpCode::StoreMatrix => None,
166 CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
167 }
168 }
169}
170
171impl Display for CoopMma {
172 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
173 match self {
174 CoopMma::Fill { value } => write!(f, "{value}"),
175 CoopMma::Load {
176 value,
177 stride,
178 offset,
179 layout,
180 } => {
181 let layout = layout
182 .map(|it| format!(", layout: {it:?}"))
183 .unwrap_or(String::new());
184 write!(
185 f,
186 "matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
187 )
188 }
189 CoopMma::Execute {
190 mat_a,
191 mat_b,
192 mat_c,
193 } => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
194 CoopMma::ExecuteManual {
195 matrix,
196 registers_a,
197 registers_b,
198 registers_c,
199 } => {
200 write!(
201 f,
202 "execute_manual_mma(
203 matrix: {matrix:?},
204 frag_a: {registers_a},
205 frag_b: {registers_b},
206 frag_c: {registers_c},
207 )"
208 )
209 }
210 CoopMma::ExecuteScaled {
211 matrix,
212 registers_a,
213 registers_b,
214 registers_c,
215 scales_a,
216 scales_b,
217 scales_factor,
218 } => {
219 write!(
220 f,
221 "execute_scaled_mma_{scales_factor}x(
222 matrix: {matrix:?},
223 frag_a: {registers_a},
224 frag_b: {registers_b},
225 frag_c: {registers_c},
226 scales_a: {scales_a},
227 scales_b: {scales_b}
228 )"
229 )
230 }
231 CoopMma::Store {
232 mat,
233 stride,
234 offset,
235 layout,
236 } => write!(
237 f,
238 "matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
239 ),
240 CoopMma::Cast { input } => {
241 write!(f, "matrix_cast(input: {input})")
242 }
243 CoopMma::RowIndex { lane_id, i, matrix } => {
244 write!(f, "row_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
245 }
246 CoopMma::ColIndex { lane_id, i, matrix } => {
247 write!(f, "col_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
248 }
249 CoopMma::LoadMatrix {
250 buffer,
251 offset,
252 factor,
253 transpose,
254 ..
255 } => {
256 write!(
257 f,
258 "ldmatrix_{factor}x(&{buffer}[{offset}], transpose: {transpose})"
259 )
260 }
261 CoopMma::StoreMatrix {
262 offset,
263 registers,
264 factor,
265 transpose,
266 ..
267 } => {
268 write!(
269 f,
270 "stmatrix_{factor}x({registers}, offset: {offset}, transpose: {transpose})"
271 )
272 }
273 }
274 }
275}