cubecl_ir/
cmma.rs

1use alloc::{format, string::String, vec, vec::Vec};
2use derive_new::new;
3
4use super::Variable;
5use crate::{OperationCode, OperationReflect};
6use crate::{StorageType, TypeHash};
7use core::fmt::Display;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
11#[allow(missing_docs)]
12pub enum MatrixIdent {
13    A,
14    B,
15    Accumulator,
16}
17
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
20#[allow(missing_docs)]
21pub enum MatrixLayout {
22    ColMajor,
23    RowMajor,
24    Undefined,
25}
26
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(new, Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
29#[allow(missing_docs)]
30pub struct Matrix {
31    pub ident: MatrixIdent,
32    pub m: u32,
33    pub n: u32,
34    pub k: u32,
35    pub storage: StorageType,
36    pub layout: MatrixLayout,
37}
38
39impl Matrix {
40    /// Number of elements in terms of the physical storage type, accounting for packed elements
41    pub fn num_elems(&self) -> u32 {
42        let elems = match self.ident {
43            MatrixIdent::A => self.m * self.k,
44            MatrixIdent::B => self.k * self.n,
45            MatrixIdent::Accumulator => self.m * self.n,
46        };
47        elems / self.storage.packing_factor()
48    }
49}
50
51/// Cooperative Matrix-Multiply and Accumulate Instruction.
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
54#[operation(opcode_name = CmmaOpCode)]
55#[allow(missing_docs)]
56pub enum CoopMma {
57    /// Fill the matrix with the value.
58    Fill { value: Variable },
59    /// Load the value into the matrix given the stride.
60    Load {
61        value: Variable,
62        stride: Variable,
63        offset: Variable,
64        layout: Option<MatrixLayout>,
65    },
66    /// Executes D=A*B+C;
67    ///
68    /// For implementing a matmul, `D=C` : `C+=A*B`
69    Execute {
70        mat_a: Variable,
71        mat_b: Variable,
72        mat_c: Variable,
73    },
74    /// Store the matrix in an output variable following the stride and the layout.
75    Store {
76        mat: Variable,
77        stride: Variable,
78        offset: Variable,
79        layout: MatrixLayout,
80    },
81    /// Cast a fragment to another type.
82    Cast { input: Variable },
83
84    /// Row index of nth element in the lane
85    RowIndex {
86        lane_id: Variable,
87        i: Variable,
88        matrix: Matrix,
89    },
90    /// Column index of nth element in the lane
91    ColIndex {
92        lane_id: Variable,
93        i: Variable,
94        matrix: Matrix,
95    },
96    /// Execute a CUDA `ldmatrix` instruction
97    LoadMatrix {
98        buffer: Variable,
99        offset: Variable,
100        line_size: Option<u32>,
101        factor: u32,
102        transpose: bool,
103    },
104    /// Execute a CUDA `stmatrix` instruction
105    StoreMatrix {
106        offset: Variable,
107        line_size: Option<u32>,
108        registers: Variable,
109        factor: u32,
110        transpose: bool,
111    },
112    /// Manual execute.
113    ExecuteManual {
114        matrix: Matrix,
115        registers_a: Variable,
116        registers_b: Variable,
117        registers_c: Variable,
118    },
119    /// Scaled manual execute.
120    ExecuteScaled {
121        matrix: Matrix,
122        registers_a: Variable,
123        registers_b: Variable,
124        registers_c: Variable,
125        scales_a: Variable,
126        scales_b: Variable,
127        scales_factor: u32,
128    },
129}
130
131impl OperationReflect for CoopMma {
132    type OpCode = CmmaOpCode;
133
134    fn op_code(&self) -> Self::OpCode {
135        self.__match_opcode()
136    }
137
138    fn args(&self) -> Option<Vec<Variable>> {
139        match self {
140            CoopMma::Fill { value } => Some(vec![*value]),
141            CoopMma::Load { .. }
142            | CoopMma::Execute { .. }
143            | CoopMma::ExecuteManual { .. }
144            | CoopMma::ExecuteScaled { .. }
145            | CoopMma::Store { .. }
146            | CoopMma::RowIndex { .. }
147            | CoopMma::ColIndex { .. }
148            | CoopMma::LoadMatrix { .. }
149            | CoopMma::StoreMatrix { .. } => None,
150            CoopMma::Cast { input } => Some(vec![*input]),
151        }
152    }
153
154    fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
155        match op_code {
156            CmmaOpCode::Fill => Some(CoopMma::Fill { value: args[0] }),
157            CmmaOpCode::Load
158            | CmmaOpCode::Execute
159            | CmmaOpCode::ExecuteManual
160            | CmmaOpCode::ExecuteScaled
161            | CmmaOpCode::Store
162            | CmmaOpCode::RowIndex
163            | CmmaOpCode::ColIndex
164            | CmmaOpCode::LoadMatrix
165            | CmmaOpCode::StoreMatrix => None,
166            CmmaOpCode::Cast => Some(CoopMma::Cast { input: args[0] }),
167        }
168    }
169}
170
171impl Display for CoopMma {
172    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
173        match self {
174            CoopMma::Fill { value } => write!(f, "{value}"),
175            CoopMma::Load {
176                value,
177                stride,
178                offset,
179                layout,
180            } => {
181                let layout = layout
182                    .map(|it| format!(", layout: {it:?}"))
183                    .unwrap_or(String::new());
184                write!(
185                    f,
186                    "matrix_load({value}, stride: {stride}{layout}, offset: {offset})"
187                )
188            }
189            CoopMma::Execute {
190                mat_a,
191                mat_b,
192                mat_c,
193            } => write!(f, "execute_cmma({mat_a}, {mat_b}, {mat_c})"),
194            CoopMma::ExecuteManual {
195                matrix,
196                registers_a,
197                registers_b,
198                registers_c,
199            } => {
200                write!(
201                    f,
202                    "execute_manual_mma(
203                    matrix: {matrix:?},
204                    frag_a: {registers_a},
205                    frag_b: {registers_b},
206                    frag_c: {registers_c},
207                )"
208                )
209            }
210            CoopMma::ExecuteScaled {
211                matrix,
212                registers_a,
213                registers_b,
214                registers_c,
215                scales_a,
216                scales_b,
217                scales_factor,
218            } => {
219                write!(
220                    f,
221                    "execute_scaled_mma_{scales_factor}x(
222                    matrix: {matrix:?},
223                    frag_a: {registers_a},
224                    frag_b: {registers_b},
225                    frag_c: {registers_c},
226                    scales_a: {scales_a},
227                    scales_b: {scales_b}
228                )"
229                )
230            }
231            CoopMma::Store {
232                mat,
233                stride,
234                offset,
235                layout,
236            } => write!(
237                f,
238                "matrix_store({mat}, stride: {stride}, layout: {layout:?}, offset: {offset:?})"
239            ),
240            CoopMma::Cast { input } => {
241                write!(f, "matrix_cast(input: {input})")
242            }
243            CoopMma::RowIndex { lane_id, i, matrix } => {
244                write!(f, "row_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
245            }
246            CoopMma::ColIndex { lane_id, i, matrix } => {
247                write!(f, "col_idx(lane_id: {lane_id}, i: {i}, matrix: {matrix:?})",)
248            }
249            CoopMma::LoadMatrix {
250                buffer,
251                offset,
252                factor,
253                transpose,
254                ..
255            } => {
256                write!(
257                    f,
258                    "ldmatrix_{factor}x(&{buffer}[{offset}], transpose: {transpose})"
259                )
260            }
261            CoopMma::StoreMatrix {
262                offset,
263                registers,
264                factor,
265                transpose,
266                ..
267            } => {
268                write!(
269                    f,
270                    "stmatrix_{factor}x({registers}, offset: {offset}, transpose: {transpose})"
271                )
272            }
273        }
274    }
275}