1use alloc::{format, string::String, vec, vec::Vec};
2
3use super::{Elem, Variable};
4use crate::TypeHash;
5use crate::{OperationCode, OperationReflect};
6use core::fmt::Display;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
10#[allow(missing_docs)]
11pub enum MatrixIdent {
12 A,
13 B,
14 Accumulator,
15}
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
19#[allow(missing_docs)]
20pub enum MatrixLayout {
21 ColMajor,
22 RowMajor,
23 Undefined,
24}
25
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
28#[allow(missing_docs)]
29pub struct Matrix {
30 pub ident: MatrixIdent,
31 pub m: u8,
32 pub n: u8,
33 pub k: u8,
34 pub elem: Elem,
35 pub layout: MatrixLayout,
36}
37
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
41#[operation(opcode_name = CmmaOpCode)]
42#[allow(missing_docs)]
43pub enum CoopMma {
44 Fill { value: Variable },
46 Load {
48 value: Variable,
49 stride: Variable,
50 offset: Variable,
51 layout: Option<MatrixLayout>,
52 },
53 Execute {
57 mat_a: Variable,
58 mat_b: Variable,
59 mat_c: Variable,
60 },
61 Store {
63 mat: Variable,
64 stride: Variable,
65 offset: Variable,
66 layout: MatrixLayout,
67 },
68 Cast { input: Variable },
70}
71
72impl OperationReflect for CoopMma {
73 type OpCode = CmmaOpCode;
74
75 fn op_code(&self) -> Self::OpCode {
76 self.__match_opcode()
77 }
78
79 fn args(&self) -> Option<Vec<Variable>> {
80 match self {
81 CoopMma::Fill { value } => Some(vec![*value]),
82 CoopMma::Load { .. } | CoopMma::Execute { .. } | CoopMma::Store { .. } => None,
83 CoopMma::Cast { input } => Some(vec![*input]),
84 }
85 }
86
87 fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
88 match op_code {
89 CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
90 CmmaOpCode::Load | CmmaOpCode::Execute | CmmaOpCode::Store => None,
91 CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
92 }
93 }
94}
95
96impl Display for CoopMma {
97 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
98 match self {
99 CoopMma::Fill { value } => write!(f, "{value}"),
100 CoopMma::Load {
101 value,
102 stride,
103 offset,
104 layout,
105 } => {
106 let layout = layout
107 .map(|it| format!(", layout: {it:?}"))
108 .unwrap_or(String::new());
109 write!(
110 f,
111 "matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
112 )
113 }
114 CoopMma::Execute {
115 mat_a,
116 mat_b,
117 mat_c,
118 } => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
119 CoopMma::Store {
120 mat,
121 stride,
122 offset,
123 layout,
124 } => write!(
125 f,
126 "matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
127 ),
128 CoopMma::Cast { input } => {
129 write!(f, "matrix_cast(input: {input})")
130 }
131 }
132 }
133}