cubecl_cpp/shared/
mma.rs

1use cubecl_core::Feature;
2use cubecl_core::ir::{self as gpu};
3use cubecl_runtime::DeviceProperties;
4use std::fmt::Display;
5use std::str::FromStr;
6use std::{fmt::Debug, marker::PhantomData};
7
8use super::{Component, Dialect, Elem, Variable};
9
10pub type SupportedWmmaCombinations = Vec<(gpu::Elem, gpu::Elem, gpu::Elem, Vec<(u8, u8, u8)>)>;
11
12pub trait Architecture: FromStr<Err = String> {
13    fn warp_size(&self) -> u32;
14    fn is_wmma_capable(&self) -> bool;
15    fn is_mfma_capable(&self) -> bool;
16}
17
18pub fn register_wmma_features(
19    supported_combinations: SupportedWmmaCombinations,
20    properties: &mut DeviceProperties<Feature>,
21) {
22    for (i, o, c, tdims) in supported_combinations {
23        for (m, n, k) in tdims {
24            properties.register_feature(Feature::Cmma {
25                a: i,
26                b: o,
27                c,
28                m,
29                n,
30                k,
31            });
32        }
33    }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Copy)]
37pub enum FragmentIdent<D: Dialect> {
38    A,
39    B,
40    Accumulator,
41    _Dialect(PhantomData<D>),
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Copy)]
45pub enum FragmentLayout<D: Dialect> {
46    ColMajor,
47    RowMajor,
48    _Dialect(PhantomData<D>),
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Copy)]
52pub struct Fragment<D: Dialect> {
53    pub ident: FragmentIdent<D>,
54    pub m: u8,
55    pub n: u8,
56    pub k: u8,
57    pub elem: Elem<D>,
58    pub layout: Option<FragmentLayout<D>>,
59}
60
61/// Warp Matrix-Multiply and Accumulate Instruction.
62#[derive(Debug, Clone, Copy)]
63pub enum WmmaInstruction<D: Dialect> {
64    /// Fill the fragment with the value.
65    Fill {
66        frag: Variable<D>,
67        value: Variable<D>,
68    },
69    /// Load the value into the fragment given the stride.
70    Load {
71        frag: Variable<D>,
72        value: Variable<D>,
73        stride: Variable<D>,
74        layout: Option<FragmentLayout<D>>,
75    },
76    /// Executes D=A*B+C;
77    ///
78    /// For implementing a matmul, `D=C` : `C+=A*B`
79    Execute {
80        frag_a: Variable<D>,
81        frag_b: Variable<D>,
82        frag_c: Variable<D>,
83        frag_d: Variable<D>,
84        warp_size: u32,
85    },
86    /// Store the fragment in an output variable following the stride and the layout.
87    Store {
88        output: Variable<D>,
89        frag: Variable<D>,
90        stride: Variable<D>,
91        layout: FragmentLayout<D>,
92    },
93    /// Cast
94    Cast {
95        input: Variable<D>,
96        output: Variable<D>,
97    },
98}
99
100impl<D: Dialect> Display for FragmentLayout<D> {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        D::compile_fragment_layout(self, f)
103    }
104}
105
106impl<D: Dialect> Display for FragmentIdent<D> {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        D::compile_fragment_ident(self, f)
109    }
110}
111
112impl<D: Dialect> Display for Fragment<D> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        D::compile_fragment(self, f)
115    }
116}
117
118impl<D: Dialect> Display for WmmaInstruction<D> {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        D::compile_instruction(self, f)
121    }
122}
123
124pub mod wmma_api_base {
125    use super::*;
126
127    pub fn compile_fragment_ident<D: Dialect>(
128        namespace: &str,
129        ident: &FragmentIdent<D>,
130        f: &mut std::fmt::Formatter<'_>,
131    ) -> std::fmt::Result {
132        match ident {
133            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
134            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
135            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
136            FragmentIdent::_Dialect(_) => Ok(()),
137        }
138    }
139
140    pub fn compile_fragment_layout<D: Dialect>(
141        namespace: &str,
142        layout: &FragmentLayout<D>,
143        f: &mut std::fmt::Formatter<'_>,
144    ) -> std::fmt::Result {
145        match layout {
146            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
147            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
148            FragmentLayout::_Dialect(_) => Ok(()),
149        }
150    }
151
152    pub fn compile_fragment<D: Dialect>(
153        namespace: &str,
154        fragment: &Fragment<D>,
155        f: &mut std::fmt::Formatter<'_>,
156    ) -> std::fmt::Result {
157        let elem = match fragment.elem {
158            Elem::TF32 => format!("{namespace}::precision::tf32"),
159            Elem::BF16 => {
160                if fragment.ident == FragmentIdent::Accumulator {
161                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
162                } else {
163                    format!("{}", fragment.elem)
164                }
165            }
166            elem => format!("{elem}"),
167        };
168        match fragment.layout {
169            Some(layout) => write!(
170                f,
171                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
172                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
173            ),
174            None => write!(
175                f,
176                "{namespace}::fragment<{}, {}, {}, {}, {}>",
177                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
178            ),
179        }
180    }
181
182    pub fn compile_instruction<D: Dialect>(
183        namespace: &str,
184        instruction: &WmmaInstruction<D>,
185        f: &mut std::fmt::Formatter<'_>,
186    ) -> std::fmt::Result {
187        match instruction {
188            WmmaInstruction::Fill { frag, value } => {
189                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
190            }
191
192            WmmaInstruction::Load {
193                frag,
194                value,
195                stride,
196                layout: None,
197            } => {
198                let item = value.item();
199                if item.vectorization > 1 {
200                    let elem = item.elem;
201                    writeln!(
202                        f,
203                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride});"
204                    )
205                } else {
206                    writeln!(
207                        f,
208                        "{namespace}::load_matrix_sync({frag}, {value}, {stride});"
209                    )
210                }
211            }
212
213            WmmaInstruction::Load {
214                frag,
215                value,
216                stride,
217                layout: Some(layout),
218            } => {
219                let layout = match layout {
220                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
221                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
222                    FragmentLayout::_Dialect(_) => "".to_string(),
223                };
224                let item = value.item();
225                if item.vectorization > 1 {
226                    let elem = item.elem;
227                    writeln!(
228                        f,
229                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride}, {layout});"
230                    )
231                } else {
232                    writeln!(
233                        f,
234                        "{namespace}::load_matrix_sync({frag}, {value}, {stride}, {layout});"
235                    )
236                }
237            }
238
239            WmmaInstruction::Execute {
240                frag_a,
241                frag_b,
242                frag_c,
243                frag_d,
244                ..
245            } => writeln!(
246                f,
247                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
248            ),
249
250            WmmaInstruction::Store {
251                output,
252                frag,
253                stride,
254                layout,
255            } => {
256                let layout = match layout {
257                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
258                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
259                    FragmentLayout::_Dialect(_) => "".to_string(),
260                };
261
262                let item = output.item();
263                let mut reinterpret_cast = item.vectorization > 1;
264                let elem = match item.elem {
265                    Elem::BF16 => {
266                        reinterpret_cast = true;
267                        Elem::F16
268                    }
269                    _ => item.elem,
270                };
271                if reinterpret_cast {
272                    writeln!(
273                        f,
274                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output}), {frag}, {stride}, {layout});"
275                    )
276                } else {
277                    writeln!(
278                        f,
279                        "{namespace}::store_matrix_sync({output}, {frag}, {stride}, {layout});"
280                    )
281                }
282            }
283            WmmaInstruction::Cast { input, output } => {
284                let ty = match output {
285                    Variable::WmmaFragment { frag, .. } => frag.elem,
286                    _ => panic!("Should be a fragment"),
287                };
288                match ty {
289                    Elem::BF16 => {
290                        let elem = Elem::<D>::F16;
291                        writeln!(
292                            f,
293                            "for(int t=0; t<{input}.num_elements; t++) {{
294                                {ty} elem = {ty}({input}.x[t]);
295                                {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
296                            }}"
297                        )
298                    }
299                    _ => {
300                        writeln!(
301                            f,
302                            "for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}"
303                        )
304                    }
305                }
306            }
307        }
308    }
309}