ptx_parser/parser/instruction/
wgmma_mma_async.rs

1//! Original PTX specification:
2//!
3//! // Half precision floating point type:
4//! wgmma.mma_async.sync.aligned.shape.dtype.f16.f16  d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b, imm-trans-a, imm-trans-b;
5//! wgmma.mma_async.sync.aligned.shape.dtype.f16.f16  d, a, b-desc, scale-d, imm-scale-a, imm-scale-b, imm-trans-b;
6//! .shape   = {.m64n8k16, .m64n16k16, .m64n24k16, .m64n32k16,
7//! .m64n40k16, .m64n48k16, .m64n56k16, .m64n64k16,
8//! .m64n72k16, .m64n80k16, .m64n88k16, .m64n96k16,
9//! .m64n104k16, .m64n112k16, .m64n120k16, .m64n128k16,
10//! .m64n136k16, .m64n144k16, .m64n152k16, .m64n160k16,
11//! .m64n168k16, .m64n176k16, .m64n184k16, .m64n192k16,
12//! .m64n200k16, .m64n208k16, .m64n216k16, .m64n224k16,
13//! .m64n232k16, .m64n240k16, .m64n248k16, .m64n256k16};
14//! .dtype   = {.f16, .f32};
15//! ------------------------------------------------------------------
16//! // Alternate floating point type :
17//! // .bf16 floating point type:
18//! wgmma.mma_async.sync.aligned.shape.dtype.bf16.bf16  d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b, imm-trans-a, imm-trans-b;
19//! wgmma.mma_async.sync.aligned.shape.dtype.bf16.bf16  d, a, b-desc, scale-d, imm-scale-a, imm-scale-b, imm-trans-b;
20//! .shape   = {.m64n8k16, .m64n16k16, .m64n24k16, .m64n32k16,
21//! .m64n40k16, .m64n48k16, .m64n56k16, .m64n64k16,
22//! .m64n72k16, .m64n80k16, .m64n88k16, .m64n96k16,
23//! .m64n104k16, .m64n112k16, .m64n120k16, .m64n128k16,
24//! .m64n136k16, .m64n144k16, .m64n152k16, .m64n160k16,
25//! .m64n168k16, .m64n176k16, .m64n184k16, .m64n192k16,
26//! .m64n200k16, .m64n208k16, .m64n216k16, .m64n224k16,
27//! .m64n232k16, .m64n240k16, .m64n248k16, .m64n256k16};
28//! .dtype  = {.f32};
29//! ------------------------------------------------------------------
30//! // .tf32 floating point type:
31//! wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32  d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b;
32//! wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32  d, a, b-desc, scale-d, imm-scale-a, imm-scale-b;
33//! .shape   = {.m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8,
34//! .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8,
35//! .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8,
36//! .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8,
37//! .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8,
38//! .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8,
39//! .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8,
40//! .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8};
41//! .dtype  = {.f32};
42//! ------------------------------------------------------------------
43//! // FP8 floating point type
44//! wgmma.mma_async.sync.aligned.shape.dtype.atype.btype  d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b;
45//! wgmma.mma_async.sync.aligned.shape.dtype.atype.btype  d, a, b-desc, scale-d, imm-scale-a, imm-scale-b;
46//! .shape   = {.m64n8k32, .m64n16k32, .m64n24k32, .m64n32k32,
47//! .m64n40k32, .m64n48k32, .m64n56k32, .m64n64k32,
48//! .m64n72k32, .m64n80k32, .m64n88k32, .m64n96k32,
49//! .m64n104k32, .m64n112k32, .m64n120k32, .m64n128k32,
50//! .m64n136k32, .m64n144k32, .m64n152k32, .m64n160k32,
51//! .m64n168k32, .m64n176k32, .m64n184k32, .m64n192k32,
52//! .m64n200k32, .m64n208k32, .m64n216k32, .m64n224k32,
53//! .m64n232k32, .m64n240k32, .m64n248k32, .m64n256k32};
54//! .atype  = {.e4m3, .e5m2};
55//! .btype  = {.e4m3, .e5m2};
56//! .dtype  = {.f16, .f32};
57//! ------------------------------------------------------------------
58//! // Integer type:
59//! wgmma.mma_async.sync.aligned.shape{.satfinite}.s32.atype.btype  d, a-desc, b-desc, scale-d;
60//! wgmma.mma_async.sync.aligned.shape{.satfinite}.s32.atype.btype  d, a, b-desc, scale-d;
61//! .shape   = {.m64n8k32, .m64n16k32, .m64n24k32, .m64n32k32,
62//! .m64n48k32, .m64n64k32, .m64n80k32, .m64n96k32,
63//! .m64n112k32, .m64n128k32, .m64n144k32, .m64n160k32,
64//! .m64n176k32, .m64n192k32, .m64n208k32, .m64n224k32};
65//! .atype  = {.s8, .u8};
66//! .btype  = {.s8, .u8};
67//! ------------------------------------------------------------------
68//! // Single bit:
69//! wgmma.mma_async.sync.aligned.shape.s32.b1.b1.op.popc  d, a-desc, b-desc, scale-d;
70//! wgmma.mma_async.sync.aligned.shape.s32.b1.b1.op.popc  d, a, b-desc, scale-d;
71//! .shape   = {.m64n8k256, .m64n16k256, .m64n24k256, .m64n32k256,
72//! .m64n48k256, .m64n64k256, .m64n80k256, .m64n96k256,
73//! .m64n112k256, .m64n128k256, .m64n144k256, .m64n160k256,
74//! .m64n176k256, .m64n192k256, .m64n208k256, .m64n224k256,
75//! .m64n240k256, .m64n256k256};
76//! .op  = {.and};
77
78#![allow(unused)]
79
80use crate::parser::{
81    PtxParseError, PtxParser, PtxTokenStream, Span,
82    util::{
83        between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
84        pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
85    },
86};
87use crate::r#type::common::*;
88use crate::{alt, ok, seq_n};
89
90pub mod section_0 {
91    use super::*;
92    use crate::r#type::instruction::wgmma_mma_async::section_0::*;
93
94    // ============================================================================
95    // Generated enum parsers
96    // ============================================================================
97
98    impl PtxParser for Dtype {
99        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
100            alt!(
101                map(string_p(".f16"), |_, _span| Dtype::F16),
102                map(string_p(".f32"), |_, _span| Dtype::F32)
103            )
104        }
105    }
106
107    impl PtxParser for Shape {
108        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
109            alt!(
110                map(string_p(".m64n104k16"), |_, _span| Shape::M64n104k16),
111                map(string_p(".m64n112k16"), |_, _span| Shape::M64n112k16),
112                map(string_p(".m64n120k16"), |_, _span| Shape::M64n120k16),
113                map(string_p(".m64n128k16"), |_, _span| Shape::M64n128k16),
114                map(string_p(".m64n136k16"), |_, _span| Shape::M64n136k16),
115                map(string_p(".m64n144k16"), |_, _span| Shape::M64n144k16),
116                map(string_p(".m64n152k16"), |_, _span| Shape::M64n152k16),
117                map(string_p(".m64n160k16"), |_, _span| Shape::M64n160k16),
118                map(string_p(".m64n168k16"), |_, _span| Shape::M64n168k16),
119                map(string_p(".m64n176k16"), |_, _span| Shape::M64n176k16),
120                map(string_p(".m64n184k16"), |_, _span| Shape::M64n184k16),
121                map(string_p(".m64n192k16"), |_, _span| Shape::M64n192k16),
122                map(string_p(".m64n200k16"), |_, _span| Shape::M64n200k16),
123                map(string_p(".m64n208k16"), |_, _span| Shape::M64n208k16),
124                map(string_p(".m64n216k16"), |_, _span| Shape::M64n216k16),
125                map(string_p(".m64n224k16"), |_, _span| Shape::M64n224k16),
126                map(string_p(".m64n232k16"), |_, _span| Shape::M64n232k16),
127                map(string_p(".m64n240k16"), |_, _span| Shape::M64n240k16),
128                map(string_p(".m64n248k16"), |_, _span| Shape::M64n248k16),
129                map(string_p(".m64n256k16"), |_, _span| Shape::M64n256k16),
130                map(string_p(".m64n16k16"), |_, _span| Shape::M64n16k16),
131                map(string_p(".m64n24k16"), |_, _span| Shape::M64n24k16),
132                map(string_p(".m64n32k16"), |_, _span| Shape::M64n32k16),
133                map(string_p(".m64n40k16"), |_, _span| Shape::M64n40k16),
134                map(string_p(".m64n48k16"), |_, _span| Shape::M64n48k16),
135                map(string_p(".m64n56k16"), |_, _span| Shape::M64n56k16),
136                map(string_p(".m64n64k16"), |_, _span| Shape::M64n64k16),
137                map(string_p(".m64n72k16"), |_, _span| Shape::M64n72k16),
138                map(string_p(".m64n80k16"), |_, _span| Shape::M64n80k16),
139                map(string_p(".m64n88k16"), |_, _span| Shape::M64n88k16),
140                map(string_p(".m64n96k16"), |_, _span| Shape::M64n96k16),
141                map(string_p(".m64n8k16"), |_, _span| Shape::M64n8k16)
142            )
143        }
144    }
145
146    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeF16F16 {
147        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
148            try_map(
149                seq_n!(
150                    string_p("wgmma"),
151                    string_p(".mma_async"),
152                    string_p(".sync"),
153                    string_p(".aligned"),
154                    Shape::parse(),
155                    Dtype::parse(),
156                    string_p(".f16"),
157                    string_p(".f16"),
158                    GeneralOperand::parse(),
159                    comma_p(),
160                    GeneralOperand::parse(),
161                    comma_p(),
162                    GeneralOperand::parse(),
163                    comma_p(),
164                    GeneralOperand::parse(),
165                    comma_p(),
166                    GeneralOperand::parse(),
167                    comma_p(),
168                    GeneralOperand::parse(),
169                    comma_p(),
170                    GeneralOperand::parse(),
171                    comma_p(),
172                    GeneralOperand::parse(),
173                    semicolon_p()
174                ),
175                |(
176                    _,
177                    mma_async,
178                    sync,
179                    aligned,
180                    shape,
181                    dtype,
182                    f16,
183                    f162,
184                    d,
185                    _,
186                    a_desc,
187                    _,
188                    b_desc,
189                    _,
190                    scale_d,
191                    _,
192                    imm_scale_a,
193                    _,
194                    imm_scale_b,
195                    _,
196                    imm_trans_a,
197                    _,
198                    imm_trans_b,
199                    _,
200                ),
201                 span| {
202                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeF16F16 {
203                        mma_async = mma_async,
204                        sync = sync,
205                        aligned = aligned,
206                        shape = shape,
207                        dtype = dtype,
208                        f16 = f16,
209                        f162 = f162,
210                        d = d,
211                        a_desc = a_desc,
212                        b_desc = b_desc,
213                        scale_d = scale_d,
214                        imm_scale_a = imm_scale_a,
215                        imm_scale_b = imm_scale_b,
216                        imm_trans_a = imm_trans_a,
217                        imm_trans_b = imm_trans_b,
218
219                    })
220                },
221            )
222        }
223    }
224
225    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeF16F161 {
226        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
227            try_map(
228                seq_n!(
229                    string_p("wgmma"),
230                    string_p(".mma_async"),
231                    string_p(".sync"),
232                    string_p(".aligned"),
233                    Shape::parse(),
234                    Dtype::parse(),
235                    string_p(".f16"),
236                    string_p(".f16"),
237                    GeneralOperand::parse(),
238                    comma_p(),
239                    GeneralOperand::parse(),
240                    comma_p(),
241                    GeneralOperand::parse(),
242                    comma_p(),
243                    GeneralOperand::parse(),
244                    comma_p(),
245                    GeneralOperand::parse(),
246                    comma_p(),
247                    GeneralOperand::parse(),
248                    comma_p(),
249                    GeneralOperand::parse(),
250                    semicolon_p()
251                ),
252                |(
253                    _,
254                    mma_async,
255                    sync,
256                    aligned,
257                    shape,
258                    dtype,
259                    f16,
260                    f162,
261                    d,
262                    _,
263                    a,
264                    _,
265                    b_desc,
266                    _,
267                    scale_d,
268                    _,
269                    imm_scale_a,
270                    _,
271                    imm_scale_b,
272                    _,
273                    imm_trans_b,
274                    _,
275                ),
276                 span| {
277                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeF16F161 {
278                        mma_async = mma_async,
279                        sync = sync,
280                        aligned = aligned,
281                        shape = shape,
282                        dtype = dtype,
283                        f16 = f16,
284                        f162 = f162,
285                        d = d,
286                        a = a,
287                        b_desc = b_desc,
288                        scale_d = scale_d,
289                        imm_scale_a = imm_scale_a,
290                        imm_scale_b = imm_scale_b,
291                        imm_trans_b = imm_trans_b,
292
293                    })
294                },
295            )
296        }
297    }
298}
299
300pub mod section_1 {
301    use super::*;
302    use crate::r#type::instruction::wgmma_mma_async::section_1::*;
303
304    // ============================================================================
305    // Generated enum parsers
306    // ============================================================================
307
308    impl PtxParser for Dtype {
309        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
310            alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
311        }
312    }
313
314    impl PtxParser for Shape {
315        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
316            alt!(
317                map(string_p(".m64n104k16"), |_, _span| Shape::M64n104k16),
318                map(string_p(".m64n112k16"), |_, _span| Shape::M64n112k16),
319                map(string_p(".m64n120k16"), |_, _span| Shape::M64n120k16),
320                map(string_p(".m64n128k16"), |_, _span| Shape::M64n128k16),
321                map(string_p(".m64n136k16"), |_, _span| Shape::M64n136k16),
322                map(string_p(".m64n144k16"), |_, _span| Shape::M64n144k16),
323                map(string_p(".m64n152k16"), |_, _span| Shape::M64n152k16),
324                map(string_p(".m64n160k16"), |_, _span| Shape::M64n160k16),
325                map(string_p(".m64n168k16"), |_, _span| Shape::M64n168k16),
326                map(string_p(".m64n176k16"), |_, _span| Shape::M64n176k16),
327                map(string_p(".m64n184k16"), |_, _span| Shape::M64n184k16),
328                map(string_p(".m64n192k16"), |_, _span| Shape::M64n192k16),
329                map(string_p(".m64n200k16"), |_, _span| Shape::M64n200k16),
330                map(string_p(".m64n208k16"), |_, _span| Shape::M64n208k16),
331                map(string_p(".m64n216k16"), |_, _span| Shape::M64n216k16),
332                map(string_p(".m64n224k16"), |_, _span| Shape::M64n224k16),
333                map(string_p(".m64n232k16"), |_, _span| Shape::M64n232k16),
334                map(string_p(".m64n240k16"), |_, _span| Shape::M64n240k16),
335                map(string_p(".m64n248k16"), |_, _span| Shape::M64n248k16),
336                map(string_p(".m64n256k16"), |_, _span| Shape::M64n256k16),
337                map(string_p(".m64n16k16"), |_, _span| Shape::M64n16k16),
338                map(string_p(".m64n24k16"), |_, _span| Shape::M64n24k16),
339                map(string_p(".m64n32k16"), |_, _span| Shape::M64n32k16),
340                map(string_p(".m64n40k16"), |_, _span| Shape::M64n40k16),
341                map(string_p(".m64n48k16"), |_, _span| Shape::M64n48k16),
342                map(string_p(".m64n56k16"), |_, _span| Shape::M64n56k16),
343                map(string_p(".m64n64k16"), |_, _span| Shape::M64n64k16),
344                map(string_p(".m64n72k16"), |_, _span| Shape::M64n72k16),
345                map(string_p(".m64n80k16"), |_, _span| Shape::M64n80k16),
346                map(string_p(".m64n88k16"), |_, _span| Shape::M64n88k16),
347                map(string_p(".m64n96k16"), |_, _span| Shape::M64n96k16),
348                map(string_p(".m64n8k16"), |_, _span| Shape::M64n8k16)
349            )
350        }
351    }
352
353    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf16 {
354        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
355            try_map(
356                seq_n!(
357                    string_p("wgmma"),
358                    string_p(".mma_async"),
359                    string_p(".sync"),
360                    string_p(".aligned"),
361                    Shape::parse(),
362                    Dtype::parse(),
363                    string_p(".bf16"),
364                    string_p(".bf16"),
365                    GeneralOperand::parse(),
366                    comma_p(),
367                    GeneralOperand::parse(),
368                    comma_p(),
369                    GeneralOperand::parse(),
370                    comma_p(),
371                    GeneralOperand::parse(),
372                    comma_p(),
373                    GeneralOperand::parse(),
374                    comma_p(),
375                    GeneralOperand::parse(),
376                    comma_p(),
377                    GeneralOperand::parse(),
378                    comma_p(),
379                    GeneralOperand::parse(),
380                    semicolon_p()
381                ),
382                |(
383                    _,
384                    mma_async,
385                    sync,
386                    aligned,
387                    shape,
388                    dtype,
389                    bf16,
390                    bf162,
391                    d,
392                    _,
393                    a_desc,
394                    _,
395                    b_desc,
396                    _,
397                    scale_d,
398                    _,
399                    imm_scale_a,
400                    _,
401                    imm_scale_b,
402                    _,
403                    imm_trans_a,
404                    _,
405                    imm_trans_b,
406                    _,
407                ),
408                 span| {
409                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf16 {
410                        mma_async = mma_async,
411                        sync = sync,
412                        aligned = aligned,
413                        shape = shape,
414                        dtype = dtype,
415                        bf16 = bf16,
416                        bf162 = bf162,
417                        d = d,
418                        a_desc = a_desc,
419                        b_desc = b_desc,
420                        scale_d = scale_d,
421                        imm_scale_a = imm_scale_a,
422                        imm_scale_b = imm_scale_b,
423                        imm_trans_a = imm_trans_a,
424                        imm_trans_b = imm_trans_b,
425
426                    })
427                },
428            )
429        }
430    }
431
432    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf161 {
433        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
434            try_map(
435                seq_n!(
436                    string_p("wgmma"),
437                    string_p(".mma_async"),
438                    string_p(".sync"),
439                    string_p(".aligned"),
440                    Shape::parse(),
441                    Dtype::parse(),
442                    string_p(".bf16"),
443                    string_p(".bf16"),
444                    GeneralOperand::parse(),
445                    comma_p(),
446                    GeneralOperand::parse(),
447                    comma_p(),
448                    GeneralOperand::parse(),
449                    comma_p(),
450                    GeneralOperand::parse(),
451                    comma_p(),
452                    GeneralOperand::parse(),
453                    comma_p(),
454                    GeneralOperand::parse(),
455                    comma_p(),
456                    GeneralOperand::parse(),
457                    semicolon_p()
458                ),
459                |(
460                    _,
461                    mma_async,
462                    sync,
463                    aligned,
464                    shape,
465                    dtype,
466                    bf16,
467                    bf162,
468                    d,
469                    _,
470                    a,
471                    _,
472                    b_desc,
473                    _,
474                    scale_d,
475                    _,
476                    imm_scale_a,
477                    _,
478                    imm_scale_b,
479                    _,
480                    imm_trans_b,
481                    _,
482                ),
483                 span| {
484                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf161 {
485                        mma_async = mma_async,
486                        sync = sync,
487                        aligned = aligned,
488                        shape = shape,
489                        dtype = dtype,
490                        bf16 = bf16,
491                        bf162 = bf162,
492                        d = d,
493                        a = a,
494                        b_desc = b_desc,
495                        scale_d = scale_d,
496                        imm_scale_a = imm_scale_a,
497                        imm_scale_b = imm_scale_b,
498                        imm_trans_b = imm_trans_b,
499
500                    })
501                },
502            )
503        }
504    }
505}
506
507pub mod section_2 {
508    use super::*;
509    use crate::r#type::instruction::wgmma_mma_async::section_2::*;
510
511    // ============================================================================
512    // Generated enum parsers
513    // ============================================================================
514
515    impl PtxParser for Dtype {
516        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
517            alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
518        }
519    }
520
521    impl PtxParser for Shape {
522        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
523            alt!(
524                map(string_p(".m64n104k8"), |_, _span| Shape::M64n104k8),
525                map(string_p(".m64n112k8"), |_, _span| Shape::M64n112k8),
526                map(string_p(".m64n120k8"), |_, _span| Shape::M64n120k8),
527                map(string_p(".m64n128k8"), |_, _span| Shape::M64n128k8),
528                map(string_p(".m64n136k8"), |_, _span| Shape::M64n136k8),
529                map(string_p(".m64n144k8"), |_, _span| Shape::M64n144k8),
530                map(string_p(".m64n152k8"), |_, _span| Shape::M64n152k8),
531                map(string_p(".m64n160k8"), |_, _span| Shape::M64n160k8),
532                map(string_p(".m64n168k8"), |_, _span| Shape::M64n168k8),
533                map(string_p(".m64n176k8"), |_, _span| Shape::M64n176k8),
534                map(string_p(".m64n184k8"), |_, _span| Shape::M64n184k8),
535                map(string_p(".m64n192k8"), |_, _span| Shape::M64n192k8),
536                map(string_p(".m64n200k8"), |_, _span| Shape::M64n200k8),
537                map(string_p(".m64n208k8"), |_, _span| Shape::M64n208k8),
538                map(string_p(".m64n216k8"), |_, _span| Shape::M64n216k8),
539                map(string_p(".m64n224k8"), |_, _span| Shape::M64n224k8),
540                map(string_p(".m64n232k8"), |_, _span| Shape::M64n232k8),
541                map(string_p(".m64n240k8"), |_, _span| Shape::M64n240k8),
542                map(string_p(".m64n248k8"), |_, _span| Shape::M64n248k8),
543                map(string_p(".m64n256k8"), |_, _span| Shape::M64n256k8),
544                map(string_p(".m64n16k8"), |_, _span| Shape::M64n16k8),
545                map(string_p(".m64n24k8"), |_, _span| Shape::M64n24k8),
546                map(string_p(".m64n32k8"), |_, _span| Shape::M64n32k8),
547                map(string_p(".m64n40k8"), |_, _span| Shape::M64n40k8),
548                map(string_p(".m64n48k8"), |_, _span| Shape::M64n48k8),
549                map(string_p(".m64n56k8"), |_, _span| Shape::M64n56k8),
550                map(string_p(".m64n64k8"), |_, _span| Shape::M64n64k8),
551                map(string_p(".m64n72k8"), |_, _span| Shape::M64n72k8),
552                map(string_p(".m64n80k8"), |_, _span| Shape::M64n80k8),
553                map(string_p(".m64n88k8"), |_, _span| Shape::M64n88k8),
554                map(string_p(".m64n96k8"), |_, _span| Shape::M64n96k8),
555                map(string_p(".m64n8k8"), |_, _span| Shape::M64n8k8)
556            )
557        }
558    }
559
560    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf32 {
561        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
562            try_map(
563                seq_n!(
564                    string_p("wgmma"),
565                    string_p(".mma_async"),
566                    string_p(".sync"),
567                    string_p(".aligned"),
568                    Shape::parse(),
569                    Dtype::parse(),
570                    string_p(".tf32"),
571                    string_p(".tf32"),
572                    GeneralOperand::parse(),
573                    comma_p(),
574                    GeneralOperand::parse(),
575                    comma_p(),
576                    GeneralOperand::parse(),
577                    comma_p(),
578                    GeneralOperand::parse(),
579                    comma_p(),
580                    GeneralOperand::parse(),
581                    comma_p(),
582                    GeneralOperand::parse(),
583                    semicolon_p()
584                ),
585                |(
586                    _,
587                    mma_async,
588                    sync,
589                    aligned,
590                    shape,
591                    dtype,
592                    tf32,
593                    tf322,
594                    d,
595                    _,
596                    a_desc,
597                    _,
598                    b_desc,
599                    _,
600                    scale_d,
601                    _,
602                    imm_scale_a,
603                    _,
604                    imm_scale_b,
605                    _,
606                ),
607                 span| {
608                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf32 {
609                        mma_async = mma_async,
610                        sync = sync,
611                        aligned = aligned,
612                        shape = shape,
613                        dtype = dtype,
614                        tf32 = tf32,
615                        tf322 = tf322,
616                        d = d,
617                        a_desc = a_desc,
618                        b_desc = b_desc,
619                        scale_d = scale_d,
620                        imm_scale_a = imm_scale_a,
621                        imm_scale_b = imm_scale_b,
622
623                    })
624                },
625            )
626        }
627    }
628
629    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf321 {
630        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
631            try_map(
632                seq_n!(
633                    string_p("wgmma"),
634                    string_p(".mma_async"),
635                    string_p(".sync"),
636                    string_p(".aligned"),
637                    Shape::parse(),
638                    Dtype::parse(),
639                    string_p(".tf32"),
640                    string_p(".tf32"),
641                    GeneralOperand::parse(),
642                    comma_p(),
643                    GeneralOperand::parse(),
644                    comma_p(),
645                    GeneralOperand::parse(),
646                    comma_p(),
647                    GeneralOperand::parse(),
648                    comma_p(),
649                    GeneralOperand::parse(),
650                    comma_p(),
651                    GeneralOperand::parse(),
652                    semicolon_p()
653                ),
654                |(
655                    _,
656                    mma_async,
657                    sync,
658                    aligned,
659                    shape,
660                    dtype,
661                    tf32,
662                    tf322,
663                    d,
664                    _,
665                    a,
666                    _,
667                    b_desc,
668                    _,
669                    scale_d,
670                    _,
671                    imm_scale_a,
672                    _,
673                    imm_scale_b,
674                    _,
675                ),
676                 span| {
677                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf321 {
678                        mma_async = mma_async,
679                        sync = sync,
680                        aligned = aligned,
681                        shape = shape,
682                        dtype = dtype,
683                        tf32 = tf32,
684                        tf322 = tf322,
685                        d = d,
686                        a = a,
687                        b_desc = b_desc,
688                        scale_d = scale_d,
689                        imm_scale_a = imm_scale_a,
690                        imm_scale_b = imm_scale_b,
691
692                    })
693                },
694            )
695        }
696    }
697}
698
699pub mod section_3 {
700    use super::*;
701    use crate::r#type::instruction::wgmma_mma_async::section_3::*;
702
703    // ============================================================================
704    // Generated enum parsers
705    // ============================================================================
706
707    impl PtxParser for Atype {
708        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
709            alt!(
710                map(string_p(".e4m3"), |_, _span| Atype::E4m3),
711                map(string_p(".e5m2"), |_, _span| Atype::E5m2)
712            )
713        }
714    }
715
716    impl PtxParser for Btype {
717        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
718            alt!(
719                map(string_p(".e4m3"), |_, _span| Btype::E4m3),
720                map(string_p(".e5m2"), |_, _span| Btype::E5m2)
721            )
722        }
723    }
724
725    impl PtxParser for Dtype {
726        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
727            alt!(
728                map(string_p(".f16"), |_, _span| Dtype::F16),
729                map(string_p(".f32"), |_, _span| Dtype::F32)
730            )
731        }
732    }
733
734    impl PtxParser for Shape {
735        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
736            alt!(
737                map(string_p(".m64n104k32"), |_, _span| Shape::M64n104k32),
738                map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
739                map(string_p(".m64n120k32"), |_, _span| Shape::M64n120k32),
740                map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
741                map(string_p(".m64n136k32"), |_, _span| Shape::M64n136k32),
742                map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
743                map(string_p(".m64n152k32"), |_, _span| Shape::M64n152k32),
744                map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
745                map(string_p(".m64n168k32"), |_, _span| Shape::M64n168k32),
746                map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
747                map(string_p(".m64n184k32"), |_, _span| Shape::M64n184k32),
748                map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
749                map(string_p(".m64n200k32"), |_, _span| Shape::M64n200k32),
750                map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
751                map(string_p(".m64n216k32"), |_, _span| Shape::M64n216k32),
752                map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
753                map(string_p(".m64n232k32"), |_, _span| Shape::M64n232k32),
754                map(string_p(".m64n240k32"), |_, _span| Shape::M64n240k32),
755                map(string_p(".m64n248k32"), |_, _span| Shape::M64n248k32),
756                map(string_p(".m64n256k32"), |_, _span| Shape::M64n256k32),
757                map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
758                map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
759                map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
760                map(string_p(".m64n40k32"), |_, _span| Shape::M64n40k32),
761                map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
762                map(string_p(".m64n56k32"), |_, _span| Shape::M64n56k32),
763                map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
764                map(string_p(".m64n72k32"), |_, _span| Shape::M64n72k32),
765                map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
766                map(string_p(".m64n88k32"), |_, _span| Shape::M64n88k32),
767                map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
768                map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
769            )
770        }
771    }
772
773    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype {
774        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
775            try_map(
776                seq_n!(
777                    string_p("wgmma"),
778                    string_p(".mma_async"),
779                    string_p(".sync"),
780                    string_p(".aligned"),
781                    Shape::parse(),
782                    Dtype::parse(),
783                    Atype::parse(),
784                    Btype::parse(),
785                    GeneralOperand::parse(),
786                    comma_p(),
787                    GeneralOperand::parse(),
788                    comma_p(),
789                    GeneralOperand::parse(),
790                    comma_p(),
791                    GeneralOperand::parse(),
792                    comma_p(),
793                    GeneralOperand::parse(),
794                    comma_p(),
795                    GeneralOperand::parse(),
796                    semicolon_p()
797                ),
798                |(
799                    _,
800                    mma_async,
801                    sync,
802                    aligned,
803                    shape,
804                    dtype,
805                    atype,
806                    btype,
807                    d,
808                    _,
809                    a_desc,
810                    _,
811                    b_desc,
812                    _,
813                    scale_d,
814                    _,
815                    imm_scale_a,
816                    _,
817                    imm_scale_b,
818                    _,
819                ),
820                 span| {
821                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype {
822                        mma_async = mma_async,
823                        sync = sync,
824                        aligned = aligned,
825                        shape = shape,
826                        dtype = dtype,
827                        atype = atype,
828                        btype = btype,
829                        d = d,
830                        a_desc = a_desc,
831                        b_desc = b_desc,
832                        scale_d = scale_d,
833                        imm_scale_a = imm_scale_a,
834                        imm_scale_b = imm_scale_b,
835
836                    })
837                },
838            )
839        }
840    }
841
842    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype1 {
843        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
844            try_map(
845                seq_n!(
846                    string_p("wgmma"),
847                    string_p(".mma_async"),
848                    string_p(".sync"),
849                    string_p(".aligned"),
850                    Shape::parse(),
851                    Dtype::parse(),
852                    Atype::parse(),
853                    Btype::parse(),
854                    GeneralOperand::parse(),
855                    comma_p(),
856                    GeneralOperand::parse(),
857                    comma_p(),
858                    GeneralOperand::parse(),
859                    comma_p(),
860                    GeneralOperand::parse(),
861                    comma_p(),
862                    GeneralOperand::parse(),
863                    comma_p(),
864                    GeneralOperand::parse(),
865                    semicolon_p()
866                ),
867                |(
868                    _,
869                    mma_async,
870                    sync,
871                    aligned,
872                    shape,
873                    dtype,
874                    atype,
875                    btype,
876                    d,
877                    _,
878                    a,
879                    _,
880                    b_desc,
881                    _,
882                    scale_d,
883                    _,
884                    imm_scale_a,
885                    _,
886                    imm_scale_b,
887                    _,
888                ),
889                 span| {
890                    ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype1 {
891                        mma_async = mma_async,
892                        sync = sync,
893                        aligned = aligned,
894                        shape = shape,
895                        dtype = dtype,
896                        atype = atype,
897                        btype = btype,
898                        d = d,
899                        a = a,
900                        b_desc = b_desc,
901                        scale_d = scale_d,
902                        imm_scale_a = imm_scale_a,
903                        imm_scale_b = imm_scale_b,
904
905                    })
906                },
907            )
908        }
909    }
910}
911
912pub mod section_4 {
913    use super::*;
914    use crate::r#type::instruction::wgmma_mma_async::section_4::*;
915
916    // ============================================================================
917    // Generated enum parsers
918    // ============================================================================
919
920    impl PtxParser for Atype {
921        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
922            alt!(
923                map(string_p(".s8"), |_, _span| Atype::S8),
924                map(string_p(".u8"), |_, _span| Atype::U8)
925            )
926        }
927    }
928
929    impl PtxParser for Btype {
930        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
931            alt!(
932                map(string_p(".s8"), |_, _span| Btype::S8),
933                map(string_p(".u8"), |_, _span| Btype::U8)
934            )
935        }
936    }
937
938    impl PtxParser for Shape {
939        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
940            alt!(
941                map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
942                map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
943                map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
944                map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
945                map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
946                map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
947                map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
948                map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
949                map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
950                map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
951                map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
952                map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
953                map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
954                map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
955                map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
956                map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
957            )
958        }
959    }
960
961    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype {
962        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
963            try_map(
964                seq_n!(
965                    string_p("wgmma"),
966                    string_p(".mma_async"),
967                    string_p(".sync"),
968                    string_p(".aligned"),
969                    Shape::parse(),
970                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
971                    string_p(".s32"),
972                    Atype::parse(),
973                    Btype::parse(),
974                    GeneralOperand::parse(),
975                    comma_p(),
976                    GeneralOperand::parse(),
977                    comma_p(),
978                    GeneralOperand::parse(),
979                    comma_p(),
980                    GeneralOperand::parse(),
981                    semicolon_p()
982                ),
983                |(
984                    _,
985                    mma_async,
986                    sync,
987                    aligned,
988                    shape,
989                    satfinite,
990                    s32,
991                    atype,
992                    btype,
993                    d,
994                    _,
995                    a_desc,
996                    _,
997                    b_desc,
998                    _,
999                    scale_d,
1000                    _,
1001                ),
1002                 span| {
1003                    ok!(WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype {
1004                        mma_async = mma_async,
1005                        sync = sync,
1006                        aligned = aligned,
1007                        shape = shape,
1008                        satfinite = satfinite,
1009                        s32 = s32,
1010                        atype = atype,
1011                        btype = btype,
1012                        d = d,
1013                        a_desc = a_desc,
1014                        b_desc = b_desc,
1015                        scale_d = scale_d,
1016
1017                    })
1018                },
1019            )
1020        }
1021    }
1022
1023    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1024        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1025            try_map(
1026                seq_n!(
1027                    string_p("wgmma"),
1028                    string_p(".mma_async"),
1029                    string_p(".sync"),
1030                    string_p(".aligned"),
1031                    Shape::parse(),
1032                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1033                    string_p(".s32"),
1034                    Atype::parse(),
1035                    Btype::parse(),
1036                    GeneralOperand::parse(),
1037                    comma_p(),
1038                    GeneralOperand::parse(),
1039                    comma_p(),
1040                    GeneralOperand::parse(),
1041                    comma_p(),
1042                    GeneralOperand::parse(),
1043                    semicolon_p()
1044                ),
1045                |(
1046                    _,
1047                    mma_async,
1048                    sync,
1049                    aligned,
1050                    shape,
1051                    satfinite,
1052                    s32,
1053                    atype,
1054                    btype,
1055                    d,
1056                    _,
1057                    a,
1058                    _,
1059                    b_desc,
1060                    _,
1061                    scale_d,
1062                    _,
1063                ),
1064                 span| {
1065                    ok!(WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1066                        mma_async = mma_async,
1067                        sync = sync,
1068                        aligned = aligned,
1069                        shape = shape,
1070                        satfinite = satfinite,
1071                        s32 = s32,
1072                        atype = atype,
1073                        btype = btype,
1074                        d = d,
1075                        a = a,
1076                        b_desc = b_desc,
1077                        scale_d = scale_d,
1078
1079                    })
1080                },
1081            )
1082        }
1083    }
1084}
1085
1086pub mod section_5 {
1087    use super::*;
1088    use crate::r#type::instruction::wgmma_mma_async::section_5::*;
1089
1090    // ============================================================================
1091    // Generated enum parsers
1092    // ============================================================================
1093
1094    impl PtxParser for Op {
1095        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1096            alt!(map(string_p(".and"), |_, _span| Op::And))
1097        }
1098    }
1099
1100    impl PtxParser for Shape {
1101        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1102            alt!(
1103                map(string_p(".m64n112k256"), |_, _span| Shape::M64n112k256),
1104                map(string_p(".m64n128k256"), |_, _span| Shape::M64n128k256),
1105                map(string_p(".m64n144k256"), |_, _span| Shape::M64n144k256),
1106                map(string_p(".m64n160k256"), |_, _span| Shape::M64n160k256),
1107                map(string_p(".m64n176k256"), |_, _span| Shape::M64n176k256),
1108                map(string_p(".m64n192k256"), |_, _span| Shape::M64n192k256),
1109                map(string_p(".m64n208k256"), |_, _span| Shape::M64n208k256),
1110                map(string_p(".m64n224k256"), |_, _span| Shape::M64n224k256),
1111                map(string_p(".m64n240k256"), |_, _span| Shape::M64n240k256),
1112                map(string_p(".m64n256k256"), |_, _span| Shape::M64n256k256),
1113                map(string_p(".m64n16k256"), |_, _span| Shape::M64n16k256),
1114                map(string_p(".m64n24k256"), |_, _span| Shape::M64n24k256),
1115                map(string_p(".m64n32k256"), |_, _span| Shape::M64n32k256),
1116                map(string_p(".m64n48k256"), |_, _span| Shape::M64n48k256),
1117                map(string_p(".m64n64k256"), |_, _span| Shape::M64n64k256),
1118                map(string_p(".m64n80k256"), |_, _span| Shape::M64n80k256),
1119                map(string_p(".m64n96k256"), |_, _span| Shape::M64n96k256),
1120                map(string_p(".m64n8k256"), |_, _span| Shape::M64n8k256)
1121            )
1122        }
1123    }
1124
1125    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc {
1126        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1127            try_map(
1128                seq_n!(
1129                    string_p("wgmma"),
1130                    string_p(".mma_async"),
1131                    string_p(".sync"),
1132                    string_p(".aligned"),
1133                    Shape::parse(),
1134                    string_p(".s32"),
1135                    string_p(".b1"),
1136                    string_p(".b1"),
1137                    Op::parse(),
1138                    string_p(".popc"),
1139                    GeneralOperand::parse(),
1140                    comma_p(),
1141                    GeneralOperand::parse(),
1142                    comma_p(),
1143                    GeneralOperand::parse(),
1144                    comma_p(),
1145                    GeneralOperand::parse(),
1146                    semicolon_p()
1147                ),
1148                |(
1149                    _,
1150                    mma_async,
1151                    sync,
1152                    aligned,
1153                    shape,
1154                    s32,
1155                    b1,
1156                    b12,
1157                    op,
1158                    popc,
1159                    d,
1160                    _,
1161                    a_desc,
1162                    _,
1163                    b_desc,
1164                    _,
1165                    scale_d,
1166                    _,
1167                ),
1168                 span| {
1169                    ok!(WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc {
1170                        mma_async = mma_async,
1171                        sync = sync,
1172                        aligned = aligned,
1173                        shape = shape,
1174                        s32 = s32,
1175                        b1 = b1,
1176                        b12 = b12,
1177                        op = op,
1178                        popc = popc,
1179                        d = d,
1180                        a_desc = a_desc,
1181                        b_desc = b_desc,
1182                        scale_d = scale_d,
1183
1184                    })
1185                },
1186            )
1187        }
1188    }
1189
1190    impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc1 {
1191        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1192            try_map(
1193                seq_n!(
1194                    string_p("wgmma"),
1195                    string_p(".mma_async"),
1196                    string_p(".sync"),
1197                    string_p(".aligned"),
1198                    Shape::parse(),
1199                    string_p(".s32"),
1200                    string_p(".b1"),
1201                    string_p(".b1"),
1202                    Op::parse(),
1203                    string_p(".popc"),
1204                    GeneralOperand::parse(),
1205                    comma_p(),
1206                    GeneralOperand::parse(),
1207                    comma_p(),
1208                    GeneralOperand::parse(),
1209                    comma_p(),
1210                    GeneralOperand::parse(),
1211                    semicolon_p()
1212                ),
1213                |(
1214                    _,
1215                    mma_async,
1216                    sync,
1217                    aligned,
1218                    shape,
1219                    s32,
1220                    b1,
1221                    b12,
1222                    op,
1223                    popc,
1224                    d,
1225                    _,
1226                    a,
1227                    _,
1228                    b_desc,
1229                    _,
1230                    scale_d,
1231                    _,
1232                ),
1233                 span| {
1234                    ok!(WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc1 {
1235                        mma_async = mma_async,
1236                        sync = sync,
1237                        aligned = aligned,
1238                        shape = shape,
1239                        s32 = s32,
1240                        b1 = b1,
1241                        b12 = b12,
1242                        op = op,
1243                        popc = popc,
1244                        d = d,
1245                        a = a,
1246                        b_desc = b_desc,
1247                        scale_d = scale_d,
1248
1249                    })
1250                },
1251            )
1252        }
1253    }
1254}