Skip to main content

ptx_parser/parser/instruction/
mma.rs

1//! Original PTX specification:
2//!
3//! // Half precision floating point type:
4//! mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype  d, a, b, c;
5//! mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype  d, a, b, c;
6//! mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c;
7//! .alayout = {.row, .col};
8//! .blayout = {.row, .col};
9//! .ctype   = {.f16, .f32};
10//! .dtype   = {.f16, .f32};
11//! ----------------------------------------------------
12//! // Alternate floating point type:
13//! // Alternate floating point type:
14//! mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32        d, a, b, c;
15//! mma.sync.aligned.m16n8k8.row.col.f32.atype.btype.f32      d, a, b, c;
16//! mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32       d, a, b, c;
17//! mma.sync.aligned.shape.row.col.dtype.f8type.f8type.ctype  d, a, b, c;
18//! mma.sync.aligned.m16n8k32.row.col.kind.dtype.f8f6f4type.f8f6f4type.ctype d, a, b, c;
19//! .atype      = {.bf16, .tf32};
20//! .btype      = {.bf16, .tf32};
21//! .f8type     = {.e4m3, .e5m2};
22//! .f8f6f4type = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
23//! .ctype      = {.f16, .f32};
24//! .dtype      = {.f16, .f32};
25//! .shape      = {.m16n8k16, .m16n8k32};
26//! .kind       = {.kind::f8f6f4};
27//! ----------------------------------------------------
28//! // Alternate floating point type:
29//! // Alternate floating point type with block scaling:
30//! mma.sync.aligned.m16n8k64.row.col.kind.block_scale{.scale_vec_size}.f32.e2m1.e2m1.f32.stype d, a, b, c, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};
31//! .kind           = {.kind::mxf4};
32//! .scale_vec_size = {.scale_vec::2X};
33//! .stype          = {.ue8m0};
34//! ----------------------------------------------------
35//! // Alternate floating point type:
36//! mma.sync.aligned.m16n8k64.row.col.kind.block_scale.scale_vec_size.f32.e2m1.e2m1.f32.stype d, a, b, c, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};
37//! .kind           = {.kind::mxf4nvf4};
38//! .scale_vec_size = {.scale_vec::2X, .scale_vec::4X};
39//! .stype          = {.ue8m0, .ue4m3};
40//! ----------------------------------------------------
41//! // Alternate floating point type:
42//! mma.sync.aligned.m16n8k32.row.col.kind.block_scale{.scale_vec_size}.f32.f8f6f4type.f8f6f4type.f32.stype d, a, b, c, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};
43//! .kind           = {.kind::mxf8f6f4};
44//! .scale_vec_size = {.scale_vec::1X};
45//! .f8f6f4type     = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
46//! .stype          = {.ue8m0};
47//! ----------------------------------------------------
48//! // Alternate floating point type:
49//! // Double precision floating point type:
50//! mma.sync.aligned.shape.row.col.f64.f64.f64.f64 d, a, b, c;
51//! .shape   = {.m8n84, .m16n8k4, .m16n8k8, .m16n8k16};
52//! ----------------------------------------------------
53//! // Alternate floating point type:
54//! // Integer type:
55//! mma.sync.aligned.shape.row.col{.satfinite}.s32.atype.btype.s32 d, a, b, c;
56//! .shape   = {.m8n8k16, .m16n8k16, .m16n8k32};
57//! .atype   = {.u8, .s8};
58//! .btype   = {.u8, .s8};
59//! ----------------------------------------------------
60//! // Alternate floating point type:
61//! mma.sync.aligned.shape.row.col{.satfinite}.s32.atype.btype.s32 d, a, b, c;
62//! .shape   = {.m8n8k32, .m16n8k32, .m16n8k64};
63//! .atype   = {.u4, .s4};
64//! .btype   = {.u4, .s4};
65//! ----------------------------------------------------
66//! // Alternate floating point type:
67//! // Single bit:
68//! mma.sync.aligned.shape.row.col.s32.b1.b1.s32.bitOp.popc d, a, b, c;
69//! .bitOp = {.xor, .and};
70//! .shape = {.m8n8k128, .m16n8k128, .m16n8k256};
71
72#![allow(unused)]
73
74use crate::parser::{
75    PtxParseError, PtxParser, PtxTokenStream, Span,
76    util::{
77        between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
78        pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
79    },
80};
81use crate::r#type::common::*;
82use crate::{alt, ok, seq_n};
83
84pub mod section_0 {
85    use super::*;
86    use crate::r#type::instruction::mma::section_0::*;
87
88    // ============================================================================
89    // Generated enum parsers
90    // ============================================================================
91
92    impl PtxParser for Alayout {
93        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
94            alt!(
95                map(string_p(".row"), |_, _span| Alayout::Row),
96                map(string_p(".col"), |_, _span| Alayout::Col)
97            )
98        }
99    }
100
101    impl PtxParser for Blayout {
102        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
103            alt!(
104                map(string_p(".row"), |_, _span| Blayout::Row),
105                map(string_p(".col"), |_, _span| Blayout::Col)
106            )
107        }
108    }
109
110    impl PtxParser for Ctype {
111        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
112            alt!(
113                map(string_p(".f16"), |_, _span| Ctype::F16),
114                map(string_p(".f32"), |_, _span| Ctype::F32)
115            )
116        }
117    }
118
119    impl PtxParser for Dtype {
120        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
121            alt!(
122                map(string_p(".f16"), |_, _span| Dtype::F16),
123                map(string_p(".f32"), |_, _span| Dtype::F32)
124            )
125        }
126    }
127
128    impl PtxParser for MmaSyncAlignedM8n8k4AlayoutBlayoutDtypeF16F16Ctype {
129        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
130            try_map(
131                seq_n!(
132                    string_p("mma"),
133                    string_p(".sync"),
134                    string_p(".aligned"),
135                    string_p(".m8n8k4"),
136                    Alayout::parse(),
137                    Blayout::parse(),
138                    Dtype::parse(),
139                    string_p(".f16"),
140                    string_p(".f16"),
141                    Ctype::parse(),
142                    GeneralOperand::parse(),
143                    comma_p(),
144                    GeneralOperand::parse(),
145                    comma_p(),
146                    GeneralOperand::parse(),
147                    comma_p(),
148                    GeneralOperand::parse(),
149                    semicolon_p()
150                ),
151                |(
152                    _,
153                    sync,
154                    aligned,
155                    m8n8k4,
156                    alayout,
157                    blayout,
158                    dtype,
159                    f16,
160                    f162,
161                    ctype,
162                    d,
163                    _,
164                    a,
165                    _,
166                    b,
167                    _,
168                    c,
169                    _,
170                ),
171                 span| {
172                    ok!(MmaSyncAlignedM8n8k4AlayoutBlayoutDtypeF16F16Ctype {
173                        sync = sync,
174                        aligned = aligned,
175                        m8n8k4 = m8n8k4,
176                        alayout = alayout,
177                        blayout = blayout,
178                        dtype = dtype,
179                        f16 = f16,
180                        f162 = f162,
181                        ctype = ctype,
182                        d = d,
183                        a = a,
184                        b = b,
185                        c = c,
186
187                    })
188                },
189            )
190        }
191    }
192
193    impl PtxParser for MmaSyncAlignedM16n8k8RowColDtypeF16F16Ctype {
194        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
195            try_map(
196                seq_n!(
197                    string_p("mma"),
198                    string_p(".sync"),
199                    string_p(".aligned"),
200                    string_p(".m16n8k8"),
201                    string_p(".row"),
202                    string_p(".col"),
203                    Dtype::parse(),
204                    string_p(".f16"),
205                    string_p(".f16"),
206                    Ctype::parse(),
207                    GeneralOperand::parse(),
208                    comma_p(),
209                    GeneralOperand::parse(),
210                    comma_p(),
211                    GeneralOperand::parse(),
212                    comma_p(),
213                    GeneralOperand::parse(),
214                    semicolon_p()
215                ),
216                |(
217                    _,
218                    sync,
219                    aligned,
220                    m16n8k8,
221                    row,
222                    col,
223                    dtype,
224                    f16,
225                    f162,
226                    ctype,
227                    d,
228                    _,
229                    a,
230                    _,
231                    b,
232                    _,
233                    c,
234                    _,
235                ),
236                 span| {
237                    ok!(MmaSyncAlignedM16n8k8RowColDtypeF16F16Ctype {
238                        sync = sync,
239                        aligned = aligned,
240                        m16n8k8 = m16n8k8,
241                        row = row,
242                        col = col,
243                        dtype = dtype,
244                        f16 = f16,
245                        f162 = f162,
246                        ctype = ctype,
247                        d = d,
248                        a = a,
249                        b = b,
250                        c = c,
251
252                    })
253                },
254            )
255        }
256    }
257
258    impl PtxParser for MmaSyncAlignedM16n8k16RowColDtypeF16F16Ctype {
259        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
260            try_map(
261                seq_n!(
262                    string_p("mma"),
263                    string_p(".sync"),
264                    string_p(".aligned"),
265                    string_p(".m16n8k16"),
266                    string_p(".row"),
267                    string_p(".col"),
268                    Dtype::parse(),
269                    string_p(".f16"),
270                    string_p(".f16"),
271                    Ctype::parse(),
272                    GeneralOperand::parse(),
273                    comma_p(),
274                    GeneralOperand::parse(),
275                    comma_p(),
276                    GeneralOperand::parse(),
277                    comma_p(),
278                    GeneralOperand::parse(),
279                    semicolon_p()
280                ),
281                |(
282                    _,
283                    sync,
284                    aligned,
285                    m16n8k16,
286                    row,
287                    col,
288                    dtype,
289                    f16,
290                    f162,
291                    ctype,
292                    d,
293                    _,
294                    a,
295                    _,
296                    b,
297                    _,
298                    c,
299                    _,
300                ),
301                 span| {
302                    ok!(MmaSyncAlignedM16n8k16RowColDtypeF16F16Ctype {
303                        sync = sync,
304                        aligned = aligned,
305                        m16n8k16 = m16n8k16,
306                        row = row,
307                        col = col,
308                        dtype = dtype,
309                        f16 = f16,
310                        f162 = f162,
311                        ctype = ctype,
312                        d = d,
313                        a = a,
314                        b = b,
315                        c = c,
316
317                    })
318                },
319            )
320        }
321    }
322}
323
324pub mod section_1 {
325    use super::*;
326    use crate::r#type::instruction::mma::section_1::*;
327
328    // ============================================================================
329    // Generated enum parsers
330    // ============================================================================
331
332    impl PtxParser for Atype {
333        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
334            alt!(
335                map(string_p(".bf16"), |_, _span| Atype::Bf16),
336                map(string_p(".tf32"), |_, _span| Atype::Tf32)
337            )
338        }
339    }
340
341    impl PtxParser for Btype {
342        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
343            alt!(
344                map(string_p(".bf16"), |_, _span| Btype::Bf16),
345                map(string_p(".tf32"), |_, _span| Btype::Tf32)
346            )
347        }
348    }
349
350    impl PtxParser for Ctype {
351        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
352            alt!(
353                map(string_p(".f16"), |_, _span| Ctype::F16),
354                map(string_p(".f32"), |_, _span| Ctype::F32)
355            )
356        }
357    }
358
359    impl PtxParser for Dtype {
360        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
361            alt!(
362                map(string_p(".f16"), |_, _span| Dtype::F16),
363                map(string_p(".f32"), |_, _span| Dtype::F32)
364            )
365        }
366    }
367
368    impl PtxParser for F8f6f4type {
369        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
370            alt!(
371                map(string_p(".e4m3"), |_, _span| F8f6f4type::E4m3),
372                map(string_p(".e5m2"), |_, _span| F8f6f4type::E5m2),
373                map(string_p(".e3m2"), |_, _span| F8f6f4type::E3m2),
374                map(string_p(".e2m3"), |_, _span| F8f6f4type::E2m3),
375                map(string_p(".e2m1"), |_, _span| F8f6f4type::E2m1)
376            )
377        }
378    }
379
380    impl PtxParser for F8type {
381        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
382            alt!(
383                map(string_p(".e4m3"), |_, _span| F8type::E4m3),
384                map(string_p(".e5m2"), |_, _span| F8type::E5m2)
385            )
386        }
387    }
388
389    impl PtxParser for Kind {
390        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
391            alt!(map(string_p(".kind::f8f6f4"), |_, _span| Kind::KindF8f6f4))
392        }
393    }
394
395    impl PtxParser for Shape {
396        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
397            alt!(
398                map(string_p(".m16n8k16"), |_, _span| Shape::M16n8k16),
399                map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32)
400            )
401        }
402    }
403
404    impl PtxParser for MmaSyncAlignedM16n8k4RowColF32Tf32Tf32F32 {
405        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
406            try_map(
407                seq_n!(
408                    string_p("mma"),
409                    string_p(".sync"),
410                    string_p(".aligned"),
411                    string_p(".m16n8k4"),
412                    string_p(".row"),
413                    string_p(".col"),
414                    string_p(".f32"),
415                    string_p(".tf32"),
416                    string_p(".tf32"),
417                    string_p(".f32"),
418                    GeneralOperand::parse(),
419                    comma_p(),
420                    GeneralOperand::parse(),
421                    comma_p(),
422                    GeneralOperand::parse(),
423                    comma_p(),
424                    GeneralOperand::parse(),
425                    semicolon_p()
426                ),
427                |(
428                    _,
429                    sync,
430                    aligned,
431                    m16n8k4,
432                    row,
433                    col,
434                    f32,
435                    tf32,
436                    tf322,
437                    f322,
438                    d,
439                    _,
440                    a,
441                    _,
442                    b,
443                    _,
444                    c,
445                    _,
446                ),
447                 span| {
448                    ok!(MmaSyncAlignedM16n8k4RowColF32Tf32Tf32F32 {
449                        sync = sync,
450                        aligned = aligned,
451                        m16n8k4 = m16n8k4,
452                        row = row,
453                        col = col,
454                        f32 = f32,
455                        tf32 = tf32,
456                        tf322 = tf322,
457                        f322 = f322,
458                        d = d,
459                        a = a,
460                        b = b,
461                        c = c,
462
463                    })
464                },
465            )
466        }
467    }
468
469    impl PtxParser for MmaSyncAlignedM16n8k8RowColF32AtypeBtypeF32 {
470        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
471            try_map(
472                seq_n!(
473                    string_p("mma"),
474                    string_p(".sync"),
475                    string_p(".aligned"),
476                    string_p(".m16n8k8"),
477                    string_p(".row"),
478                    string_p(".col"),
479                    string_p(".f32"),
480                    Atype::parse(),
481                    Btype::parse(),
482                    string_p(".f32"),
483                    GeneralOperand::parse(),
484                    comma_p(),
485                    GeneralOperand::parse(),
486                    comma_p(),
487                    GeneralOperand::parse(),
488                    comma_p(),
489                    GeneralOperand::parse(),
490                    semicolon_p()
491                ),
492                |(
493                    _,
494                    sync,
495                    aligned,
496                    m16n8k8,
497                    row,
498                    col,
499                    f32,
500                    atype,
501                    btype,
502                    f322,
503                    d,
504                    _,
505                    a,
506                    _,
507                    b,
508                    _,
509                    c,
510                    _,
511                ),
512                 span| {
513                    ok!(MmaSyncAlignedM16n8k8RowColF32AtypeBtypeF32 {
514                        sync = sync,
515                        aligned = aligned,
516                        m16n8k8 = m16n8k8,
517                        row = row,
518                        col = col,
519                        f32 = f32,
520                        atype = atype,
521                        btype = btype,
522                        f322 = f322,
523                        d = d,
524                        a = a,
525                        b = b,
526                        c = c,
527
528                    })
529                },
530            )
531        }
532    }
533
534    impl PtxParser for MmaSyncAlignedM16n8k16RowColF32Bf16Bf16F32 {
535        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
536            try_map(
537                seq_n!(
538                    string_p("mma"),
539                    string_p(".sync"),
540                    string_p(".aligned"),
541                    string_p(".m16n8k16"),
542                    string_p(".row"),
543                    string_p(".col"),
544                    string_p(".f32"),
545                    string_p(".bf16"),
546                    string_p(".bf16"),
547                    string_p(".f32"),
548                    GeneralOperand::parse(),
549                    comma_p(),
550                    GeneralOperand::parse(),
551                    comma_p(),
552                    GeneralOperand::parse(),
553                    comma_p(),
554                    GeneralOperand::parse(),
555                    semicolon_p()
556                ),
557                |(
558                    _,
559                    sync,
560                    aligned,
561                    m16n8k16,
562                    row,
563                    col,
564                    f32,
565                    bf16,
566                    bf162,
567                    f322,
568                    d,
569                    _,
570                    a,
571                    _,
572                    b,
573                    _,
574                    c,
575                    _,
576                ),
577                 span| {
578                    ok!(MmaSyncAlignedM16n8k16RowColF32Bf16Bf16F32 {
579                        sync = sync,
580                        aligned = aligned,
581                        m16n8k16 = m16n8k16,
582                        row = row,
583                        col = col,
584                        f32 = f32,
585                        bf16 = bf16,
586                        bf162 = bf162,
587                        f322 = f322,
588                        d = d,
589                        a = a,
590                        b = b,
591                        c = c,
592
593                    })
594                },
595            )
596        }
597    }
598
599    impl PtxParser for MmaSyncAlignedShapeRowColDtypeF8typeF8typeCtype {
600        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
601            try_map(
602                seq_n!(
603                    string_p("mma"),
604                    string_p(".sync"),
605                    string_p(".aligned"),
606                    Shape::parse(),
607                    string_p(".row"),
608                    string_p(".col"),
609                    Dtype::parse(),
610                    F8type::parse(),
611                    F8type::parse(),
612                    Ctype::parse(),
613                    GeneralOperand::parse(),
614                    comma_p(),
615                    GeneralOperand::parse(),
616                    comma_p(),
617                    GeneralOperand::parse(),
618                    comma_p(),
619                    GeneralOperand::parse(),
620                    semicolon_p()
621                ),
622                |(
623                    _,
624                    sync,
625                    aligned,
626                    shape,
627                    row,
628                    col,
629                    dtype,
630                    f8type,
631                    f8type1,
632                    ctype,
633                    d,
634                    _,
635                    a,
636                    _,
637                    b,
638                    _,
639                    c,
640                    _,
641                ),
642                 span| {
643                    ok!(MmaSyncAlignedShapeRowColDtypeF8typeF8typeCtype {
644                        sync = sync,
645                        aligned = aligned,
646                        shape = shape,
647                        row = row,
648                        col = col,
649                        dtype = dtype,
650                        f8type = f8type,
651                        f8type1 = f8type1,
652                        ctype = ctype,
653                        d = d,
654                        a = a,
655                        b = b,
656                        c = c,
657
658                    })
659                },
660            )
661        }
662    }
663
664    impl PtxParser for MmaSyncAlignedM16n8k32RowColKindDtypeF8f6f4typeF8f6f4typeCtype {
665        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
666            try_map(
667                seq_n!(
668                    string_p("mma"),
669                    string_p(".sync"),
670                    string_p(".aligned"),
671                    string_p(".m16n8k32"),
672                    string_p(".row"),
673                    string_p(".col"),
674                    Kind::parse(),
675                    Dtype::parse(),
676                    F8f6f4type::parse(),
677                    F8f6f4type::parse(),
678                    Ctype::parse(),
679                    GeneralOperand::parse(),
680                    comma_p(),
681                    GeneralOperand::parse(),
682                    comma_p(),
683                    GeneralOperand::parse(),
684                    comma_p(),
685                    GeneralOperand::parse(),
686                    semicolon_p()
687                ),
688                |(
689                    _,
690                    sync,
691                    aligned,
692                    m16n8k32,
693                    row,
694                    col,
695                    kind,
696                    dtype,
697                    f8f6f4type,
698                    f8f6f4type1,
699                    ctype,
700                    d,
701                    _,
702                    a,
703                    _,
704                    b,
705                    _,
706                    c,
707                    _,
708                ),
709                 span| {
710                    ok!(MmaSyncAlignedM16n8k32RowColKindDtypeF8f6f4typeF8f6f4typeCtype {
711                        sync = sync,
712                        aligned = aligned,
713                        m16n8k32 = m16n8k32,
714                        row = row,
715                        col = col,
716                        kind = kind,
717                        dtype = dtype,
718                        f8f6f4type = f8f6f4type,
719                        f8f6f4type1 = f8f6f4type1,
720                        ctype = ctype,
721                        d = d,
722                        a = a,
723                        b = b,
724                        c = c,
725
726                    })
727                },
728            )
729        }
730    }
731}
732
733pub mod section_2 {
734    use super::*;
735    use crate::r#type::instruction::mma::section_2::*;
736
737    // ============================================================================
738    // Generated enum parsers
739    // ============================================================================
740
741    impl PtxParser for Kind {
742        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
743            alt!(map(string_p(".kind::mxf4"), |_, _span| Kind::KindMxf4))
744        }
745    }
746
747    impl PtxParser for ScaleVecSize {
748        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
749            alt!(map(string_p(".scale_vec::2X"), |_, _span| {
750                ScaleVecSize::ScaleVec2x
751            }))
752        }
753    }
754
755    impl PtxParser for Stype {
756        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
757            alt!(map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0))
758        }
759    }
760
761    impl PtxParser for MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype {
762        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
763            try_map(
764                seq_n!(
765                    string_p("mma"),
766                    string_p(".sync"),
767                    string_p(".aligned"),
768                    string_p(".m16n8k64"),
769                    string_p(".row"),
770                    string_p(".col"),
771                    Kind::parse(),
772                    string_p(".block_scale"),
773                    optional(ScaleVecSize::parse()),
774                    string_p(".f32"),
775                    string_p(".e2m1"),
776                    string_p(".e2m1"),
777                    string_p(".f32"),
778                    Stype::parse(),
779                    GeneralOperand::parse(),
780                    comma_p(),
781                    GeneralOperand::parse(),
782                    comma_p(),
783                    GeneralOperand::parse(),
784                    comma_p(),
785                    GeneralOperand::parse(),
786                    comma_p(),
787                    GeneralOperand::parse(),
788                    comma_p(),
789                    VectorOperand::parse(),
790                    comma_p(),
791                    GeneralOperand::parse(),
792                    comma_p(),
793                    VectorOperand::parse(),
794                    semicolon_p()
795                ),
796                |(
797                    _,
798                    sync,
799                    aligned,
800                    m16n8k64,
801                    row,
802                    col,
803                    kind,
804                    block_scale,
805                    scale_vec_size,
806                    f32,
807                    e2m1,
808                    e2m12,
809                    f322,
810                    stype,
811                    d,
812                    _,
813                    a,
814                    _,
815                    b,
816                    _,
817                    c,
818                    _,
819                    scale_a_data,
820                    _,
821                    byte_id_a,
822                    _,
823                    scale_b_data,
824                    _,
825                    byte_id_b,
826                    _,
827                ),
828                 span| {
829                    ok!(MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype {
830                        sync = sync,
831                        aligned = aligned,
832                        m16n8k64 = m16n8k64,
833                        row = row,
834                        col = col,
835                        kind = kind,
836                        block_scale = block_scale,
837                        scale_vec_size = scale_vec_size,
838                        f32 = f32,
839                        e2m1 = e2m1,
840                        e2m12 = e2m12,
841                        f322 = f322,
842                        stype = stype,
843                        d = d,
844                        a = a,
845                        b = b,
846                        c = c,
847                        scale_a_data = scale_a_data,
848                        byte_id_a = byte_id_a,
849                        scale_b_data = scale_b_data,
850                        byte_id_b = byte_id_b,
851
852                    })
853                },
854            )
855        }
856    }
857}
858
859pub mod section_3 {
860    use super::*;
861    use crate::r#type::instruction::mma::section_3::*;
862
863    // ============================================================================
864    // Generated enum parsers
865    // ============================================================================
866
867    impl PtxParser for Kind {
868        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
869            alt!(map(string_p(".kind::mxf4nvf4"), |_, _span| {
870                Kind::KindMxf4nvf4
871            }))
872        }
873    }
874
875    impl PtxParser for ScaleVecSize {
876        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
877            alt!(
878                map(string_p(".scale_vec::2X"), |_, _span| {
879                    ScaleVecSize::ScaleVec2x
880                }),
881                map(string_p(".scale_vec::4X"), |_, _span| {
882                    ScaleVecSize::ScaleVec4x
883                })
884            )
885        }
886    }
887
888    impl PtxParser for Stype {
889        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
890            alt!(
891                map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0),
892                map(string_p(".ue4m3"), |_, _span| Stype::Ue4m3)
893            )
894        }
895    }
896
897    impl PtxParser for MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype1 {
898        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
899            try_map(
900                seq_n!(
901                    string_p("mma"),
902                    string_p(".sync"),
903                    string_p(".aligned"),
904                    string_p(".m16n8k64"),
905                    string_p(".row"),
906                    string_p(".col"),
907                    Kind::parse(),
908                    string_p(".block_scale"),
909                    ScaleVecSize::parse(),
910                    string_p(".f32"),
911                    string_p(".e2m1"),
912                    string_p(".e2m1"),
913                    string_p(".f32"),
914                    Stype::parse(),
915                    GeneralOperand::parse(),
916                    comma_p(),
917                    GeneralOperand::parse(),
918                    comma_p(),
919                    GeneralOperand::parse(),
920                    comma_p(),
921                    GeneralOperand::parse(),
922                    comma_p(),
923                    GeneralOperand::parse(),
924                    comma_p(),
925                    VectorOperand::parse(),
926                    comma_p(),
927                    GeneralOperand::parse(),
928                    comma_p(),
929                    VectorOperand::parse(),
930                    semicolon_p()
931                ),
932                |(
933                    _,
934                    sync,
935                    aligned,
936                    m16n8k64,
937                    row,
938                    col,
939                    kind,
940                    block_scale,
941                    scale_vec_size,
942                    f32,
943                    e2m1,
944                    e2m12,
945                    f322,
946                    stype,
947                    d,
948                    _,
949                    a,
950                    _,
951                    b,
952                    _,
953                    c,
954                    _,
955                    scale_a_data,
956                    _,
957                    byte_id_a,
958                    _,
959                    scale_b_data,
960                    _,
961                    byte_id_b,
962                    _,
963                ),
964                 span| {
965                    ok!(MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype1 {
966                        sync = sync,
967                        aligned = aligned,
968                        m16n8k64 = m16n8k64,
969                        row = row,
970                        col = col,
971                        kind = kind,
972                        block_scale = block_scale,
973                        scale_vec_size = scale_vec_size,
974                        f32 = f32,
975                        e2m1 = e2m1,
976                        e2m12 = e2m12,
977                        f322 = f322,
978                        stype = stype,
979                        d = d,
980                        a = a,
981                        b = b,
982                        c = c,
983                        scale_a_data = scale_a_data,
984                        byte_id_a = byte_id_a,
985                        scale_b_data = scale_b_data,
986                        byte_id_b = byte_id_b,
987
988                    })
989                },
990            )
991        }
992    }
993}
994
995pub mod section_4 {
996    use super::*;
997    use crate::r#type::instruction::mma::section_4::*;
998
999    // ============================================================================
1000    // Generated enum parsers
1001    // ============================================================================
1002
1003    impl PtxParser for F8f6f4type {
1004        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1005            alt!(
1006                map(string_p(".e4m3"), |_, _span| F8f6f4type::E4m3),
1007                map(string_p(".e5m2"), |_, _span| F8f6f4type::E5m2),
1008                map(string_p(".e3m2"), |_, _span| F8f6f4type::E3m2),
1009                map(string_p(".e2m3"), |_, _span| F8f6f4type::E2m3),
1010                map(string_p(".e2m1"), |_, _span| F8f6f4type::E2m1)
1011            )
1012        }
1013    }
1014
1015    impl PtxParser for Kind {
1016        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1017            alt!(map(string_p(".kind::mxf8f6f4"), |_, _span| {
1018                Kind::KindMxf8f6f4
1019            }))
1020        }
1021    }
1022
1023    impl PtxParser for ScaleVecSize {
1024        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1025            alt!(map(string_p(".scale_vec::1X"), |_, _span| {
1026                ScaleVecSize::ScaleVec1x
1027            }))
1028        }
1029    }
1030
1031    impl PtxParser for Stype {
1032        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1033            alt!(map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0))
1034        }
1035    }
1036
1037    impl PtxParser
1038        for MmaSyncAlignedM16n8k32RowColKindBlockScaleScaleVecSizeF32F8f6f4typeF8f6f4typeF32Stype
1039    {
1040        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1041            try_map(
1042                seq_n!(
1043                    string_p("mma"),
1044                    string_p(".sync"),
1045                    string_p(".aligned"),
1046                    string_p(".m16n8k32"),
1047                    string_p(".row"),
1048                    string_p(".col"),
1049                    Kind::parse(),
1050                    string_p(".block_scale"),
1051                    optional(ScaleVecSize::parse()),
1052                    string_p(".f32"),
1053                    F8f6f4type::parse(),
1054                    F8f6f4type::parse(),
1055                    string_p(".f32"),
1056                    Stype::parse(),
1057                    GeneralOperand::parse(),
1058                    comma_p(),
1059                    GeneralOperand::parse(),
1060                    comma_p(),
1061                    GeneralOperand::parse(),
1062                    comma_p(),
1063                    GeneralOperand::parse(),
1064                    comma_p(),
1065                    GeneralOperand::parse(),
1066                    comma_p(),
1067                    VectorOperand::parse(),
1068                    comma_p(),
1069                    GeneralOperand::parse(),
1070                    comma_p(),
1071                    VectorOperand::parse(),
1072                    semicolon_p()
1073                ),
1074                |(
1075                    _,
1076                    sync,
1077                    aligned,
1078                    m16n8k32,
1079                    row,
1080                    col,
1081                    kind,
1082                    block_scale,
1083                    scale_vec_size,
1084                    f32,
1085                    f8f6f4type,
1086                    f8f6f4type1,
1087                    f322,
1088                    stype,
1089                    d,
1090                    _,
1091                    a,
1092                    _,
1093                    b,
1094                    _,
1095                    c,
1096                    _,
1097                    scale_a_data,
1098                    _,
1099                    byte_id_a,
1100                    _,
1101                    scale_b_data,
1102                    _,
1103                    byte_id_b,
1104                    _,
1105                ),
1106                 span| {
1107                    ok!(MmaSyncAlignedM16n8k32RowColKindBlockScaleScaleVecSizeF32F8f6f4typeF8f6f4typeF32Stype {
1108                        sync = sync,
1109                        aligned = aligned,
1110                        m16n8k32 = m16n8k32,
1111                        row = row,
1112                        col = col,
1113                        kind = kind,
1114                        block_scale = block_scale,
1115                        scale_vec_size = scale_vec_size,
1116                        f32 = f32,
1117                        f8f6f4type = f8f6f4type,
1118                        f8f6f4type1 = f8f6f4type1,
1119                        f322 = f322,
1120                        stype = stype,
1121                        d = d,
1122                        a = a,
1123                        b = b,
1124                        c = c,
1125                        scale_a_data = scale_a_data,
1126                        byte_id_a = byte_id_a,
1127                        scale_b_data = scale_b_data,
1128                        byte_id_b = byte_id_b,
1129
1130                    })
1131                },
1132            )
1133        }
1134    }
1135}
1136
1137pub mod section_5 {
1138    use super::*;
1139    use crate::r#type::instruction::mma::section_5::*;
1140
1141    // ============================================================================
1142    // Generated enum parsers
1143    // ============================================================================
1144
1145    impl PtxParser for Shape {
1146        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1147            alt!(
1148                map(string_p(".m16n8k16"), |_, _span| Shape::M16n8k16),
1149                map(string_p(".m16n8k4"), |_, _span| Shape::M16n8k4),
1150                map(string_p(".m16n8k8"), |_, _span| Shape::M16n8k8),
1151                map(string_p(".m8n84"), |_, _span| Shape::M8n84)
1152            )
1153        }
1154    }
1155
1156    impl PtxParser for MmaSyncAlignedShapeRowColF64F64F64F64 {
1157        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1158            try_map(
1159                seq_n!(
1160                    string_p("mma"),
1161                    string_p(".sync"),
1162                    string_p(".aligned"),
1163                    Shape::parse(),
1164                    string_p(".row"),
1165                    string_p(".col"),
1166                    string_p(".f64"),
1167                    string_p(".f64"),
1168                    string_p(".f64"),
1169                    string_p(".f64"),
1170                    GeneralOperand::parse(),
1171                    comma_p(),
1172                    GeneralOperand::parse(),
1173                    comma_p(),
1174                    GeneralOperand::parse(),
1175                    comma_p(),
1176                    GeneralOperand::parse(),
1177                    semicolon_p()
1178                ),
1179                |(
1180                    _,
1181                    sync,
1182                    aligned,
1183                    shape,
1184                    row,
1185                    col,
1186                    f64,
1187                    f642,
1188                    f644,
1189                    f646,
1190                    d,
1191                    _,
1192                    a,
1193                    _,
1194                    b,
1195                    _,
1196                    c,
1197                    _,
1198                ),
1199                 span| {
1200                    ok!(MmaSyncAlignedShapeRowColF64F64F64F64 {
1201                        sync = sync,
1202                        aligned = aligned,
1203                        shape = shape,
1204                        row = row,
1205                        col = col,
1206                        f64 = f64,
1207                        f642 = f642,
1208                        f644 = f644,
1209                        f646 = f646,
1210                        d = d,
1211                        a = a,
1212                        b = b,
1213                        c = c,
1214
1215                    })
1216                },
1217            )
1218        }
1219    }
1220}
1221
1222pub mod section_6 {
1223    use super::*;
1224    use crate::r#type::instruction::mma::section_6::*;
1225
1226    // ============================================================================
1227    // Generated enum parsers
1228    // ============================================================================
1229
1230    impl PtxParser for Atype {
1231        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1232            alt!(
1233                map(string_p(".u8"), |_, _span| Atype::U8),
1234                map(string_p(".s8"), |_, _span| Atype::S8)
1235            )
1236        }
1237    }
1238
1239    impl PtxParser for Btype {
1240        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1241            alt!(
1242                map(string_p(".u8"), |_, _span| Btype::U8),
1243                map(string_p(".s8"), |_, _span| Btype::S8)
1244            )
1245        }
1246    }
1247
1248    impl PtxParser for Shape {
1249        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1250            alt!(
1251                map(string_p(".m16n8k16"), |_, _span| Shape::M16n8k16),
1252                map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32),
1253                map(string_p(".m8n8k16"), |_, _span| Shape::M8n8k16)
1254            )
1255        }
1256    }
1257
1258    impl PtxParser for MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS32 {
1259        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1260            try_map(
1261                seq_n!(
1262                    string_p("mma"),
1263                    string_p(".sync"),
1264                    string_p(".aligned"),
1265                    Shape::parse(),
1266                    string_p(".row"),
1267                    string_p(".col"),
1268                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1269                    string_p(".s32"),
1270                    Atype::parse(),
1271                    Btype::parse(),
1272                    string_p(".s32"),
1273                    GeneralOperand::parse(),
1274                    comma_p(),
1275                    GeneralOperand::parse(),
1276                    comma_p(),
1277                    GeneralOperand::parse(),
1278                    comma_p(),
1279                    GeneralOperand::parse(),
1280                    semicolon_p()
1281                ),
1282                |(
1283                    _,
1284                    sync,
1285                    aligned,
1286                    shape,
1287                    row,
1288                    col,
1289                    satfinite,
1290                    s32,
1291                    atype,
1292                    btype,
1293                    s322,
1294                    d,
1295                    _,
1296                    a,
1297                    _,
1298                    b,
1299                    _,
1300                    c,
1301                    _,
1302                ),
1303                 span| {
1304                    ok!(MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS32 {
1305                        sync = sync,
1306                        aligned = aligned,
1307                        shape = shape,
1308                        row = row,
1309                        col = col,
1310                        satfinite = satfinite,
1311                        s32 = s32,
1312                        atype = atype,
1313                        btype = btype,
1314                        s322 = s322,
1315                        d = d,
1316                        a = a,
1317                        b = b,
1318                        c = c,
1319
1320                    })
1321                },
1322            )
1323        }
1324    }
1325}
1326
1327pub mod section_7 {
1328    use super::*;
1329    use crate::r#type::instruction::mma::section_7::*;
1330
1331    // ============================================================================
1332    // Generated enum parsers
1333    // ============================================================================
1334
1335    impl PtxParser for Atype {
1336        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1337            alt!(
1338                map(string_p(".u4"), |_, _span| Atype::U4),
1339                map(string_p(".s4"), |_, _span| Atype::S4)
1340            )
1341        }
1342    }
1343
1344    impl PtxParser for Btype {
1345        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1346            alt!(
1347                map(string_p(".u4"), |_, _span| Btype::U4),
1348                map(string_p(".s4"), |_, _span| Btype::S4)
1349            )
1350        }
1351    }
1352
1353    impl PtxParser for Shape {
1354        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1355            alt!(
1356                map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32),
1357                map(string_p(".m16n8k64"), |_, _span| Shape::M16n8k64),
1358                map(string_p(".m8n8k32"), |_, _span| Shape::M8n8k32)
1359            )
1360        }
1361    }
1362
1363    impl PtxParser for MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS321 {
1364        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1365            try_map(
1366                seq_n!(
1367                    string_p("mma"),
1368                    string_p(".sync"),
1369                    string_p(".aligned"),
1370                    Shape::parse(),
1371                    string_p(".row"),
1372                    string_p(".col"),
1373                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1374                    string_p(".s32"),
1375                    Atype::parse(),
1376                    Btype::parse(),
1377                    string_p(".s32"),
1378                    GeneralOperand::parse(),
1379                    comma_p(),
1380                    GeneralOperand::parse(),
1381                    comma_p(),
1382                    GeneralOperand::parse(),
1383                    comma_p(),
1384                    GeneralOperand::parse(),
1385                    semicolon_p()
1386                ),
1387                |(
1388                    _,
1389                    sync,
1390                    aligned,
1391                    shape,
1392                    row,
1393                    col,
1394                    satfinite,
1395                    s32,
1396                    atype,
1397                    btype,
1398                    s322,
1399                    d,
1400                    _,
1401                    a,
1402                    _,
1403                    b,
1404                    _,
1405                    c,
1406                    _,
1407                ),
1408                 span| {
1409                    ok!(MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS321 {
1410                        sync = sync,
1411                        aligned = aligned,
1412                        shape = shape,
1413                        row = row,
1414                        col = col,
1415                        satfinite = satfinite,
1416                        s32 = s32,
1417                        atype = atype,
1418                        btype = btype,
1419                        s322 = s322,
1420                        d = d,
1421                        a = a,
1422                        b = b,
1423                        c = c,
1424
1425                    })
1426                },
1427            )
1428        }
1429    }
1430}
1431
1432pub mod section_8 {
1433    use super::*;
1434    use crate::r#type::instruction::mma::section_8::*;
1435
1436    // ============================================================================
1437    // Generated enum parsers
1438    // ============================================================================
1439
1440    impl PtxParser for Bitop {
1441        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1442            alt!(
1443                map(string_p(".xor"), |_, _span| Bitop::Xor),
1444                map(string_p(".and"), |_, _span| Bitop::And)
1445            )
1446        }
1447    }
1448
1449    impl PtxParser for Shape {
1450        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1451            alt!(
1452                map(string_p(".m16n8k128"), |_, _span| Shape::M16n8k128),
1453                map(string_p(".m16n8k256"), |_, _span| Shape::M16n8k256),
1454                map(string_p(".m8n8k128"), |_, _span| Shape::M8n8k128)
1455            )
1456        }
1457    }
1458
1459    impl PtxParser for MmaSyncAlignedShapeRowColS32B1B1S32BitopPopc {
1460        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1461            try_map(
1462                seq_n!(
1463                    string_p("mma"),
1464                    string_p(".sync"),
1465                    string_p(".aligned"),
1466                    Shape::parse(),
1467                    string_p(".row"),
1468                    string_p(".col"),
1469                    string_p(".s32"),
1470                    string_p(".b1"),
1471                    string_p(".b1"),
1472                    string_p(".s32"),
1473                    Bitop::parse(),
1474                    string_p(".popc"),
1475                    GeneralOperand::parse(),
1476                    comma_p(),
1477                    GeneralOperand::parse(),
1478                    comma_p(),
1479                    GeneralOperand::parse(),
1480                    comma_p(),
1481                    GeneralOperand::parse(),
1482                    semicolon_p()
1483                ),
1484                |(
1485                    _,
1486                    sync,
1487                    aligned,
1488                    shape,
1489                    row,
1490                    col,
1491                    s32,
1492                    b1,
1493                    b12,
1494                    s322,
1495                    bitop,
1496                    popc,
1497                    d,
1498                    _,
1499                    a,
1500                    _,
1501                    b,
1502                    _,
1503                    c,
1504                    _,
1505                ),
1506                 span| {
1507                    ok!(MmaSyncAlignedShapeRowColS32B1B1S32BitopPopc {
1508                        sync = sync,
1509                        aligned = aligned,
1510                        shape = shape,
1511                        row = row,
1512                        col = col,
1513                        s32 = s32,
1514                        b1 = b1,
1515                        b12 = b12,
1516                        s322 = s322,
1517                        bitop = bitop,
1518                        popc = popc,
1519                        d = d,
1520                        a = a,
1521                        b = b,
1522                        c = c,
1523
1524                    })
1525                },
1526            )
1527        }
1528    }
1529}