ptx_parser/parser/instruction/
wgmma_mma_async_sp.rs

1//! Original PTX specification:
2//!
3//! // Half precision floating point type:
4//! wgmma.mma_async.sp.sync.aligned.shape.dtype.f16.f16  d, a-desc, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b, imm-trans-a, imm-trans-b;
5//! wgmma.mma_async.sp.sync.aligned.shape.dtype.f16.f16  d, a, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b, imm-trans-b;
6//! .shape   = {.m64n8k32, .m64n16k32, .m64n24k32, .m64n32k32,
7//! .m64n40k32, .m64n48k32, .m64n56k32, .m64n64k32,
8//! .m64n72k32, .m64n80k32, .m64n88k32, .m64n96k32,
9//! .m64n104k32, .m64n112k32, .m64n120k32, .m64n128k32,
10//! .m64n136k32, .m64n144k32, .m64n152k32, .m64n160k32,
11//! .m64n168k32, .m64n176k32, .m64n184k32, .m64n192k32,
12//! .m64n200k32, .m64n208k32, .m64n216k32, .m64n224k32,
13//! .m64n232k32, .m64n240k32, .m64n248k32, .m64n256k32};
14//! .dtype   = {.f16, .f32};
15//! ------------------------------------------------------------------
16//! // Alternate floating point type :
17//! // .bf16 floating point type:
18//! wgmma.mma_async.sp.sync.aligned.shape.dtype.bf16.bf16  d, a-desc, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b, imm-trans-a, imm-trans-b;
19//! wgmma.mma_async.sp.sync.aligned.shape.dtype.bf16.bf16  d, a, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b, imm-trans-b;
20//! .shape   = {.m64n8k32, .m64n16k32, .m64n24k32, .m64n32k32,
21//! .m64n40k32, .m64n48k32, .m64n56k32, .m64n64k32,
22//! .m64n72k32, .m64n80k32, .m64n88k32, .m64n96k32,
23//! .m64n104k32, .m64n112k32, .m64n120k32, .m64n128k32,
24//! .m64n136k32, .m64n144k32, .m64n152k32, .m64n160k32,
25//! .m64n168k32, .m64n176k32, .m64n184k32, .m64n192k32,
26//! .m64n200k32, .m64n208k32, .m64n216k32, .m64n224k32,
27//! .m64n232k32, .m64n240k32, .m64n248k32, .m64n256k32};
28//! .dtype  = {.f32};
29//! ------------------------------------------------------------------
30//! // .tf32 floating point type:
31//! wgmma.mma_async.sp.sync.aligned.shape.dtype.tf32.tf32  d, a-desc, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b;
32//! wgmma.mma_async.sp.sync.aligned.shape.dtype.tf32.tf32  d, a, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b;
33//! .shape   = {.m64n8k16, .m64n16k16, .m64n24k16, .m64n32k16,
34//! .m64n40k16, .m64n48k16, .m64n56k16, .m64n64k16,
35//! .m64n72k16, .m64n80k16, .m64n88k16, .m64n96k16,
36//! .m64n104k16, .m64n112k16, .m64n120k16, .m64n128k16,
37//! .m64n136k16, .m64n144k16, .m64n152k16, .m64n160k16,
38//! .m64n168k16, .m64n176k16, .m64n184k16, .m64n192k16,
39//! .m64n200k16, .m64n208k16, .m64n216k16, .m64n224k16,
40//! .m64n232k16, .m64n240k16, .m64n248k16, .m64n256k16};
41//! .dtype  = {.f32};
42//! ------------------------------------------------------------------
43//! // FP8 floating point type
44//! wgmma.mma_async.sp.sync.aligned.shape.dtype.atype.btype  d, a-desc, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b;
45//! wgmma.mma_async.sp.sync.aligned.shape.dtype.atype.btype  d, a, b-desc, sp-meta, sp-sel, scale-d, imm-scale-a, imm-scale-b;
46//! .shape   = {.m64n8k64, .m64n16k64, .m64n24k64, .m64n32k64,
47//! .m64n40k64, .m64n48k64, .m64n56k64, .m64n64k64,
48//! .m64n72k64, .m64n80k64, .m64n88k64, .m64n96k64,
49//! .m64n104k64, .m64n112k64, .m64n120k64, .m64n128k64,
50//! .m64n136k64, .m64n144k64, .m64n152k64, .m64n160k64,
51//! .m64n168k64, .m64n176k64, .m64n184k64, .m64n192k64,
52//! .m64n200k64, .m64n208k64, .m64n216k64, .m64n224k64,
53//! .m64n232k64, .m64n240k64, .m64n248k64, .m64n256k64};
54//! .atype  = {.e4m3, .e5m2};
55//! .btype  = {.e4m3, .e5m2};
56//! .dtype  = {.f16, .f32};
57//! ------------------------------------------------------------------
58//! // Integer type:
59//! wgmma.mma_async.sp.sync.aligned.shape{.satfinite}.s32.atype.btype  d, a-desc, b-desc, sp-meta, sp-sel, scale-d;
60//! wgmma.mma_async.sp.sync.aligned.shape{.satfinite}.s32.atype.btype  d, a, b-desc, sp-meta, sp-sel, scale-d;
61//! .shape   = {.m64n8k64, .m64n16k64, .m64n24k64, .m64n32k64,
62//! .m64n48k64, .m64n64k64, .m64n80k64, .m64n96k64,
63//! .m64n112k64, .m64n128k64, .m64n144k64, .m64n160k64,
64//! .m64n176k64, .m64n192k64, .m64n208k64, .m64n224k64,
65//! .m64n240k64, .m64n256k64};
66//! .atype  = {.s8, .u8};
67//! .btype  = {.s8, .u8};
68
69#![allow(unused)]
70
71use crate::parser::{
72    PtxParseError, PtxParser, PtxTokenStream, Span,
73    util::{
74        between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
75        pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
76    },
77};
78use crate::r#type::common::*;
79use crate::{alt, ok, seq_n};
80
81pub mod section_0 {
82    use super::*;
83    use crate::r#type::instruction::wgmma_mma_async_sp::section_0::*;
84
85    // ============================================================================
86    // Generated enum parsers
87    // ============================================================================
88
89    impl PtxParser for Dtype {
90        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
91            alt!(
92                map(string_p(".f16"), |_, _span| Dtype::F16),
93                map(string_p(".f32"), |_, _span| Dtype::F32)
94            )
95        }
96    }
97
98    impl PtxParser for Shape {
99        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
100            alt!(
101                map(string_p(".m64n104k32"), |_, _span| Shape::M64n104k32),
102                map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
103                map(string_p(".m64n120k32"), |_, _span| Shape::M64n120k32),
104                map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
105                map(string_p(".m64n136k32"), |_, _span| Shape::M64n136k32),
106                map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
107                map(string_p(".m64n152k32"), |_, _span| Shape::M64n152k32),
108                map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
109                map(string_p(".m64n168k32"), |_, _span| Shape::M64n168k32),
110                map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
111                map(string_p(".m64n184k32"), |_, _span| Shape::M64n184k32),
112                map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
113                map(string_p(".m64n200k32"), |_, _span| Shape::M64n200k32),
114                map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
115                map(string_p(".m64n216k32"), |_, _span| Shape::M64n216k32),
116                map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
117                map(string_p(".m64n232k32"), |_, _span| Shape::M64n232k32),
118                map(string_p(".m64n240k32"), |_, _span| Shape::M64n240k32),
119                map(string_p(".m64n248k32"), |_, _span| Shape::M64n248k32),
120                map(string_p(".m64n256k32"), |_, _span| Shape::M64n256k32),
121                map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
122                map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
123                map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
124                map(string_p(".m64n40k32"), |_, _span| Shape::M64n40k32),
125                map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
126                map(string_p(".m64n56k32"), |_, _span| Shape::M64n56k32),
127                map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
128                map(string_p(".m64n72k32"), |_, _span| Shape::M64n72k32),
129                map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
130                map(string_p(".m64n88k32"), |_, _span| Shape::M64n88k32),
131                map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
132                map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
133            )
134        }
135    }
136
137    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F16 {
138        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
139            try_map(
140                seq_n!(
141                    string_p("wgmma"),
142                    string_p(".mma_async"),
143                    string_p(".sp"),
144                    string_p(".sync"),
145                    string_p(".aligned"),
146                    Shape::parse(),
147                    Dtype::parse(),
148                    string_p(".f16"),
149                    string_p(".f16"),
150                    GeneralOperand::parse(),
151                    comma_p(),
152                    GeneralOperand::parse(),
153                    comma_p(),
154                    GeneralOperand::parse(),
155                    comma_p(),
156                    GeneralOperand::parse(),
157                    comma_p(),
158                    GeneralOperand::parse(),
159                    comma_p(),
160                    GeneralOperand::parse(),
161                    comma_p(),
162                    GeneralOperand::parse(),
163                    comma_p(),
164                    GeneralOperand::parse(),
165                    comma_p(),
166                    GeneralOperand::parse(),
167                    comma_p(),
168                    GeneralOperand::parse(),
169                    semicolon_p()
170                ),
171                |(
172                    _,
173                    mma_async,
174                    sp,
175                    sync,
176                    aligned,
177                    shape,
178                    dtype,
179                    f16,
180                    f162,
181                    d,
182                    _,
183                    a_desc,
184                    _,
185                    b_desc,
186                    _,
187                    sp_meta,
188                    _,
189                    sp_sel,
190                    _,
191                    scale_d,
192                    _,
193                    imm_scale_a,
194                    _,
195                    imm_scale_b,
196                    _,
197                    imm_trans_a,
198                    _,
199                    imm_trans_b,
200                    _,
201                ),
202                 span| {
203                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F16 {
204                        mma_async = mma_async,
205                        sp = sp,
206                        sync = sync,
207                        aligned = aligned,
208                        shape = shape,
209                        dtype = dtype,
210                        f16 = f16,
211                        f162 = f162,
212                        d = d,
213                        a_desc = a_desc,
214                        b_desc = b_desc,
215                        sp_meta = sp_meta,
216                        sp_sel = sp_sel,
217                        scale_d = scale_d,
218                        imm_scale_a = imm_scale_a,
219                        imm_scale_b = imm_scale_b,
220                        imm_trans_a = imm_trans_a,
221                        imm_trans_b = imm_trans_b,
222
223                    })
224                },
225            )
226        }
227    }
228
229    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F161 {
230        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
231            try_map(
232                seq_n!(
233                    string_p("wgmma"),
234                    string_p(".mma_async"),
235                    string_p(".sp"),
236                    string_p(".sync"),
237                    string_p(".aligned"),
238                    Shape::parse(),
239                    Dtype::parse(),
240                    string_p(".f16"),
241                    string_p(".f16"),
242                    GeneralOperand::parse(),
243                    comma_p(),
244                    GeneralOperand::parse(),
245                    comma_p(),
246                    GeneralOperand::parse(),
247                    comma_p(),
248                    GeneralOperand::parse(),
249                    comma_p(),
250                    GeneralOperand::parse(),
251                    comma_p(),
252                    GeneralOperand::parse(),
253                    comma_p(),
254                    GeneralOperand::parse(),
255                    comma_p(),
256                    GeneralOperand::parse(),
257                    comma_p(),
258                    GeneralOperand::parse(),
259                    semicolon_p()
260                ),
261                |(
262                    _,
263                    mma_async,
264                    sp,
265                    sync,
266                    aligned,
267                    shape,
268                    dtype,
269                    f16,
270                    f162,
271                    d,
272                    _,
273                    a,
274                    _,
275                    b_desc,
276                    _,
277                    sp_meta,
278                    _,
279                    sp_sel,
280                    _,
281                    scale_d,
282                    _,
283                    imm_scale_a,
284                    _,
285                    imm_scale_b,
286                    _,
287                    imm_trans_b,
288                    _,
289                ),
290                 span| {
291                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F161 {
292                        mma_async = mma_async,
293                        sp = sp,
294                        sync = sync,
295                        aligned = aligned,
296                        shape = shape,
297                        dtype = dtype,
298                        f16 = f16,
299                        f162 = f162,
300                        d = d,
301                        a = a,
302                        b_desc = b_desc,
303                        sp_meta = sp_meta,
304                        sp_sel = sp_sel,
305                        scale_d = scale_d,
306                        imm_scale_a = imm_scale_a,
307                        imm_scale_b = imm_scale_b,
308                        imm_trans_b = imm_trans_b,
309
310                    })
311                },
312            )
313        }
314    }
315}
316
317pub mod section_1 {
318    use super::*;
319    use crate::r#type::instruction::wgmma_mma_async_sp::section_1::*;
320
321    // ============================================================================
322    // Generated enum parsers
323    // ============================================================================
324
325    impl PtxParser for Dtype {
326        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
327            alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
328        }
329    }
330
331    impl PtxParser for Shape {
332        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
333            alt!(
334                map(string_p(".m64n104k32"), |_, _span| Shape::M64n104k32),
335                map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
336                map(string_p(".m64n120k32"), |_, _span| Shape::M64n120k32),
337                map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
338                map(string_p(".m64n136k32"), |_, _span| Shape::M64n136k32),
339                map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
340                map(string_p(".m64n152k32"), |_, _span| Shape::M64n152k32),
341                map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
342                map(string_p(".m64n168k32"), |_, _span| Shape::M64n168k32),
343                map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
344                map(string_p(".m64n184k32"), |_, _span| Shape::M64n184k32),
345                map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
346                map(string_p(".m64n200k32"), |_, _span| Shape::M64n200k32),
347                map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
348                map(string_p(".m64n216k32"), |_, _span| Shape::M64n216k32),
349                map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
350                map(string_p(".m64n232k32"), |_, _span| Shape::M64n232k32),
351                map(string_p(".m64n240k32"), |_, _span| Shape::M64n240k32),
352                map(string_p(".m64n248k32"), |_, _span| Shape::M64n248k32),
353                map(string_p(".m64n256k32"), |_, _span| Shape::M64n256k32),
354                map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
355                map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
356                map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
357                map(string_p(".m64n40k32"), |_, _span| Shape::M64n40k32),
358                map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
359                map(string_p(".m64n56k32"), |_, _span| Shape::M64n56k32),
360                map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
361                map(string_p(".m64n72k32"), |_, _span| Shape::M64n72k32),
362                map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
363                map(string_p(".m64n88k32"), |_, _span| Shape::M64n88k32),
364                map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
365                map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
366            )
367        }
368    }
369
370    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf16 {
371        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
372            try_map(
373                seq_n!(
374                    string_p("wgmma"),
375                    string_p(".mma_async"),
376                    string_p(".sp"),
377                    string_p(".sync"),
378                    string_p(".aligned"),
379                    Shape::parse(),
380                    Dtype::parse(),
381                    string_p(".bf16"),
382                    string_p(".bf16"),
383                    GeneralOperand::parse(),
384                    comma_p(),
385                    GeneralOperand::parse(),
386                    comma_p(),
387                    GeneralOperand::parse(),
388                    comma_p(),
389                    GeneralOperand::parse(),
390                    comma_p(),
391                    GeneralOperand::parse(),
392                    comma_p(),
393                    GeneralOperand::parse(),
394                    comma_p(),
395                    GeneralOperand::parse(),
396                    comma_p(),
397                    GeneralOperand::parse(),
398                    comma_p(),
399                    GeneralOperand::parse(),
400                    comma_p(),
401                    GeneralOperand::parse(),
402                    semicolon_p()
403                ),
404                |(
405                    _,
406                    mma_async,
407                    sp,
408                    sync,
409                    aligned,
410                    shape,
411                    dtype,
412                    bf16,
413                    bf162,
414                    d,
415                    _,
416                    a_desc,
417                    _,
418                    b_desc,
419                    _,
420                    sp_meta,
421                    _,
422                    sp_sel,
423                    _,
424                    scale_d,
425                    _,
426                    imm_scale_a,
427                    _,
428                    imm_scale_b,
429                    _,
430                    imm_trans_a,
431                    _,
432                    imm_trans_b,
433                    _,
434                ),
435                 span| {
436                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf16 {
437                        mma_async = mma_async,
438                        sp = sp,
439                        sync = sync,
440                        aligned = aligned,
441                        shape = shape,
442                        dtype = dtype,
443                        bf16 = bf16,
444                        bf162 = bf162,
445                        d = d,
446                        a_desc = a_desc,
447                        b_desc = b_desc,
448                        sp_meta = sp_meta,
449                        sp_sel = sp_sel,
450                        scale_d = scale_d,
451                        imm_scale_a = imm_scale_a,
452                        imm_scale_b = imm_scale_b,
453                        imm_trans_a = imm_trans_a,
454                        imm_trans_b = imm_trans_b,
455
456                    })
457                },
458            )
459        }
460    }
461
462    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf161 {
463        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
464            try_map(
465                seq_n!(
466                    string_p("wgmma"),
467                    string_p(".mma_async"),
468                    string_p(".sp"),
469                    string_p(".sync"),
470                    string_p(".aligned"),
471                    Shape::parse(),
472                    Dtype::parse(),
473                    string_p(".bf16"),
474                    string_p(".bf16"),
475                    GeneralOperand::parse(),
476                    comma_p(),
477                    GeneralOperand::parse(),
478                    comma_p(),
479                    GeneralOperand::parse(),
480                    comma_p(),
481                    GeneralOperand::parse(),
482                    comma_p(),
483                    GeneralOperand::parse(),
484                    comma_p(),
485                    GeneralOperand::parse(),
486                    comma_p(),
487                    GeneralOperand::parse(),
488                    comma_p(),
489                    GeneralOperand::parse(),
490                    comma_p(),
491                    GeneralOperand::parse(),
492                    semicolon_p()
493                ),
494                |(
495                    _,
496                    mma_async,
497                    sp,
498                    sync,
499                    aligned,
500                    shape,
501                    dtype,
502                    bf16,
503                    bf162,
504                    d,
505                    _,
506                    a,
507                    _,
508                    b_desc,
509                    _,
510                    sp_meta,
511                    _,
512                    sp_sel,
513                    _,
514                    scale_d,
515                    _,
516                    imm_scale_a,
517                    _,
518                    imm_scale_b,
519                    _,
520                    imm_trans_b,
521                    _,
522                ),
523                 span| {
524                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf161 {
525                        mma_async = mma_async,
526                        sp = sp,
527                        sync = sync,
528                        aligned = aligned,
529                        shape = shape,
530                        dtype = dtype,
531                        bf16 = bf16,
532                        bf162 = bf162,
533                        d = d,
534                        a = a,
535                        b_desc = b_desc,
536                        sp_meta = sp_meta,
537                        sp_sel = sp_sel,
538                        scale_d = scale_d,
539                        imm_scale_a = imm_scale_a,
540                        imm_scale_b = imm_scale_b,
541                        imm_trans_b = imm_trans_b,
542
543                    })
544                },
545            )
546        }
547    }
548}
549
550pub mod section_2 {
551    use super::*;
552    use crate::r#type::instruction::wgmma_mma_async_sp::section_2::*;
553
554    // ============================================================================
555    // Generated enum parsers
556    // ============================================================================
557
558    impl PtxParser for Dtype {
559        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
560            alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
561        }
562    }
563
564    impl PtxParser for Shape {
565        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
566            alt!(
567                map(string_p(".m64n104k16"), |_, _span| Shape::M64n104k16),
568                map(string_p(".m64n112k16"), |_, _span| Shape::M64n112k16),
569                map(string_p(".m64n120k16"), |_, _span| Shape::M64n120k16),
570                map(string_p(".m64n128k16"), |_, _span| Shape::M64n128k16),
571                map(string_p(".m64n136k16"), |_, _span| Shape::M64n136k16),
572                map(string_p(".m64n144k16"), |_, _span| Shape::M64n144k16),
573                map(string_p(".m64n152k16"), |_, _span| Shape::M64n152k16),
574                map(string_p(".m64n160k16"), |_, _span| Shape::M64n160k16),
575                map(string_p(".m64n168k16"), |_, _span| Shape::M64n168k16),
576                map(string_p(".m64n176k16"), |_, _span| Shape::M64n176k16),
577                map(string_p(".m64n184k16"), |_, _span| Shape::M64n184k16),
578                map(string_p(".m64n192k16"), |_, _span| Shape::M64n192k16),
579                map(string_p(".m64n200k16"), |_, _span| Shape::M64n200k16),
580                map(string_p(".m64n208k16"), |_, _span| Shape::M64n208k16),
581                map(string_p(".m64n216k16"), |_, _span| Shape::M64n216k16),
582                map(string_p(".m64n224k16"), |_, _span| Shape::M64n224k16),
583                map(string_p(".m64n232k16"), |_, _span| Shape::M64n232k16),
584                map(string_p(".m64n240k16"), |_, _span| Shape::M64n240k16),
585                map(string_p(".m64n248k16"), |_, _span| Shape::M64n248k16),
586                map(string_p(".m64n256k16"), |_, _span| Shape::M64n256k16),
587                map(string_p(".m64n16k16"), |_, _span| Shape::M64n16k16),
588                map(string_p(".m64n24k16"), |_, _span| Shape::M64n24k16),
589                map(string_p(".m64n32k16"), |_, _span| Shape::M64n32k16),
590                map(string_p(".m64n40k16"), |_, _span| Shape::M64n40k16),
591                map(string_p(".m64n48k16"), |_, _span| Shape::M64n48k16),
592                map(string_p(".m64n56k16"), |_, _span| Shape::M64n56k16),
593                map(string_p(".m64n64k16"), |_, _span| Shape::M64n64k16),
594                map(string_p(".m64n72k16"), |_, _span| Shape::M64n72k16),
595                map(string_p(".m64n80k16"), |_, _span| Shape::M64n80k16),
596                map(string_p(".m64n88k16"), |_, _span| Shape::M64n88k16),
597                map(string_p(".m64n96k16"), |_, _span| Shape::M64n96k16),
598                map(string_p(".m64n8k16"), |_, _span| Shape::M64n8k16)
599            )
600        }
601    }
602
603    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf32 {
604        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
605            try_map(
606                seq_n!(
607                    string_p("wgmma"),
608                    string_p(".mma_async"),
609                    string_p(".sp"),
610                    string_p(".sync"),
611                    string_p(".aligned"),
612                    Shape::parse(),
613                    Dtype::parse(),
614                    string_p(".tf32"),
615                    string_p(".tf32"),
616                    GeneralOperand::parse(),
617                    comma_p(),
618                    GeneralOperand::parse(),
619                    comma_p(),
620                    GeneralOperand::parse(),
621                    comma_p(),
622                    GeneralOperand::parse(),
623                    comma_p(),
624                    GeneralOperand::parse(),
625                    comma_p(),
626                    GeneralOperand::parse(),
627                    comma_p(),
628                    GeneralOperand::parse(),
629                    comma_p(),
630                    GeneralOperand::parse(),
631                    semicolon_p()
632                ),
633                |(
634                    _,
635                    mma_async,
636                    sp,
637                    sync,
638                    aligned,
639                    shape,
640                    dtype,
641                    tf32,
642                    tf322,
643                    d,
644                    _,
645                    a_desc,
646                    _,
647                    b_desc,
648                    _,
649                    sp_meta,
650                    _,
651                    sp_sel,
652                    _,
653                    scale_d,
654                    _,
655                    imm_scale_a,
656                    _,
657                    imm_scale_b,
658                    _,
659                ),
660                 span| {
661                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf32 {
662                        mma_async = mma_async,
663                        sp = sp,
664                        sync = sync,
665                        aligned = aligned,
666                        shape = shape,
667                        dtype = dtype,
668                        tf32 = tf32,
669                        tf322 = tf322,
670                        d = d,
671                        a_desc = a_desc,
672                        b_desc = b_desc,
673                        sp_meta = sp_meta,
674                        sp_sel = sp_sel,
675                        scale_d = scale_d,
676                        imm_scale_a = imm_scale_a,
677                        imm_scale_b = imm_scale_b,
678
679                    })
680                },
681            )
682        }
683    }
684
685    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf321 {
686        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
687            try_map(
688                seq_n!(
689                    string_p("wgmma"),
690                    string_p(".mma_async"),
691                    string_p(".sp"),
692                    string_p(".sync"),
693                    string_p(".aligned"),
694                    Shape::parse(),
695                    Dtype::parse(),
696                    string_p(".tf32"),
697                    string_p(".tf32"),
698                    GeneralOperand::parse(),
699                    comma_p(),
700                    GeneralOperand::parse(),
701                    comma_p(),
702                    GeneralOperand::parse(),
703                    comma_p(),
704                    GeneralOperand::parse(),
705                    comma_p(),
706                    GeneralOperand::parse(),
707                    comma_p(),
708                    GeneralOperand::parse(),
709                    comma_p(),
710                    GeneralOperand::parse(),
711                    comma_p(),
712                    GeneralOperand::parse(),
713                    semicolon_p()
714                ),
715                |(
716                    _,
717                    mma_async,
718                    sp,
719                    sync,
720                    aligned,
721                    shape,
722                    dtype,
723                    tf32,
724                    tf322,
725                    d,
726                    _,
727                    a,
728                    _,
729                    b_desc,
730                    _,
731                    sp_meta,
732                    _,
733                    sp_sel,
734                    _,
735                    scale_d,
736                    _,
737                    imm_scale_a,
738                    _,
739                    imm_scale_b,
740                    _,
741                ),
742                 span| {
743                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf321 {
744                        mma_async = mma_async,
745                        sp = sp,
746                        sync = sync,
747                        aligned = aligned,
748                        shape = shape,
749                        dtype = dtype,
750                        tf32 = tf32,
751                        tf322 = tf322,
752                        d = d,
753                        a = a,
754                        b_desc = b_desc,
755                        sp_meta = sp_meta,
756                        sp_sel = sp_sel,
757                        scale_d = scale_d,
758                        imm_scale_a = imm_scale_a,
759                        imm_scale_b = imm_scale_b,
760
761                    })
762                },
763            )
764        }
765    }
766}
767
768pub mod section_3 {
769    use super::*;
770    use crate::r#type::instruction::wgmma_mma_async_sp::section_3::*;
771
772    // ============================================================================
773    // Generated enum parsers
774    // ============================================================================
775
776    impl PtxParser for Atype {
777        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
778            alt!(
779                map(string_p(".e4m3"), |_, _span| Atype::E4m3),
780                map(string_p(".e5m2"), |_, _span| Atype::E5m2)
781            )
782        }
783    }
784
785    impl PtxParser for Btype {
786        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
787            alt!(
788                map(string_p(".e4m3"), |_, _span| Btype::E4m3),
789                map(string_p(".e5m2"), |_, _span| Btype::E5m2)
790            )
791        }
792    }
793
794    impl PtxParser for Dtype {
795        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
796            alt!(
797                map(string_p(".f16"), |_, _span| Dtype::F16),
798                map(string_p(".f32"), |_, _span| Dtype::F32)
799            )
800        }
801    }
802
803    impl PtxParser for Shape {
804        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
805            alt!(
806                map(string_p(".m64n104k64"), |_, _span| Shape::M64n104k64),
807                map(string_p(".m64n112k64"), |_, _span| Shape::M64n112k64),
808                map(string_p(".m64n120k64"), |_, _span| Shape::M64n120k64),
809                map(string_p(".m64n128k64"), |_, _span| Shape::M64n128k64),
810                map(string_p(".m64n136k64"), |_, _span| Shape::M64n136k64),
811                map(string_p(".m64n144k64"), |_, _span| Shape::M64n144k64),
812                map(string_p(".m64n152k64"), |_, _span| Shape::M64n152k64),
813                map(string_p(".m64n160k64"), |_, _span| Shape::M64n160k64),
814                map(string_p(".m64n168k64"), |_, _span| Shape::M64n168k64),
815                map(string_p(".m64n176k64"), |_, _span| Shape::M64n176k64),
816                map(string_p(".m64n184k64"), |_, _span| Shape::M64n184k64),
817                map(string_p(".m64n192k64"), |_, _span| Shape::M64n192k64),
818                map(string_p(".m64n200k64"), |_, _span| Shape::M64n200k64),
819                map(string_p(".m64n208k64"), |_, _span| Shape::M64n208k64),
820                map(string_p(".m64n216k64"), |_, _span| Shape::M64n216k64),
821                map(string_p(".m64n224k64"), |_, _span| Shape::M64n224k64),
822                map(string_p(".m64n232k64"), |_, _span| Shape::M64n232k64),
823                map(string_p(".m64n240k64"), |_, _span| Shape::M64n240k64),
824                map(string_p(".m64n248k64"), |_, _span| Shape::M64n248k64),
825                map(string_p(".m64n256k64"), |_, _span| Shape::M64n256k64),
826                map(string_p(".m64n16k64"), |_, _span| Shape::M64n16k64),
827                map(string_p(".m64n24k64"), |_, _span| Shape::M64n24k64),
828                map(string_p(".m64n32k64"), |_, _span| Shape::M64n32k64),
829                map(string_p(".m64n40k64"), |_, _span| Shape::M64n40k64),
830                map(string_p(".m64n48k64"), |_, _span| Shape::M64n48k64),
831                map(string_p(".m64n56k64"), |_, _span| Shape::M64n56k64),
832                map(string_p(".m64n64k64"), |_, _span| Shape::M64n64k64),
833                map(string_p(".m64n72k64"), |_, _span| Shape::M64n72k64),
834                map(string_p(".m64n80k64"), |_, _span| Shape::M64n80k64),
835                map(string_p(".m64n88k64"), |_, _span| Shape::M64n88k64),
836                map(string_p(".m64n96k64"), |_, _span| Shape::M64n96k64),
837                map(string_p(".m64n8k64"), |_, _span| Shape::M64n8k64)
838            )
839        }
840    }
841
842    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype {
843        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
844            try_map(
845                seq_n!(
846                    string_p("wgmma"),
847                    string_p(".mma_async"),
848                    string_p(".sp"),
849                    string_p(".sync"),
850                    string_p(".aligned"),
851                    Shape::parse(),
852                    Dtype::parse(),
853                    Atype::parse(),
854                    Btype::parse(),
855                    GeneralOperand::parse(),
856                    comma_p(),
857                    GeneralOperand::parse(),
858                    comma_p(),
859                    GeneralOperand::parse(),
860                    comma_p(),
861                    GeneralOperand::parse(),
862                    comma_p(),
863                    GeneralOperand::parse(),
864                    comma_p(),
865                    GeneralOperand::parse(),
866                    comma_p(),
867                    GeneralOperand::parse(),
868                    comma_p(),
869                    GeneralOperand::parse(),
870                    semicolon_p()
871                ),
872                |(
873                    _,
874                    mma_async,
875                    sp,
876                    sync,
877                    aligned,
878                    shape,
879                    dtype,
880                    atype,
881                    btype,
882                    d,
883                    _,
884                    a_desc,
885                    _,
886                    b_desc,
887                    _,
888                    sp_meta,
889                    _,
890                    sp_sel,
891                    _,
892                    scale_d,
893                    _,
894                    imm_scale_a,
895                    _,
896                    imm_scale_b,
897                    _,
898                ),
899                 span| {
900                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype {
901                        mma_async = mma_async,
902                        sp = sp,
903                        sync = sync,
904                        aligned = aligned,
905                        shape = shape,
906                        dtype = dtype,
907                        atype = atype,
908                        btype = btype,
909                        d = d,
910                        a_desc = a_desc,
911                        b_desc = b_desc,
912                        sp_meta = sp_meta,
913                        sp_sel = sp_sel,
914                        scale_d = scale_d,
915                        imm_scale_a = imm_scale_a,
916                        imm_scale_b = imm_scale_b,
917
918                    })
919                },
920            )
921        }
922    }
923
924    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype1 {
925        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
926            try_map(
927                seq_n!(
928                    string_p("wgmma"),
929                    string_p(".mma_async"),
930                    string_p(".sp"),
931                    string_p(".sync"),
932                    string_p(".aligned"),
933                    Shape::parse(),
934                    Dtype::parse(),
935                    Atype::parse(),
936                    Btype::parse(),
937                    GeneralOperand::parse(),
938                    comma_p(),
939                    GeneralOperand::parse(),
940                    comma_p(),
941                    GeneralOperand::parse(),
942                    comma_p(),
943                    GeneralOperand::parse(),
944                    comma_p(),
945                    GeneralOperand::parse(),
946                    comma_p(),
947                    GeneralOperand::parse(),
948                    comma_p(),
949                    GeneralOperand::parse(),
950                    comma_p(),
951                    GeneralOperand::parse(),
952                    semicolon_p()
953                ),
954                |(
955                    _,
956                    mma_async,
957                    sp,
958                    sync,
959                    aligned,
960                    shape,
961                    dtype,
962                    atype,
963                    btype,
964                    d,
965                    _,
966                    a,
967                    _,
968                    b_desc,
969                    _,
970                    sp_meta,
971                    _,
972                    sp_sel,
973                    _,
974                    scale_d,
975                    _,
976                    imm_scale_a,
977                    _,
978                    imm_scale_b,
979                    _,
980                ),
981                 span| {
982                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype1 {
983                        mma_async = mma_async,
984                        sp = sp,
985                        sync = sync,
986                        aligned = aligned,
987                        shape = shape,
988                        dtype = dtype,
989                        atype = atype,
990                        btype = btype,
991                        d = d,
992                        a = a,
993                        b_desc = b_desc,
994                        sp_meta = sp_meta,
995                        sp_sel = sp_sel,
996                        scale_d = scale_d,
997                        imm_scale_a = imm_scale_a,
998                        imm_scale_b = imm_scale_b,
999
1000                    })
1001                },
1002            )
1003        }
1004    }
1005}
1006
1007pub mod section_4 {
1008    use super::*;
1009    use crate::r#type::instruction::wgmma_mma_async_sp::section_4::*;
1010
1011    // ============================================================================
1012    // Generated enum parsers
1013    // ============================================================================
1014
1015    impl PtxParser for Atype {
1016        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1017            alt!(
1018                map(string_p(".s8"), |_, _span| Atype::S8),
1019                map(string_p(".u8"), |_, _span| Atype::U8)
1020            )
1021        }
1022    }
1023
1024    impl PtxParser for Btype {
1025        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1026            alt!(
1027                map(string_p(".s8"), |_, _span| Btype::S8),
1028                map(string_p(".u8"), |_, _span| Btype::U8)
1029            )
1030        }
1031    }
1032
1033    impl PtxParser for Shape {
1034        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1035            alt!(
1036                map(string_p(".m64n112k64"), |_, _span| Shape::M64n112k64),
1037                map(string_p(".m64n128k64"), |_, _span| Shape::M64n128k64),
1038                map(string_p(".m64n144k64"), |_, _span| Shape::M64n144k64),
1039                map(string_p(".m64n160k64"), |_, _span| Shape::M64n160k64),
1040                map(string_p(".m64n176k64"), |_, _span| Shape::M64n176k64),
1041                map(string_p(".m64n192k64"), |_, _span| Shape::M64n192k64),
1042                map(string_p(".m64n208k64"), |_, _span| Shape::M64n208k64),
1043                map(string_p(".m64n224k64"), |_, _span| Shape::M64n224k64),
1044                map(string_p(".m64n240k64"), |_, _span| Shape::M64n240k64),
1045                map(string_p(".m64n256k64"), |_, _span| Shape::M64n256k64),
1046                map(string_p(".m64n16k64"), |_, _span| Shape::M64n16k64),
1047                map(string_p(".m64n24k64"), |_, _span| Shape::M64n24k64),
1048                map(string_p(".m64n32k64"), |_, _span| Shape::M64n32k64),
1049                map(string_p(".m64n48k64"), |_, _span| Shape::M64n48k64),
1050                map(string_p(".m64n64k64"), |_, _span| Shape::M64n64k64),
1051                map(string_p(".m64n80k64"), |_, _span| Shape::M64n80k64),
1052                map(string_p(".m64n96k64"), |_, _span| Shape::M64n96k64),
1053                map(string_p(".m64n8k64"), |_, _span| Shape::M64n8k64)
1054            )
1055        }
1056    }
1057
1058    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype {
1059        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1060            try_map(
1061                seq_n!(
1062                    string_p("wgmma"),
1063                    string_p(".mma_async"),
1064                    string_p(".sp"),
1065                    string_p(".sync"),
1066                    string_p(".aligned"),
1067                    Shape::parse(),
1068                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1069                    string_p(".s32"),
1070                    Atype::parse(),
1071                    Btype::parse(),
1072                    GeneralOperand::parse(),
1073                    comma_p(),
1074                    GeneralOperand::parse(),
1075                    comma_p(),
1076                    GeneralOperand::parse(),
1077                    comma_p(),
1078                    GeneralOperand::parse(),
1079                    comma_p(),
1080                    GeneralOperand::parse(),
1081                    comma_p(),
1082                    GeneralOperand::parse(),
1083                    semicolon_p()
1084                ),
1085                |(
1086                    _,
1087                    mma_async,
1088                    sp,
1089                    sync,
1090                    aligned,
1091                    shape,
1092                    satfinite,
1093                    s32,
1094                    atype,
1095                    btype,
1096                    d,
1097                    _,
1098                    a_desc,
1099                    _,
1100                    b_desc,
1101                    _,
1102                    sp_meta,
1103                    _,
1104                    sp_sel,
1105                    _,
1106                    scale_d,
1107                    _,
1108                ),
1109                 span| {
1110                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype {
1111                        mma_async = mma_async,
1112                        sp = sp,
1113                        sync = sync,
1114                        aligned = aligned,
1115                        shape = shape,
1116                        satfinite = satfinite,
1117                        s32 = s32,
1118                        atype = atype,
1119                        btype = btype,
1120                        d = d,
1121                        a_desc = a_desc,
1122                        b_desc = b_desc,
1123                        sp_meta = sp_meta,
1124                        sp_sel = sp_sel,
1125                        scale_d = scale_d,
1126
1127                    })
1128                },
1129            )
1130        }
1131    }
1132
1133    impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1134        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1135            try_map(
1136                seq_n!(
1137                    string_p("wgmma"),
1138                    string_p(".mma_async"),
1139                    string_p(".sp"),
1140                    string_p(".sync"),
1141                    string_p(".aligned"),
1142                    Shape::parse(),
1143                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1144                    string_p(".s32"),
1145                    Atype::parse(),
1146                    Btype::parse(),
1147                    GeneralOperand::parse(),
1148                    comma_p(),
1149                    GeneralOperand::parse(),
1150                    comma_p(),
1151                    GeneralOperand::parse(),
1152                    comma_p(),
1153                    GeneralOperand::parse(),
1154                    comma_p(),
1155                    GeneralOperand::parse(),
1156                    comma_p(),
1157                    GeneralOperand::parse(),
1158                    semicolon_p()
1159                ),
1160                |(
1161                    _,
1162                    mma_async,
1163                    sp,
1164                    sync,
1165                    aligned,
1166                    shape,
1167                    satfinite,
1168                    s32,
1169                    atype,
1170                    btype,
1171                    d,
1172                    _,
1173                    a,
1174                    _,
1175                    b_desc,
1176                    _,
1177                    sp_meta,
1178                    _,
1179                    sp_sel,
1180                    _,
1181                    scale_d,
1182                    _,
1183                ),
1184                 span| {
1185                    ok!(WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1186                        mma_async = mma_async,
1187                        sp = sp,
1188                        sync = sync,
1189                        aligned = aligned,
1190                        shape = shape,
1191                        satfinite = satfinite,
1192                        s32 = s32,
1193                        atype = atype,
1194                        btype = btype,
1195                        d = d,
1196                        a = a,
1197                        b_desc = b_desc,
1198                        sp_meta = sp_meta,
1199                        sp_sel = sp_sel,
1200                        scale_d = scale_d,
1201
1202                    })
1203                },
1204            )
1205        }
1206    }
1207}