Skip to main content

ptx_parser/unparser/instruction/
wmma_mma.rs

1//! Original PTX specification:
2//!
3//! // Floating point (.f16 multiplicands) wmma.mma
4//! wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype d, a, b, c;
5//! ----------------------------------------------------------------
6//! // Integer (.u8/.s8 multiplicands) wmma.mma
7//! wmma.mma.sync.aligned.alayout.blayout.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;
8//! .alayout = {.row, .col};
9//! .blayout = {.row, .col};
10//! .shape  =  {.m16n16k16, .m8n32k16, .m32n8k16};
11//! .dtype   = {.f16, .f32};
12//! .atype   = {.s8, .u8};
13//! .btype   = {.s8, .u8};
14//! .ctype   = {.f16, .f32};
15//! ----------------------------------------------------------------
16//! // Floating point format .bf16 wmma.mma:
17//! wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c;
18//! .alayout = {.row, .col};
19//! .blayout = {.row, .col};
20//! .shape   = {.m16n16k16, .m8n32k16, .m32n8k16};
21//! .atype   = {.bf16 };
22//! .btype   = {.bf16};
23//! ----------------------------------------------------------------
24//! // Floating point format .tf32 wmma.mma:
25//! wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c;
26//! .alayout = {.row, .col};
27//! .blayout = {.row, .col};
28//! .shape   = {.m16n16k8 };
29//! .atype   = {.tf32 };
30//! .btype   = {.tf32};
31//! ----------------------------------------------------------------
32//! // Floating point Double precision wmma.mma:
33//! wmma.mma.sync.aligned.alayout.blayout.shape{.rnd}.f64.f64.f64.f64 d, a, b, c;
34//! .alayout = {.row, .col};
35//! .blayout = {.row, .col};
36//! .shape   = {.m8n8k4 };
37//! .rnd = { .rn, .rz, .rm, .rp };
38//! ----------------------------------------------------------------
39//! // Sub-byte (.u4/.s4 multiplicands) wmma.mma:
40//! wmma.mma.sync.aligned.row.col.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;
41//! .shape  = {.m8n8k32};
42//! .atype  = {.s4, .u4};
43//! .btype  = {.s4, .u4};
44//! ----------------------------------------------------------------
45//! // Single-bit (.b1 multiplicands) wmma.mma:
46//! wmma.mma.op.popc.sync.aligned.row.col.shape.s32.atype.btype.s32 d, a, b, c;
47//! .shape  = {.m8n8k128};
48//! .atype  = {.b1};
49//! .btype  = {.b1};
50//! .op     = {.xor, .and};
51
52#![allow(unused)]
53
54use crate::lexer::PtxToken;
55use crate::unparser::{PtxUnparser, common::*};
56
57pub mod section_0 {
58    use super::*;
59    use crate::r#type::instruction::wmma_mma::section_0::*;
60
61    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeDtypeCtype {
62        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
63            self.unparse_tokens_mode(tokens, false);
64        }
65        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
66            push_opcode(tokens, "wmma");
67            push_directive(tokens, "mma");
68            push_directive(tokens, "sync");
69            push_directive(tokens, "aligned");
70            push_directive(tokens, "alayout");
71            push_directive(tokens, "blayout");
72            push_directive(tokens, "shape");
73            push_directive(tokens, "dtype");
74            push_directive(tokens, "ctype");
75            if spaced {
76                tokens.push(PtxToken::Space);
77            }
78            self.d.unparse_tokens_mode(tokens, spaced);
79            tokens.push(PtxToken::Comma);
80            if spaced {
81                tokens.push(PtxToken::Space);
82            }
83            self.a.unparse_tokens_mode(tokens, spaced);
84            tokens.push(PtxToken::Comma);
85            if spaced {
86                tokens.push(PtxToken::Space);
87            }
88            self.b.unparse_tokens_mode(tokens, spaced);
89            tokens.push(PtxToken::Comma);
90            if spaced {
91                tokens.push(PtxToken::Space);
92            }
93            self.c.unparse_tokens_mode(tokens, spaced);
94            tokens.push(PtxToken::Semicolon);
95            if spaced {
96                tokens.push(PtxToken::Newline);
97            }
98        }
99    }
100}
101
102pub mod section_1 {
103    use super::*;
104    use crate::r#type::instruction::wmma_mma::section_1::*;
105
106    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeS32AtypeBtypeS32Satfinite {
107        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
108            self.unparse_tokens_mode(tokens, false);
109        }
110        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
111            push_opcode(tokens, "wmma");
112            push_directive(tokens, "mma");
113            push_directive(tokens, "sync");
114            push_directive(tokens, "aligned");
115            match &self.alayout {
116                Alayout::Row => {
117                    push_directive(tokens, "row");
118                }
119                Alayout::Col => {
120                    push_directive(tokens, "col");
121                }
122            }
123            match &self.blayout {
124                Blayout::Row => {
125                    push_directive(tokens, "row");
126                }
127                Blayout::Col => {
128                    push_directive(tokens, "col");
129                }
130            }
131            match &self.shape {
132                Shape::M16n16k16 => {
133                    push_directive(tokens, "m16n16k16");
134                }
135                Shape::M8n32k16 => {
136                    push_directive(tokens, "m8n32k16");
137                }
138                Shape::M32n8k16 => {
139                    push_directive(tokens, "m32n8k16");
140                }
141            }
142            push_directive(tokens, "s32");
143            match &self.atype {
144                Atype::S8 => {
145                    push_directive(tokens, "s8");
146                }
147                Atype::U8 => {
148                    push_directive(tokens, "u8");
149                }
150            }
151            match &self.btype {
152                Btype::S8 => {
153                    push_directive(tokens, "s8");
154                }
155                Btype::U8 => {
156                    push_directive(tokens, "u8");
157                }
158            }
159            push_directive(tokens, "s32");
160            if self.satfinite {
161                push_directive(tokens, "satfinite");
162            }
163            if spaced {
164                tokens.push(PtxToken::Space);
165            }
166            self.d.unparse_tokens_mode(tokens, spaced);
167            tokens.push(PtxToken::Comma);
168            if spaced {
169                tokens.push(PtxToken::Space);
170            }
171            self.a.unparse_tokens_mode(tokens, spaced);
172            tokens.push(PtxToken::Comma);
173            if spaced {
174                tokens.push(PtxToken::Space);
175            }
176            self.b.unparse_tokens_mode(tokens, spaced);
177            tokens.push(PtxToken::Comma);
178            if spaced {
179                tokens.push(PtxToken::Space);
180            }
181            self.c.unparse_tokens_mode(tokens, spaced);
182            tokens.push(PtxToken::Semicolon);
183            if spaced {
184                tokens.push(PtxToken::Newline);
185            }
186        }
187    }
188}
189
190pub mod section_2 {
191    use super::*;
192    use crate::r#type::instruction::wmma_mma::section_2::*;
193
194    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeF32AtypeBtypeF32 {
195        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
196            self.unparse_tokens_mode(tokens, false);
197        }
198        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
199            push_opcode(tokens, "wmma");
200            push_directive(tokens, "mma");
201            push_directive(tokens, "sync");
202            push_directive(tokens, "aligned");
203            match &self.alayout {
204                Alayout::Row => {
205                    push_directive(tokens, "row");
206                }
207                Alayout::Col => {
208                    push_directive(tokens, "col");
209                }
210            }
211            match &self.blayout {
212                Blayout::Row => {
213                    push_directive(tokens, "row");
214                }
215                Blayout::Col => {
216                    push_directive(tokens, "col");
217                }
218            }
219            match &self.shape {
220                Shape::M16n16k16 => {
221                    push_directive(tokens, "m16n16k16");
222                }
223                Shape::M8n32k16 => {
224                    push_directive(tokens, "m8n32k16");
225                }
226                Shape::M32n8k16 => {
227                    push_directive(tokens, "m32n8k16");
228                }
229            }
230            push_directive(tokens, "f32");
231            match &self.atype {
232                Atype::Bf16 => {
233                    push_directive(tokens, "bf16");
234                }
235            }
236            match &self.btype {
237                Btype::Bf16 => {
238                    push_directive(tokens, "bf16");
239                }
240            }
241            push_directive(tokens, "f32");
242            if spaced {
243                tokens.push(PtxToken::Space);
244            }
245            self.d.unparse_tokens_mode(tokens, spaced);
246            tokens.push(PtxToken::Comma);
247            if spaced {
248                tokens.push(PtxToken::Space);
249            }
250            self.a.unparse_tokens_mode(tokens, spaced);
251            tokens.push(PtxToken::Comma);
252            if spaced {
253                tokens.push(PtxToken::Space);
254            }
255            self.b.unparse_tokens_mode(tokens, spaced);
256            tokens.push(PtxToken::Comma);
257            if spaced {
258                tokens.push(PtxToken::Space);
259            }
260            self.c.unparse_tokens_mode(tokens, spaced);
261            tokens.push(PtxToken::Semicolon);
262            if spaced {
263                tokens.push(PtxToken::Newline);
264            }
265        }
266    }
267}
268
269pub mod section_3 {
270    use super::*;
271    use crate::r#type::instruction::wmma_mma::section_3::*;
272
273    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeF32AtypeBtypeF321 {
274        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
275            self.unparse_tokens_mode(tokens, false);
276        }
277        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
278            push_opcode(tokens, "wmma");
279            push_directive(tokens, "mma");
280            push_directive(tokens, "sync");
281            push_directive(tokens, "aligned");
282            match &self.alayout {
283                Alayout::Row => {
284                    push_directive(tokens, "row");
285                }
286                Alayout::Col => {
287                    push_directive(tokens, "col");
288                }
289            }
290            match &self.blayout {
291                Blayout::Row => {
292                    push_directive(tokens, "row");
293                }
294                Blayout::Col => {
295                    push_directive(tokens, "col");
296                }
297            }
298            match &self.shape {
299                Shape::M16n16k8 => {
300                    push_directive(tokens, "m16n16k8");
301                }
302            }
303            push_directive(tokens, "f32");
304            match &self.atype {
305                Atype::Tf32 => {
306                    push_directive(tokens, "tf32");
307                }
308            }
309            match &self.btype {
310                Btype::Tf32 => {
311                    push_directive(tokens, "tf32");
312                }
313            }
314            push_directive(tokens, "f32");
315            if spaced {
316                tokens.push(PtxToken::Space);
317            }
318            self.d.unparse_tokens_mode(tokens, spaced);
319            tokens.push(PtxToken::Comma);
320            if spaced {
321                tokens.push(PtxToken::Space);
322            }
323            self.a.unparse_tokens_mode(tokens, spaced);
324            tokens.push(PtxToken::Comma);
325            if spaced {
326                tokens.push(PtxToken::Space);
327            }
328            self.b.unparse_tokens_mode(tokens, spaced);
329            tokens.push(PtxToken::Comma);
330            if spaced {
331                tokens.push(PtxToken::Space);
332            }
333            self.c.unparse_tokens_mode(tokens, spaced);
334            tokens.push(PtxToken::Semicolon);
335            if spaced {
336                tokens.push(PtxToken::Newline);
337            }
338        }
339    }
340}
341
342pub mod section_4 {
343    use super::*;
344    use crate::r#type::instruction::wmma_mma::section_4::*;
345
346    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeRndF64F64F64F64 {
347        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
348            self.unparse_tokens_mode(tokens, false);
349        }
350        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
351            push_opcode(tokens, "wmma");
352            push_directive(tokens, "mma");
353            push_directive(tokens, "sync");
354            push_directive(tokens, "aligned");
355            match &self.alayout {
356                Alayout::Row => {
357                    push_directive(tokens, "row");
358                }
359                Alayout::Col => {
360                    push_directive(tokens, "col");
361                }
362            }
363            match &self.blayout {
364                Blayout::Row => {
365                    push_directive(tokens, "row");
366                }
367                Blayout::Col => {
368                    push_directive(tokens, "col");
369                }
370            }
371            match &self.shape {
372                Shape::M8n8k4 => {
373                    push_directive(tokens, "m8n8k4");
374                }
375            }
376            if let Some(rnd_0) = self.rnd.as_ref() {
377                match rnd_0 {
378                    Rnd::Rn => {
379                        push_directive(tokens, "rn");
380                    }
381                    Rnd::Rz => {
382                        push_directive(tokens, "rz");
383                    }
384                    Rnd::Rm => {
385                        push_directive(tokens, "rm");
386                    }
387                    Rnd::Rp => {
388                        push_directive(tokens, "rp");
389                    }
390                }
391            }
392            push_directive(tokens, "f64");
393            push_directive(tokens, "f64");
394            push_directive(tokens, "f64");
395            push_directive(tokens, "f64");
396            if spaced {
397                tokens.push(PtxToken::Space);
398            }
399            self.d.unparse_tokens_mode(tokens, spaced);
400            tokens.push(PtxToken::Comma);
401            if spaced {
402                tokens.push(PtxToken::Space);
403            }
404            self.a.unparse_tokens_mode(tokens, spaced);
405            tokens.push(PtxToken::Comma);
406            if spaced {
407                tokens.push(PtxToken::Space);
408            }
409            self.b.unparse_tokens_mode(tokens, spaced);
410            tokens.push(PtxToken::Comma);
411            if spaced {
412                tokens.push(PtxToken::Space);
413            }
414            self.c.unparse_tokens_mode(tokens, spaced);
415            tokens.push(PtxToken::Semicolon);
416            if spaced {
417                tokens.push(PtxToken::Newline);
418            }
419        }
420    }
421}
422
423pub mod section_5 {
424    use super::*;
425    use crate::r#type::instruction::wmma_mma::section_5::*;
426
427    impl PtxUnparser for WmmaMmaSyncAlignedRowColShapeS32AtypeBtypeS32Satfinite {
428        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
429            self.unparse_tokens_mode(tokens, false);
430        }
431        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
432            push_opcode(tokens, "wmma");
433            push_directive(tokens, "mma");
434            push_directive(tokens, "sync");
435            push_directive(tokens, "aligned");
436            push_directive(tokens, "row");
437            push_directive(tokens, "col");
438            match &self.shape {
439                Shape::M8n8k32 => {
440                    push_directive(tokens, "m8n8k32");
441                }
442            }
443            push_directive(tokens, "s32");
444            match &self.atype {
445                Atype::S4 => {
446                    push_directive(tokens, "s4");
447                }
448                Atype::U4 => {
449                    push_directive(tokens, "u4");
450                }
451            }
452            match &self.btype {
453                Btype::S4 => {
454                    push_directive(tokens, "s4");
455                }
456                Btype::U4 => {
457                    push_directive(tokens, "u4");
458                }
459            }
460            push_directive(tokens, "s32");
461            if self.satfinite {
462                push_directive(tokens, "satfinite");
463            }
464            if spaced {
465                tokens.push(PtxToken::Space);
466            }
467            self.d.unparse_tokens_mode(tokens, spaced);
468            tokens.push(PtxToken::Comma);
469            if spaced {
470                tokens.push(PtxToken::Space);
471            }
472            self.a.unparse_tokens_mode(tokens, spaced);
473            tokens.push(PtxToken::Comma);
474            if spaced {
475                tokens.push(PtxToken::Space);
476            }
477            self.b.unparse_tokens_mode(tokens, spaced);
478            tokens.push(PtxToken::Comma);
479            if spaced {
480                tokens.push(PtxToken::Space);
481            }
482            self.c.unparse_tokens_mode(tokens, spaced);
483            tokens.push(PtxToken::Semicolon);
484            if spaced {
485                tokens.push(PtxToken::Newline);
486            }
487        }
488    }
489}
490
491pub mod section_6 {
492    use super::*;
493    use crate::r#type::instruction::wmma_mma::section_6::*;
494
495    impl PtxUnparser for WmmaMmaOpPopcSyncAlignedRowColShapeS32AtypeBtypeS32 {
496        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
497            self.unparse_tokens_mode(tokens, false);
498        }
499        fn unparse_tokens_mode(&self, tokens: &mut ::std::vec::Vec<PtxToken>, spaced: bool) {
500            push_opcode(tokens, "wmma");
501            push_directive(tokens, "mma");
502            match &self.op {
503                Op::Xor => {
504                    push_directive(tokens, "xor");
505                }
506                Op::And => {
507                    push_directive(tokens, "and");
508                }
509            }
510            push_directive(tokens, "popc");
511            push_directive(tokens, "sync");
512            push_directive(tokens, "aligned");
513            push_directive(tokens, "row");
514            push_directive(tokens, "col");
515            match &self.shape {
516                Shape::M8n8k128 => {
517                    push_directive(tokens, "m8n8k128");
518                }
519            }
520            push_directive(tokens, "s32");
521            match &self.atype {
522                Atype::B1 => {
523                    push_directive(tokens, "b1");
524                }
525            }
526            match &self.btype {
527                Btype::B1 => {
528                    push_directive(tokens, "b1");
529                }
530            }
531            push_directive(tokens, "s32");
532            if spaced {
533                tokens.push(PtxToken::Space);
534            }
535            self.d.unparse_tokens_mode(tokens, spaced);
536            tokens.push(PtxToken::Comma);
537            if spaced {
538                tokens.push(PtxToken::Space);
539            }
540            self.a.unparse_tokens_mode(tokens, spaced);
541            tokens.push(PtxToken::Comma);
542            if spaced {
543                tokens.push(PtxToken::Space);
544            }
545            self.b.unparse_tokens_mode(tokens, spaced);
546            tokens.push(PtxToken::Comma);
547            if spaced {
548                tokens.push(PtxToken::Space);
549            }
550            self.c.unparse_tokens_mode(tokens, spaced);
551            tokens.push(PtxToken::Semicolon);
552            if spaced {
553                tokens.push(PtxToken::Newline);
554            }
555        }
556    }
557}