ptx_parser/unparser/instruction/
wmma_mma.rs

1//! Original PTX specification:
2//!
3//! // Floating point (.f16 multiplicands) wmma.mma
4//! wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype d, a, b, c;
5//! ----------------------------------------------------------------
6//! // Integer (.u8/.s8 multiplicands) wmma.mma
7//! wmma.mma.sync.aligned.alayout.blayout.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;
8//! .alayout = {.row, .col};
9//! .blayout = {.row, .col};
10//! .shape  =  {.m16n16k16, .m8n32k16, .m32n8k16};
11//! .dtype   = {.f16, .f32};
12//! .atype   = {.s8, .u8};
13//! .btype   = {.s8, .u8};
14//! .ctype   = {.f16, .f32};
15//! ----------------------------------------------------------------
16//! // Floating point format .bf16 wmma.mma:
17//! wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c;
18//! .alayout = {.row, .col};
19//! .blayout = {.row, .col};
20//! .shape   = {.m16n16k16, .m8n32k16, .m32n8k16};
21//! .atype   = {.bf16 };
22//! .btype   = {.bf16};
23//! ----------------------------------------------------------------
24//! // Floating point format .tf32 wmma.mma:
25//! wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c;
26//! .alayout = {.row, .col};
27//! .blayout = {.row, .col};
28//! .shape   = {.m16n16k8 };
29//! .atype   = {.tf32 };
30//! .btype   = {.tf32};
31//! ----------------------------------------------------------------
32//! // Floating point Double precision wmma.mma:
33//! wmma.mma.sync.aligned.alayout.blayout.shape{.rnd}.f64.f64.f64.f64 d, a, b, c;
34//! .alayout = {.row, .col};
35//! .blayout = {.row, .col};
36//! .shape   = {.m8n8k4 };
37//! .rnd = { .rn, .rz, .rm, .rp };
38//! ----------------------------------------------------------------
39//! // Sub-byte (.u4/.s4 multiplicands) wmma.mma:
40//! wmma.mma.sync.aligned.row.col.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;
41//! .shape  = {.m8n8k32};
42//! .atype  = {.s4, .u4};
43//! .btype  = {.s4, .u4};
44//! ----------------------------------------------------------------
45//! // Single-bit (.b1 multiplicands) wmma.mma:
46//! wmma.mma.op.popc.sync.aligned.row.col.shape.s32.atype.btype.s32 d, a, b, c;
47//! .shape  = {.m8n8k128};
48//! .atype  = {.b1};
49//! .btype  = {.b1};
50//! .op     = {.xor, .and};
51
52#![allow(unused)]
53
54use crate::lexer::PtxToken;
55use crate::unparser::{PtxUnparser, common::*};
56
57pub mod section_0 {
58    use super::*;
59    use crate::r#type::instruction::wmma_mma::section_0::*;
60
61    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeDtypeCtype {
62        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
63            push_opcode(tokens, "wmma");
64            push_directive(tokens, "mma");
65            push_directive(tokens, "sync");
66            push_directive(tokens, "aligned");
67            push_directive(tokens, "alayout");
68            push_directive(tokens, "blayout");
69            push_directive(tokens, "shape");
70            push_directive(tokens, "dtype");
71            push_directive(tokens, "ctype");
72            self.d.unparse_tokens(tokens);
73            tokens.push(PtxToken::Comma);
74            self.a.unparse_tokens(tokens);
75            tokens.push(PtxToken::Comma);
76            self.b.unparse_tokens(tokens);
77            tokens.push(PtxToken::Comma);
78            self.c.unparse_tokens(tokens);
79            tokens.push(PtxToken::Semicolon);
80        }
81    }
82}
83
84pub mod section_1 {
85    use super::*;
86    use crate::r#type::instruction::wmma_mma::section_1::*;
87
88    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeS32AtypeBtypeS32Satfinite {
89        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
90            push_opcode(tokens, "wmma");
91            push_directive(tokens, "mma");
92            push_directive(tokens, "sync");
93            push_directive(tokens, "aligned");
94            match &self.alayout {
95                Alayout::Row => {
96                    push_directive(tokens, "row");
97                }
98                Alayout::Col => {
99                    push_directive(tokens, "col");
100                }
101            }
102            match &self.blayout {
103                Blayout::Row => {
104                    push_directive(tokens, "row");
105                }
106                Blayout::Col => {
107                    push_directive(tokens, "col");
108                }
109            }
110            match &self.shape {
111                Shape::M16n16k16 => {
112                    push_directive(tokens, "m16n16k16");
113                }
114                Shape::M8n32k16 => {
115                    push_directive(tokens, "m8n32k16");
116                }
117                Shape::M32n8k16 => {
118                    push_directive(tokens, "m32n8k16");
119                }
120            }
121            push_directive(tokens, "s32");
122            match &self.atype {
123                Atype::S8 => {
124                    push_directive(tokens, "s8");
125                }
126                Atype::U8 => {
127                    push_directive(tokens, "u8");
128                }
129            }
130            match &self.btype {
131                Btype::S8 => {
132                    push_directive(tokens, "s8");
133                }
134                Btype::U8 => {
135                    push_directive(tokens, "u8");
136                }
137            }
138            push_directive(tokens, "s32");
139            if self.satfinite {
140                push_directive(tokens, "satfinite");
141            }
142            self.d.unparse_tokens(tokens);
143            tokens.push(PtxToken::Comma);
144            self.a.unparse_tokens(tokens);
145            tokens.push(PtxToken::Comma);
146            self.b.unparse_tokens(tokens);
147            tokens.push(PtxToken::Comma);
148            self.c.unparse_tokens(tokens);
149            tokens.push(PtxToken::Semicolon);
150        }
151    }
152}
153
154pub mod section_2 {
155    use super::*;
156    use crate::r#type::instruction::wmma_mma::section_2::*;
157
158    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeF32AtypeBtypeF32 {
159        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
160            push_opcode(tokens, "wmma");
161            push_directive(tokens, "mma");
162            push_directive(tokens, "sync");
163            push_directive(tokens, "aligned");
164            match &self.alayout {
165                Alayout::Row => {
166                    push_directive(tokens, "row");
167                }
168                Alayout::Col => {
169                    push_directive(tokens, "col");
170                }
171            }
172            match &self.blayout {
173                Blayout::Row => {
174                    push_directive(tokens, "row");
175                }
176                Blayout::Col => {
177                    push_directive(tokens, "col");
178                }
179            }
180            match &self.shape {
181                Shape::M16n16k16 => {
182                    push_directive(tokens, "m16n16k16");
183                }
184                Shape::M8n32k16 => {
185                    push_directive(tokens, "m8n32k16");
186                }
187                Shape::M32n8k16 => {
188                    push_directive(tokens, "m32n8k16");
189                }
190            }
191            push_directive(tokens, "f32");
192            match &self.atype {
193                Atype::Bf16 => {
194                    push_directive(tokens, "bf16");
195                }
196            }
197            match &self.btype {
198                Btype::Bf16 => {
199                    push_directive(tokens, "bf16");
200                }
201            }
202            push_directive(tokens, "f32");
203            self.d.unparse_tokens(tokens);
204            tokens.push(PtxToken::Comma);
205            self.a.unparse_tokens(tokens);
206            tokens.push(PtxToken::Comma);
207            self.b.unparse_tokens(tokens);
208            tokens.push(PtxToken::Comma);
209            self.c.unparse_tokens(tokens);
210            tokens.push(PtxToken::Semicolon);
211        }
212    }
213}
214
215pub mod section_3 {
216    use super::*;
217    use crate::r#type::instruction::wmma_mma::section_3::*;
218
219    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeF32AtypeBtypeF321 {
220        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
221            push_opcode(tokens, "wmma");
222            push_directive(tokens, "mma");
223            push_directive(tokens, "sync");
224            push_directive(tokens, "aligned");
225            match &self.alayout {
226                Alayout::Row => {
227                    push_directive(tokens, "row");
228                }
229                Alayout::Col => {
230                    push_directive(tokens, "col");
231                }
232            }
233            match &self.blayout {
234                Blayout::Row => {
235                    push_directive(tokens, "row");
236                }
237                Blayout::Col => {
238                    push_directive(tokens, "col");
239                }
240            }
241            match &self.shape {
242                Shape::M16n16k8 => {
243                    push_directive(tokens, "m16n16k8");
244                }
245            }
246            push_directive(tokens, "f32");
247            match &self.atype {
248                Atype::Tf32 => {
249                    push_directive(tokens, "tf32");
250                }
251            }
252            match &self.btype {
253                Btype::Tf32 => {
254                    push_directive(tokens, "tf32");
255                }
256            }
257            push_directive(tokens, "f32");
258            self.d.unparse_tokens(tokens);
259            tokens.push(PtxToken::Comma);
260            self.a.unparse_tokens(tokens);
261            tokens.push(PtxToken::Comma);
262            self.b.unparse_tokens(tokens);
263            tokens.push(PtxToken::Comma);
264            self.c.unparse_tokens(tokens);
265            tokens.push(PtxToken::Semicolon);
266        }
267    }
268}
269
270pub mod section_4 {
271    use super::*;
272    use crate::r#type::instruction::wmma_mma::section_4::*;
273
274    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeRndF64F64F64F64 {
275        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
276            push_opcode(tokens, "wmma");
277            push_directive(tokens, "mma");
278            push_directive(tokens, "sync");
279            push_directive(tokens, "aligned");
280            match &self.alayout {
281                Alayout::Row => {
282                    push_directive(tokens, "row");
283                }
284                Alayout::Col => {
285                    push_directive(tokens, "col");
286                }
287            }
288            match &self.blayout {
289                Blayout::Row => {
290                    push_directive(tokens, "row");
291                }
292                Blayout::Col => {
293                    push_directive(tokens, "col");
294                }
295            }
296            match &self.shape {
297                Shape::M8n8k4 => {
298                    push_directive(tokens, "m8n8k4");
299                }
300            }
301            if let Some(rnd_0) = self.rnd.as_ref() {
302                match rnd_0 {
303                    Rnd::Rn => {
304                        push_directive(tokens, "rn");
305                    }
306                    Rnd::Rz => {
307                        push_directive(tokens, "rz");
308                    }
309                    Rnd::Rm => {
310                        push_directive(tokens, "rm");
311                    }
312                    Rnd::Rp => {
313                        push_directive(tokens, "rp");
314                    }
315                }
316            }
317            push_directive(tokens, "f64");
318            push_directive(tokens, "f64");
319            push_directive(tokens, "f64");
320            push_directive(tokens, "f64");
321            self.d.unparse_tokens(tokens);
322            tokens.push(PtxToken::Comma);
323            self.a.unparse_tokens(tokens);
324            tokens.push(PtxToken::Comma);
325            self.b.unparse_tokens(tokens);
326            tokens.push(PtxToken::Comma);
327            self.c.unparse_tokens(tokens);
328            tokens.push(PtxToken::Semicolon);
329        }
330    }
331}
332
333pub mod section_5 {
334    use super::*;
335    use crate::r#type::instruction::wmma_mma::section_5::*;
336
337    impl PtxUnparser for WmmaMmaSyncAlignedRowColShapeS32AtypeBtypeS32Satfinite {
338        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
339            push_opcode(tokens, "wmma");
340            push_directive(tokens, "mma");
341            push_directive(tokens, "sync");
342            push_directive(tokens, "aligned");
343            push_directive(tokens, "row");
344            push_directive(tokens, "col");
345            match &self.shape {
346                Shape::M8n8k32 => {
347                    push_directive(tokens, "m8n8k32");
348                }
349            }
350            push_directive(tokens, "s32");
351            match &self.atype {
352                Atype::S4 => {
353                    push_directive(tokens, "s4");
354                }
355                Atype::U4 => {
356                    push_directive(tokens, "u4");
357                }
358            }
359            match &self.btype {
360                Btype::S4 => {
361                    push_directive(tokens, "s4");
362                }
363                Btype::U4 => {
364                    push_directive(tokens, "u4");
365                }
366            }
367            push_directive(tokens, "s32");
368            if self.satfinite {
369                push_directive(tokens, "satfinite");
370            }
371            self.d.unparse_tokens(tokens);
372            tokens.push(PtxToken::Comma);
373            self.a.unparse_tokens(tokens);
374            tokens.push(PtxToken::Comma);
375            self.b.unparse_tokens(tokens);
376            tokens.push(PtxToken::Comma);
377            self.c.unparse_tokens(tokens);
378            tokens.push(PtxToken::Semicolon);
379        }
380    }
381}
382
383pub mod section_6 {
384    use super::*;
385    use crate::r#type::instruction::wmma_mma::section_6::*;
386
387    impl PtxUnparser for WmmaMmaOpPopcSyncAlignedRowColShapeS32AtypeBtypeS32 {
388        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
389            push_opcode(tokens, "wmma");
390            push_directive(tokens, "mma");
391            match &self.op {
392                Op::Xor => {
393                    push_directive(tokens, "xor");
394                }
395                Op::And => {
396                    push_directive(tokens, "and");
397                }
398            }
399            push_directive(tokens, "popc");
400            push_directive(tokens, "sync");
401            push_directive(tokens, "aligned");
402            push_directive(tokens, "row");
403            push_directive(tokens, "col");
404            match &self.shape {
405                Shape::M8n8k128 => {
406                    push_directive(tokens, "m8n8k128");
407                }
408            }
409            push_directive(tokens, "s32");
410            match &self.atype {
411                Atype::B1 => {
412                    push_directive(tokens, "b1");
413                }
414            }
415            match &self.btype {
416                Btype::B1 => {
417                    push_directive(tokens, "b1");
418                }
419            }
420            push_directive(tokens, "s32");
421            self.d.unparse_tokens(tokens);
422            tokens.push(PtxToken::Comma);
423            self.a.unparse_tokens(tokens);
424            tokens.push(PtxToken::Comma);
425            self.b.unparse_tokens(tokens);
426            tokens.push(PtxToken::Comma);
427            self.c.unparse_tokens(tokens);
428            tokens.push(PtxToken::Semicolon);
429        }
430    }
431}