cubecl_ir/
cmma.rs

1use alloc::{format, string::String, vec, vec::Vec};
2use derive_new::new;
3
4use super::Variable;
5use crate::{OperationCode, OperationReflect};
6use crate::{StorageType, TypeHash};
7use core::fmt::Display;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
11#[allow(missing_docs)]
12pub enum MatrixIdent {
13    A,
14    B,
15    Accumulator,
16}
17
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
20#[allow(missing_docs)]
21pub enum MatrixLayout {
22    ColMajor,
23    RowMajor,
24    Undefined,
25}
26
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(new, Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
29#[allow(missing_docs)]
30pub struct Matrix {
31    pub ident: MatrixIdent,
32    pub m: u32,
33    pub n: u32,
34    pub k: u32,
35    pub storage: StorageType,
36    pub layout: MatrixLayout,
37}
38
39impl Matrix {
40    /// Number of elements in terms of the physical storage type, accounting for packed elements
41    pub fn num_elems(&self) -> u32 {
42        let elems = match self.ident {
43            MatrixIdent::A => self.m * self.k,
44            MatrixIdent::B => self.k * self.n,
45            MatrixIdent::Accumulator => self.m * self.n,
46        };
47        elems / self.storage.packing_factor()
48    }
49}
50
51/// Cooperative Matrix-Multiply and Accumulate Instruction.
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
54#[operation(opcode_name = CmmaOpCode)]
55#[allow(missing_docs)]
56pub enum CoopMma {
57    /// Fill the matrix with the value.
58    Fill { value: Variable },
59    /// Load the value into the matrix given the stride.
60    Load {
61        value: Variable,
62        stride: Variable,
63        offset: Variable,
64        layout: Option<MatrixLayout>,
65    },
66    /// Executes D=A*B+C;
67    ///
68    /// For implementing a matmul, `D=C` : `C+=A*B`
69    Execute {
70        mat_a: Variable,
71        mat_b: Variable,
72        mat_c: Variable,
73    },
74    /// Store the matrix in an output variable following the stride and the layout.
75    Store {
76        mat: Variable,
77        stride: Variable,
78        offset: Variable,
79        layout: MatrixLayout,
80    },
81    /// Cast a fragment to another type.
82    Cast { input: Variable },
83
84    /// Row index of nth element in the lane
85    RowIndex {
86        lane_id: Variable,
87        i: Variable,
88        matrix: Matrix,
89    },
90    /// Column index of nth element in the lane
91    ColIndex {
92        lane_id: Variable,
93        i: Variable,
94        matrix: Matrix,
95    },
96    /// Manual execute.
97    ExecuteManual {
98        matrix: Matrix,
99        registers_a: Vec<Variable>,
100        registers_b: Vec<Variable>,
101        registers_c: Vec<Variable>,
102    },
103    /// Scaled manual execute.
104    ExecuteScaled {
105        matrix: Matrix,
106        registers_a: Vec<Variable>,
107        registers_b: Vec<Variable>,
108        registers_c: Vec<Variable>,
109        scales_a: Variable,
110        scales_b: Variable,
111        scales_factor: u32,
112    },
113}
114
115impl OperationReflect for CoopMma {
116    type OpCode = CmmaOpCode;
117
118    fn op_code(&self) -> Self::OpCode {
119        self.__match_opcode()
120    }
121
122    fn args(&self) -> Option<Vec<Variable>> {
123        match self {
124            CoopMma::Fill { value } => Some(vec![*value]),
125            CoopMma::Load { .. }
126            | CoopMma::Execute { .. }
127            | CoopMma::ExecuteManual { .. }
128            | CoopMma::ExecuteScaled { .. }
129            | CoopMma::Store { .. }
130            | CoopMma::RowIndex { .. }
131            | CoopMma::ColIndex { .. } => None,
132            CoopMma::Cast { input } => Some(vec![*input]),
133        }
134    }
135
136    fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
137        match op_code {
138            CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
139            CmmaOpCode::Load
140            | CmmaOpCode::Execute
141            | CmmaOpCode::ExecuteManual
142            | CmmaOpCode::ExecuteScaled
143            | CmmaOpCode::Store
144            | CmmaOpCode::RowIndex
145            | CmmaOpCode::ColIndex => None,
146            CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
147        }
148    }
149}
150
151impl Display for CoopMma {
152    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
153        match self {
154            CoopMma::Fill { value } => write!(f, "{value}"),
155            CoopMma::Load {
156                value,
157                stride,
158                offset,
159                layout,
160            } => {
161                let layout = layout
162                    .map(|it| format!(", layout: {it:?}"))
163                    .unwrap_or(String::new());
164                write!(
165                    f,
166                    "matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
167                )
168            }
169            CoopMma::Execute {
170                mat_a,
171                mat_b,
172                mat_c,
173            } => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
174            CoopMma::ExecuteManual {
175                matrix,
176                registers_a,
177                registers_b,
178                registers_c,
179            } => {
180                let frag_a = comma_separated(registers_a.iter().map(|it| format!("{it}")));
181                let frag_b = comma_separated(registers_b.iter().map(|it| format!("{it}")));
182                let frag_c = comma_separated(registers_c.iter().map(|it| format!("{it}")));
183                write!(
184                    f,
185                    "execute_manual_mma(
186                    matrix: {matrix:?},
187                    frag_a: [{frag_a}],
188                    frag_b: [{frag_b}],
189                    frag_c: [{frag_c}],
190                )"
191                )
192            }
193            CoopMma::ExecuteScaled {
194                matrix,
195                registers_a,
196                registers_b,
197                registers_c,
198                scales_a,
199                scales_b,
200                scales_factor,
201            } => {
202                let frag_a = comma_separated(registers_a.iter().map(|it| format!("{it}")));
203                let frag_b = comma_separated(registers_b.iter().map(|it| format!("{it}")));
204                let frag_c = comma_separated(registers_c.iter().map(|it| format!("{it}")));
205                write!(
206                    f,
207                    "execute_scaled_mma_{scales_factor}x(
208                    matrix: {matrix:?},
209                    frag_a: [{frag_a}],
210                    frag_b: [{frag_b}],
211                    frag_c: [{frag_c}],
212                    scales_a: {scales_a},
213                    scales_b: {scales_b}
214                )"
215                )
216            }
217            CoopMma::Store {
218                mat,
219                stride,
220                offset,
221                layout,
222            } => write!(
223                f,
224                "matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
225            ),
226            CoopMma::Cast { input } => {
227                write!(f, "matrix_cast(input: {input})")
228            }
229            CoopMma::RowIndex { lane_id, i, matrix } => {
230                write!(f, "row_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
231            }
232            CoopMma::ColIndex { lane_id, i, matrix } => {
233                write!(f, "col_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
234            }
235        }
236    }
237}
238
239fn comma_separated(it: impl IntoIterator<Item = String>) -> String {
240    it.into_iter().collect::<Vec<_>>().join(", ")
241}