cubecl_ir/
cmma.rs

1use alloc::{format, string::String, vec, vec::Vec};
2
3use super::{Elem, Variable};
4use crate::TypeHash;
5use crate::{OperationCode, OperationReflect};
6use core::fmt::Display;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
10#[allow(missing_docs)]
11pub enum MatrixIdent {
12    A,
13    B,
14    Accumulator,
15}
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
19#[allow(missing_docs)]
20pub enum MatrixLayout {
21    ColMajor,
22    RowMajor,
23    Undefined,
24}
25
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
28#[allow(missing_docs)]
29pub struct Matrix {
30    pub ident: MatrixIdent,
31    pub m: u8,
32    pub n: u8,
33    pub k: u8,
34    pub elem: Elem,
35    pub layout: MatrixLayout,
36}
37
38/// Cooperative Matrix-Multiply and Accumulate Instruction.
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
41#[operation(opcode_name = CmmaOpCode)]
42#[allow(missing_docs)]
43pub enum CoopMma {
44    /// Fill the matrix with the value.
45    Fill { value: Variable },
46    /// Load the value into the matrix given the stride.
47    Load {
48        value: Variable,
49        stride: Variable,
50        offset: Variable,
51        layout: Option<MatrixLayout>,
52    },
53    /// Executes D=A*B+C;
54    ///
55    /// For implementing a matmul, `D=C` : `C+=A*B`
56    Execute {
57        mat_a: Variable,
58        mat_b: Variable,
59        mat_c: Variable,
60    },
61    /// Store the matrix in an output variable following the stride and the layout.
62    Store {
63        mat: Variable,
64        stride: Variable,
65        offset: Variable,
66        layout: MatrixLayout,
67    },
68    /// Cast a fragment to another type.
69    Cast { input: Variable },
70}
71
72impl OperationReflect for CoopMma {
73    type OpCode = CmmaOpCode;
74
75    fn op_code(&self) -> Self::OpCode {
76        self.__match_opcode()
77    }
78
79    fn args(&self) -> Option<Vec<Variable>> {
80        match self {
81            CoopMma::Fill { value } => Some(vec![*value]),
82            CoopMma::Load { .. } | CoopMma::Execute { .. } | CoopMma::Store { .. } => None,
83            CoopMma::Cast { input } => Some(vec![*input]),
84        }
85    }
86
87    fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
88        match op_code {
89            CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
90            CmmaOpCode::Load | CmmaOpCode::Execute | CmmaOpCode::Store => None,
91            CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
92        }
93    }
94}
95
96impl Display for CoopMma {
97    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
98        match self {
99            CoopMma::Fill { value } => write!(f, "{value}"),
100            CoopMma::Load {
101                value,
102                stride,
103                offset,
104                layout,
105            } => {
106                let layout = layout
107                    .map(|it| format!(", layout: {it:?}"))
108                    .unwrap_or(String::new());
109                write!(
110                    f,
111                    "matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
112                )
113            }
114            CoopMma::Execute {
115                mat_a,
116                mat_b,
117                mat_c,
118            } => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
119            CoopMma::Store {
120                mat,
121                stride,
122                offset,
123                layout,
124            } => write!(
125                f,
126                "matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
127            ),
128            CoopMma::Cast { input } => {
129                write!(f, "matrix_cast(input: {input})")
130            }
131        }
132    }
133}