use alloc::{format, string::String, vec, vec::Vec};
use derive_new::new;
use super::Variable;
use crate::{OperationCode, OperationReflect};
use crate::{StorageType, TypeHash};
use core::fmt::Display;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(missing_docs)]
pub enum MatrixIdent {
A,
B,
Accumulator,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(missing_docs)]
pub enum MatrixLayout {
ColMajor,
RowMajor,
Undefined,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(new, Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(missing_docs)]
pub struct Matrix {
pub ident: MatrixIdent,
pub m: usize,
pub n: usize,
pub k: usize,
pub storage: StorageType,
pub layout: MatrixLayout,
}
impl Matrix {
pub fn num_elems(&self) -> usize {
let elems = match self.ident {
MatrixIdent::A => self.m * self.k,
MatrixIdent::B => self.k * self.n,
MatrixIdent::Accumulator => self.m * self.n,
};
elems / self.storage.packing_factor()
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
#[operation(opcode_name = CmmaOpCode)]
#[allow(missing_docs)]
pub enum CoopMma {
Fill { value: Variable },
Load {
value: Variable,
stride: Variable,
offset: Variable,
layout: Option<MatrixLayout>,
},
Execute {
mat_a: Variable,
mat_b: Variable,
mat_c: Variable,
},
Store {
mat: Variable,
stride: Variable,
offset: Variable,
layout: MatrixLayout,
},
Cast { input: Variable },
RowIndex {
lane_id: Variable,
i: Variable,
matrix: Matrix,
},
ColIndex {
lane_id: Variable,
i: Variable,
matrix: Matrix,
},
LoadMatrix {
buffer: Variable,
offset: Variable,
vector_size: Option<usize>,
factor: usize,
transpose: bool,
},
StoreMatrix {
offset: Variable,
vector_size: Option<usize>,
registers: Variable,
factor: usize,
transpose: bool,
},
ExecuteManual {
matrix: Matrix,
registers_a: Variable,
registers_b: Variable,
registers_c: Variable,
},
ExecuteScaled {
matrix: Matrix,
registers_a: Variable,
registers_b: Variable,
registers_c: Variable,
scales_a: Variable,
scales_b: Variable,
scales_factor: usize,
},
}
impl OperationReflect for CoopMma {
type OpCode = CmmaOpCode;
fn op_code(&self) -> Self::OpCode {
self.__match_opcode()
}
fn args(&self) -> Option<Vec<Variable>> {
match self {
CoopMma::Fill { value } => Some(vec![*value]),
CoopMma::Load { .. }
| CoopMma::Execute { .. }
| CoopMma::ExecuteManual { .. }
| CoopMma::ExecuteScaled { .. }
| CoopMma::Store { .. }
| CoopMma::RowIndex { .. }
| CoopMma::ColIndex { .. }
| CoopMma::LoadMatrix { .. }
| CoopMma::StoreMatrix { .. } => None,
CoopMma::Cast { input } => Some(vec![*input]),
}
}
fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
match op_code {
CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
CmmaOpCode::Load
| CmmaOpCode::Execute
| CmmaOpCode::ExecuteManual
| CmmaOpCode::ExecuteScaled
| CmmaOpCode::Store
| CmmaOpCode::RowIndex
| CmmaOpCode::ColIndex
| CmmaOpCode::LoadMatrix
| CmmaOpCode::StoreMatrix => None,
CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
}
}
}
impl Display for CoopMma {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
CoopMma::Fill { value } => write!(f, "{value}"),
CoopMma::Load {
value,
stride,
offset,
layout,
} => {
let layout = layout
.map(|it| format!(", layout: {it:?}"))
.unwrap_or(String::new());
write!(
f,
"matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
)
}
CoopMma::Execute {
mat_a,
mat_b,
mat_c,
} => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
CoopMma::ExecuteManual {
matrix,
registers_a,
registers_b,
registers_c,
} => {
write!(
f,
"execute_manual_mma(
matrix: {matrix:?},
frag_a: {registers_a},
frag_b: {registers_b},
frag_c: {registers_c},
)"
)
}
CoopMma::ExecuteScaled {
matrix,
registers_a,
registers_b,
registers_c,
scales_a,
scales_b,
scales_factor,
} => {
write!(
f,
"execute_scaled_mma_{scales_factor}x(
matrix: {matrix:?},
frag_a: {registers_a},
frag_b: {registers_b},
frag_c: {registers_c},
scales_a: {scales_a},
scales_b: {scales_b}
)"
)
}
CoopMma::Store {
mat,
stride,
offset,
layout,
} => write!(
f,
"matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
),
CoopMma::Cast { input } => {
write!(f, "matrix_cast(input: {input})")
}
CoopMma::RowIndex { lane_id, i, matrix } => {
write!(f, "row_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
}
CoopMma::ColIndex { lane_id, i, matrix } => {
write!(f, "col_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
}
CoopMma::LoadMatrix {
buffer,
offset,
factor,
transpose,
..
} => {
write!(
f,
"ldmatrix_{factor}x(&{buffer}[{offset}], transpose: {transpose})"
)
}
CoopMma::StoreMatrix {
offset,
registers,
factor,
transpose,
..
} => {
write!(
f,
"stmatrix_{factor}x({registers}, offset: {offset}, transpose: {transpose})"
)
}
}
}
}