ptx_parser/unparser/instruction/
wmma_mma.rs

1//! Original PTX specification:
2//!
3//! // Floating point (.f16 multiplicands) wmma.mma
4//! wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype d, a, b, c;
5//! ----------------------------------------------------------------
6//! // Integer (.u8/.s8 multiplicands) wmma.mma
7//! wmma.mma.sync.aligned.alayout.blayout.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;
8//! .alayout = {.row, .col};
9//! .blayout = {.row, .col};
10//! .shape  =  {.m16n16k16, .m8n32k16, .m32n8k16};
11//! .dtype   = {.f16, .f32};
12//! .atype   = {.s8, .u8};
13//! .btype   = {.s8, .u8};
14//! .ctype   = {.f16, .f32};
15//! ----------------------------------------------------------------
16//! // Floating point format .bf16 wmma.mma:
17//! wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c;
18//! .alayout = {.row, .col};
19//! .blayout = {.row, .col};
20//! .shape   = {.m16n16k16, .m8n32k16, .m32n8k16};
21//! .atype   = {.bf16 };
22//! .btype   = {.bf16};
23//! ----------------------------------------------------------------
24//! // Floating point format .tf32 wmma.mma:
25//! wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c;
26//! .alayout = {.row, .col};
27//! .blayout = {.row, .col};
28//! .shape   = {.m16n16k8 };
29//! .atype   = {.tf32 };
30//! .btype   = {.tf32};
31//! ----------------------------------------------------------------
32//! // Floating point Double precision wmma.mma:
33//! wmma.mma.sync.aligned.alayout.blayout.shape{.rnd}.f64.f64.f64.f64 d, a, b, c;
34//! .alayout = {.row, .col};
35//! .blayout = {.row, .col};
36//! .shape   = {.m8n8k4 };
37//! .rnd = { .rn, .rz, .rm, .rp };
38//! ----------------------------------------------------------------
39//! // Sub-byte (.u4/.s4 multiplicands) wmma.mma:
40//! wmma.mma.sync.aligned.row.col.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;
41//! .shape  = {.m8n8k32};
42//! .atype  = {.s4, .u4};
43//! .btype  = {.s4, .u4};
44//! ----------------------------------------------------------------
45//! // Single-bit (.b1 multiplicands) wmma.mma:
46//! wmma.mma.op.popc.sync.aligned.row.col.shape.s32.atype.btype.s32 d, a, b, c;
47//! .shape  = {.m8n8k128};
48//! .atype  = {.b1};
49//! .btype  = {.b1};
50//! .op     = {.xor, .and};
51
52#![allow(unused)]
53
54use crate::lexer::PtxToken;
55use crate::unparser::{PtxUnparser, common::*};
56
57pub mod section_0 {
58    use super::*;
59    use crate::r#type::instruction::wmma_mma::section_0::*;
60
61    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeDtypeCtype {
62        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
63            push_opcode(tokens, "wmma");
64                    push_directive(tokens, "mma");
65                    push_directive(tokens, "sync");
66                    push_directive(tokens, "aligned");
67                    push_directive(tokens, "alayout");
68                    push_directive(tokens, "blayout");
69                    push_directive(tokens, "shape");
70                    push_directive(tokens, "dtype");
71                    push_directive(tokens, "ctype");
72                    self.d.unparse_tokens(tokens);
73            tokens.push(PtxToken::Comma);
74                    self.a.unparse_tokens(tokens);
75            tokens.push(PtxToken::Comma);
76                    self.b.unparse_tokens(tokens);
77            tokens.push(PtxToken::Comma);
78                    self.c.unparse_tokens(tokens);
79            tokens.push(PtxToken::Semicolon);
80        }
81    }
82
83}
84
85pub mod section_1 {
86    use super::*;
87    use crate::r#type::instruction::wmma_mma::section_1::*;
88
89    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeS32AtypeBtypeS32Satfinite {
90        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
91            push_opcode(tokens, "wmma");
92                    push_directive(tokens, "mma");
93                    push_directive(tokens, "sync");
94                    push_directive(tokens, "aligned");
95                    match &self.alayout {
96                            Alayout::Row => {
97                                    push_directive(tokens, "row");
98                            }
99                            Alayout::Col => {
100                                    push_directive(tokens, "col");
101                            }
102                    }
103                    match &self.blayout {
104                            Blayout::Row => {
105                                    push_directive(tokens, "row");
106                            }
107                            Blayout::Col => {
108                                    push_directive(tokens, "col");
109                            }
110                    }
111                    match &self.shape {
112                            Shape::M16n16k16 => {
113                                    push_directive(tokens, "m16n16k16");
114                            }
115                            Shape::M8n32k16 => {
116                                    push_directive(tokens, "m8n32k16");
117                            }
118                            Shape::M32n8k16 => {
119                                    push_directive(tokens, "m32n8k16");
120                            }
121                    }
122                    push_directive(tokens, "s32");
123                    match &self.atype {
124                            Atype::S8 => {
125                                    push_directive(tokens, "s8");
126                            }
127                            Atype::U8 => {
128                                    push_directive(tokens, "u8");
129                            }
130                    }
131                    match &self.btype {
132                            Btype::S8 => {
133                                    push_directive(tokens, "s8");
134                            }
135                            Btype::U8 => {
136                                    push_directive(tokens, "u8");
137                            }
138                    }
139                    push_directive(tokens, "s32");
140                    if self.satfinite {
141                            push_directive(tokens, "satfinite");
142                    }
143                    self.d.unparse_tokens(tokens);
144            tokens.push(PtxToken::Comma);
145                    self.a.unparse_tokens(tokens);
146            tokens.push(PtxToken::Comma);
147                    self.b.unparse_tokens(tokens);
148            tokens.push(PtxToken::Comma);
149                    self.c.unparse_tokens(tokens);
150            tokens.push(PtxToken::Semicolon);
151        }
152    }
153
154}
155
156pub mod section_2 {
157    use super::*;
158    use crate::r#type::instruction::wmma_mma::section_2::*;
159
160    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeF32AtypeBtypeF32 {
161        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
162            push_opcode(tokens, "wmma");
163                    push_directive(tokens, "mma");
164                    push_directive(tokens, "sync");
165                    push_directive(tokens, "aligned");
166                    match &self.alayout {
167                            Alayout::Row => {
168                                    push_directive(tokens, "row");
169                            }
170                            Alayout::Col => {
171                                    push_directive(tokens, "col");
172                            }
173                    }
174                    match &self.blayout {
175                            Blayout::Row => {
176                                    push_directive(tokens, "row");
177                            }
178                            Blayout::Col => {
179                                    push_directive(tokens, "col");
180                            }
181                    }
182                    match &self.shape {
183                            Shape::M16n16k16 => {
184                                    push_directive(tokens, "m16n16k16");
185                            }
186                            Shape::M8n32k16 => {
187                                    push_directive(tokens, "m8n32k16");
188                            }
189                            Shape::M32n8k16 => {
190                                    push_directive(tokens, "m32n8k16");
191                            }
192                    }
193                    push_directive(tokens, "f32");
194                    match &self.atype {
195                            Atype::Bf16 => {
196                                    push_directive(tokens, "bf16");
197                            }
198                    }
199                    match &self.btype {
200                            Btype::Bf16 => {
201                                    push_directive(tokens, "bf16");
202                            }
203                    }
204                    push_directive(tokens, "f32");
205                    self.d.unparse_tokens(tokens);
206            tokens.push(PtxToken::Comma);
207                    self.a.unparse_tokens(tokens);
208            tokens.push(PtxToken::Comma);
209                    self.b.unparse_tokens(tokens);
210            tokens.push(PtxToken::Comma);
211                    self.c.unparse_tokens(tokens);
212            tokens.push(PtxToken::Semicolon);
213        }
214    }
215
216}
217
218pub mod section_3 {
219    use super::*;
220    use crate::r#type::instruction::wmma_mma::section_3::*;
221
222    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeF32AtypeBtypeF321 {
223        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
224            push_opcode(tokens, "wmma");
225                    push_directive(tokens, "mma");
226                    push_directive(tokens, "sync");
227                    push_directive(tokens, "aligned");
228                    match &self.alayout {
229                            Alayout::Row => {
230                                    push_directive(tokens, "row");
231                            }
232                            Alayout::Col => {
233                                    push_directive(tokens, "col");
234                            }
235                    }
236                    match &self.blayout {
237                            Blayout::Row => {
238                                    push_directive(tokens, "row");
239                            }
240                            Blayout::Col => {
241                                    push_directive(tokens, "col");
242                            }
243                    }
244                    match &self.shape {
245                            Shape::M16n16k8 => {
246                                    push_directive(tokens, "m16n16k8");
247                            }
248                    }
249                    push_directive(tokens, "f32");
250                    match &self.atype {
251                            Atype::Tf32 => {
252                                    push_directive(tokens, "tf32");
253                            }
254                    }
255                    match &self.btype {
256                            Btype::Tf32 => {
257                                    push_directive(tokens, "tf32");
258                            }
259                    }
260                    push_directive(tokens, "f32");
261                    self.d.unparse_tokens(tokens);
262            tokens.push(PtxToken::Comma);
263                    self.a.unparse_tokens(tokens);
264            tokens.push(PtxToken::Comma);
265                    self.b.unparse_tokens(tokens);
266            tokens.push(PtxToken::Comma);
267                    self.c.unparse_tokens(tokens);
268            tokens.push(PtxToken::Semicolon);
269        }
270    }
271
272}
273
274pub mod section_4 {
275    use super::*;
276    use crate::r#type::instruction::wmma_mma::section_4::*;
277
278    impl PtxUnparser for WmmaMmaSyncAlignedAlayoutBlayoutShapeRndF64F64F64F64 {
279        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
280            push_opcode(tokens, "wmma");
281                    push_directive(tokens, "mma");
282                    push_directive(tokens, "sync");
283                    push_directive(tokens, "aligned");
284                    match &self.alayout {
285                            Alayout::Row => {
286                                    push_directive(tokens, "row");
287                            }
288                            Alayout::Col => {
289                                    push_directive(tokens, "col");
290                            }
291                    }
292                    match &self.blayout {
293                            Blayout::Row => {
294                                    push_directive(tokens, "row");
295                            }
296                            Blayout::Col => {
297                                    push_directive(tokens, "col");
298                            }
299                    }
300                    match &self.shape {
301                            Shape::M8n8k4 => {
302                                    push_directive(tokens, "m8n8k4");
303                            }
304                    }
305                    if let Some(rnd_0) = self.rnd.as_ref() {
306                            match rnd_0 {
307                                    Rnd::Rn => {
308                                            push_directive(tokens, "rn");
309                                    }
310                                    Rnd::Rz => {
311                                            push_directive(tokens, "rz");
312                                    }
313                                    Rnd::Rm => {
314                                            push_directive(tokens, "rm");
315                                    }
316                                    Rnd::Rp => {
317                                            push_directive(tokens, "rp");
318                                    }
319                            }
320                    }
321                    push_directive(tokens, "f64");
322                    push_directive(tokens, "f64");
323                    push_directive(tokens, "f64");
324                    push_directive(tokens, "f64");
325                    self.d.unparse_tokens(tokens);
326            tokens.push(PtxToken::Comma);
327                    self.a.unparse_tokens(tokens);
328            tokens.push(PtxToken::Comma);
329                    self.b.unparse_tokens(tokens);
330            tokens.push(PtxToken::Comma);
331                    self.c.unparse_tokens(tokens);
332            tokens.push(PtxToken::Semicolon);
333        }
334    }
335
336}
337
338pub mod section_5 {
339    use super::*;
340    use crate::r#type::instruction::wmma_mma::section_5::*;
341
342    impl PtxUnparser for WmmaMmaSyncAlignedRowColShapeS32AtypeBtypeS32Satfinite {
343        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
344            push_opcode(tokens, "wmma");
345                    push_directive(tokens, "mma");
346                    push_directive(tokens, "sync");
347                    push_directive(tokens, "aligned");
348                    push_directive(tokens, "row");
349                    push_directive(tokens, "col");
350                    match &self.shape {
351                            Shape::M8n8k32 => {
352                                    push_directive(tokens, "m8n8k32");
353                            }
354                    }
355                    push_directive(tokens, "s32");
356                    match &self.atype {
357                            Atype::S4 => {
358                                    push_directive(tokens, "s4");
359                            }
360                            Atype::U4 => {
361                                    push_directive(tokens, "u4");
362                            }
363                    }
364                    match &self.btype {
365                            Btype::S4 => {
366                                    push_directive(tokens, "s4");
367                            }
368                            Btype::U4 => {
369                                    push_directive(tokens, "u4");
370                            }
371                    }
372                    push_directive(tokens, "s32");
373                    if self.satfinite {
374                            push_directive(tokens, "satfinite");
375                    }
376                    self.d.unparse_tokens(tokens);
377            tokens.push(PtxToken::Comma);
378                    self.a.unparse_tokens(tokens);
379            tokens.push(PtxToken::Comma);
380                    self.b.unparse_tokens(tokens);
381            tokens.push(PtxToken::Comma);
382                    self.c.unparse_tokens(tokens);
383            tokens.push(PtxToken::Semicolon);
384        }
385    }
386
387}
388
389pub mod section_6 {
390    use super::*;
391    use crate::r#type::instruction::wmma_mma::section_6::*;
392
393    impl PtxUnparser for WmmaMmaOpPopcSyncAlignedRowColShapeS32AtypeBtypeS32 {
394        fn unparse_tokens(&self, tokens: &mut ::std::vec::Vec<PtxToken>) {
395            push_opcode(tokens, "wmma");
396                    push_directive(tokens, "mma");
397                    match &self.op {
398                            Op::Xor => {
399                                    push_directive(tokens, "xor");
400                            }
401                            Op::And => {
402                                    push_directive(tokens, "and");
403                            }
404                    }
405                    push_directive(tokens, "popc");
406                    push_directive(tokens, "sync");
407                    push_directive(tokens, "aligned");
408                    push_directive(tokens, "row");
409                    push_directive(tokens, "col");
410                    match &self.shape {
411                            Shape::M8n8k128 => {
412                                    push_directive(tokens, "m8n8k128");
413                            }
414                    }
415                    push_directive(tokens, "s32");
416                    match &self.atype {
417                            Atype::B1 => {
418                                    push_directive(tokens, "b1");
419                            }
420                    }
421                    match &self.btype {
422                            Btype::B1 => {
423                                    push_directive(tokens, "b1");
424                            }
425                    }
426                    push_directive(tokens, "s32");
427                    self.d.unparse_tokens(tokens);
428            tokens.push(PtxToken::Comma);
429                    self.a.unparse_tokens(tokens);
430            tokens.push(PtxToken::Comma);
431                    self.b.unparse_tokens(tokens);
432            tokens.push(PtxToken::Comma);
433                    self.c.unparse_tokens(tokens);
434            tokens.push(PtxToken::Semicolon);
435        }
436    }
437
438}
439