1use alloc::{format, string::String, vec, vec::Vec};
2use derive_new::new;
3
4use super::Variable;
5use crate::{OperationCode, OperationReflect};
6use crate::{StorageType, TypeHash};
7use core::fmt::Display;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
11#[allow(missing_docs)]
12pub enum MatrixIdent {
13 A,
14 B,
15 Accumulator,
16}
17
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
20#[allow(missing_docs)]
21pub enum MatrixLayout {
22 ColMajor,
23 RowMajor,
24 Undefined,
25}
26
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(new, Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
29#[allow(missing_docs)]
30pub struct Matrix {
31 pub ident: MatrixIdent,
32 pub m: u32,
33 pub n: u32,
34 pub k: u32,
35 pub storage: StorageType,
36 pub layout: MatrixLayout,
37}
38
39impl Matrix {
40 pub fn num_elems(&self) -> u32 {
42 let elems = match self.ident {
43 MatrixIdent::A => self.m * self.k,
44 MatrixIdent::B => self.k * self.n,
45 MatrixIdent::Accumulator => self.m * self.n,
46 };
47 elems / self.storage.packing_factor()
48 }
49}
50
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
54#[operation(opcode_name = CmmaOpCode)]
55#[allow(missing_docs)]
56pub enum CoopMma {
57 Fill { value: Variable },
59 Load {
61 value: Variable,
62 stride: Variable,
63 offset: Variable,
64 layout: Option<MatrixLayout>,
65 },
66 Execute {
70 mat_a: Variable,
71 mat_b: Variable,
72 mat_c: Variable,
73 },
74 Store {
76 mat: Variable,
77 stride: Variable,
78 offset: Variable,
79 layout: MatrixLayout,
80 },
81 Cast { input: Variable },
83
84 RowIndex {
86 lane_id: Variable,
87 i: Variable,
88 matrix: Matrix,
89 },
90 ColIndex {
92 lane_id: Variable,
93 i: Variable,
94 matrix: Matrix,
95 },
96 ExecuteManual {
98 matrix: Matrix,
99 registers_a: Vec<Variable>,
100 registers_b: Vec<Variable>,
101 registers_c: Vec<Variable>,
102 },
103 ExecuteScaled {
105 matrix: Matrix,
106 registers_a: Vec<Variable>,
107 registers_b: Vec<Variable>,
108 registers_c: Vec<Variable>,
109 scales_a: Variable,
110 scales_b: Variable,
111 scales_factor: u32,
112 },
113}
114
115impl OperationReflect for CoopMma {
116 type OpCode = CmmaOpCode;
117
118 fn op_code(&self) -> Self::OpCode {
119 self.__match_opcode()
120 }
121
122 fn args(&self) -> Option<Vec<Variable>> {
123 match self {
124 CoopMma::Fill { value } => Some(vec![*value]),
125 CoopMma::Load { .. }
126 | CoopMma::Execute { .. }
127 | CoopMma::ExecuteManual { .. }
128 | CoopMma::ExecuteScaled { .. }
129 | CoopMma::Store { .. }
130 | CoopMma::RowIndex { .. }
131 | CoopMma::ColIndex { .. } => None,
132 CoopMma::Cast { input } => Some(vec![*input]),
133 }
134 }
135
136 fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
137 match op_code {
138 CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
139 CmmaOpCode::Load
140 | CmmaOpCode::Execute
141 | CmmaOpCode::ExecuteManual
142 | CmmaOpCode::ExecuteScaled
143 | CmmaOpCode::Store
144 | CmmaOpCode::RowIndex
145 | CmmaOpCode::ColIndex => None,
146 CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
147 }
148 }
149}
150
151impl Display for CoopMma {
152 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
153 match self {
154 CoopMma::Fill { value } => write!(f, "{value}"),
155 CoopMma::Load {
156 value,
157 stride,
158 offset,
159 layout,
160 } => {
161 let layout = layout
162 .map(|it| format!(", layout: {it:?}"))
163 .unwrap_or(String::new());
164 write!(
165 f,
166 "matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
167 )
168 }
169 CoopMma::Execute {
170 mat_a,
171 mat_b,
172 mat_c,
173 } => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
174 CoopMma::ExecuteManual {
175 matrix,
176 registers_a,
177 registers_b,
178 registers_c,
179 } => {
180 let frag_a = comma_separated(registers_a.iter().map(|it| format!("{it}")));
181 let frag_b = comma_separated(registers_b.iter().map(|it| format!("{it}")));
182 let frag_c = comma_separated(registers_c.iter().map(|it| format!("{it}")));
183 write!(
184 f,
185 "execute_manual_mma(
186 matrix: {matrix:?},
187 frag_a: [{frag_a}],
188 frag_b: [{frag_b}],
189 frag_c: [{frag_c}],
190 )"
191 )
192 }
193 CoopMma::ExecuteScaled {
194 matrix,
195 registers_a,
196 registers_b,
197 registers_c,
198 scales_a,
199 scales_b,
200 scales_factor,
201 } => {
202 let frag_a = comma_separated(registers_a.iter().map(|it| format!("{it}")));
203 let frag_b = comma_separated(registers_b.iter().map(|it| format!("{it}")));
204 let frag_c = comma_separated(registers_c.iter().map(|it| format!("{it}")));
205 write!(
206 f,
207 "execute_scaled_mma_{scales_factor}x(
208 matrix: {matrix:?},
209 frag_a: [{frag_a}],
210 frag_b: [{frag_b}],
211 frag_c: [{frag_c}],
212 scales_a: {scales_a},
213 scales_b: {scales_b}
214 )"
215 )
216 }
217 CoopMma::Store {
218 mat,
219 stride,
220 offset,
221 layout,
222 } => write!(
223 f,
224 "matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
225 ),
226 CoopMma::Cast { input } => {
227 write!(f, "matrix_cast(input: {input})")
228 }
229 CoopMma::RowIndex { lane_id, i, matrix } => {
230 write!(f, "row_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
231 }
232 CoopMma::ColIndex { lane_id, i, matrix } => {
233 write!(f, "col_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
234 }
235 }
236 }
237}
238
239fn comma_separated(it: impl IntoIterator<Item = String>) -> String {
240 it.into_iter().collect::<Vec<_>>().join(", ")
241}