Skip to main content

ptx_parser/parser/instruction/
mma_sp.rs

1//! Original PTX specification:
2//!
3//! // Half precision floating point type:
4//! mma.spvariant.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype  d, a, b, c, e, f;
5//! mma.spvariant.sync.aligned.m16n8k32.row.col.dtype.f16.f16.ctype  d, a, b, c, e, f;
6//! .ctype     = {.f16, .f32};
7//! .dtype     = {.f16, .f32};
8//! .spvariant = {.sp, .sp::ordered_metadata};
9//! ----------------------------------------------------
10//! // Alternate floating point type:
11//! mma.spvariant.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32     d, a, b, c, e, f;
12//! mma.spvariant.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32     d, a, b, c, e, f;
13//! mma.spvariant.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32      d, a, b, c, e, f;
14//! mma.spvariant.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32     d, a, b, c, e, f;
15//! mma.spvariant.sync.aligned.m16n8k64.row.col.f32.f8type.f8type.f32 d, a, b, c, e, f;
16//! mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.kind.dtype.f8f6f4type.f8f6f4type.ctype d, a, b, c, e, f;
17//! .f8type     = {.e4m3, .e5m2};
18//! .spvariant  = {.sp, .sp::ordered_metadata};
19//! .f8f6f4type = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
20//! .kind       = {.kind::f8f6f4};
21//! .ctype      = {.f16, .f32};
22//! .dtype      = {.f16, .f32};
23//! ----------------------------------------------------
24//! // Alternate floating point type:
25//! // Alternate floating point type with block scaling:
26//! mma.spvariant.sync.aligned.m16n8k128.row.col.kind.block_scale{.scale_vec_size}.f32.e2m1.e2m1.f32.stype d, a, b, c, e, f, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};
27//! .spvariant      = {.sp::ordered_metadata};
28//! .kind           = {.kind::mxf4};
29//! .scale_vec_size = {.scale_vec::2X};
30//! .stype          = {.ue8m0};
31//! ----------------------------------------------------
32//! // Alternate floating point type:
33//! mma.spvariant.sync.aligned.m16n8k128.row.col.kind.block_scale.scale_vec_size.f32.e2m1.e2m1.f32.stype d, a, b, c, e, f, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};
34//! .spvariant      = {.sp::ordered_metadata};
35//! .kind           = {.kind::mxf4nvf4};
36//! .scale_vec_size = {.scale_vec::2X, .scale_vec::4X};
37//! .stype          = {.ue8m0, .ue4m3};
38//! ----------------------------------------------------
39//! // Alternate floating point type:
40//! mma.spvariant.sync.aligned.m16n8k64.row.col.kind.block_scale{.scale_vec_size}.f32.f8f6f4type.f8f6f4type.f32.stype d, a, b, c, e, f, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};
41//! .spvariant      = {.sp::ordered_metadata};
42//! .kind           = {.kind::mxf8f6f4};
43//! .scale_vec_size = {.scale_vec::1X};
44//! .f8f6f4type     = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
45//! .stype          = {.ue8m0};
46//! ----------------------------------------------------
47//! // Alternate floating point type:
48//! // Integer type:
49//! mma.spvariant.sync.aligned.shape.row.col{.satfinite}.s32.atype.btype.s32 d, a, b, c, e, f;
50//! .shape     = {.m16n8k32, .m16n8k64};
51//! .atype     = {.u8, .s8};
52//! .btype     = {.u8, .s8};
53//! .spvariant = {.sp, .sp::ordered_metadata};
54//! ----------------------------------------------------
55//! // Alternate floating point type:
56//! mma.spvariant.sync.aligned.shape.row.col{.satfinite}.s32.atype.btype.s32 d, a, b, c, e, f;
57//! .shape     = {.m16n8k64, .m16n8k128};
58//! .atype     = {.u4, .s4};
59//! .btype     = {.u4, .s4};
60//! .spvariant = {.sp, .sp::ordered_metadata};
61
62#![allow(unused)]
63
64use crate::parser::{
65    PtxParseError, PtxParser, PtxTokenStream, Span,
66    util::{
67        between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
68        pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
69    },
70};
71use crate::r#type::common::*;
72use crate::{alt, ok, seq_n};
73
74pub mod section_0 {
75    use super::*;
76    use crate::r#type::instruction::mma_sp::section_0::*;
77
78    // ============================================================================
79    // Generated enum parsers
80    // ============================================================================
81
82    impl PtxParser for Ctype {
83        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
84            alt!(
85                map(string_p(".f16"), |_, _span| Ctype::F16),
86                map(string_p(".f32"), |_, _span| Ctype::F32)
87            )
88        }
89    }
90
91    impl PtxParser for Dtype {
92        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
93            alt!(
94                map(string_p(".f16"), |_, _span| Dtype::F16),
95                map(string_p(".f32"), |_, _span| Dtype::F32)
96            )
97        }
98    }
99
100    impl PtxParser for Spvariant {
101        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
102            alt!(
103                map(string_p(".sp::ordered_metadata"), |_, _span| {
104                    Spvariant::SpOrderedMetadata
105                }),
106                map(string_p(".sp"), |_, _span| Spvariant::Sp)
107            )
108        }
109    }
110
111    impl PtxParser for MmaSpvariantSyncAlignedM16n8k16RowColDtypeF16F16Ctype {
112        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
113            try_map(
114                seq_n!(
115                    string_p("mma"),
116                    Spvariant::parse(),
117                    string_p(".sync"),
118                    string_p(".aligned"),
119                    string_p(".m16n8k16"),
120                    string_p(".row"),
121                    string_p(".col"),
122                    Dtype::parse(),
123                    string_p(".f16"),
124                    string_p(".f16"),
125                    Ctype::parse(),
126                    GeneralOperand::parse(),
127                    comma_p(),
128                    GeneralOperand::parse(),
129                    comma_p(),
130                    GeneralOperand::parse(),
131                    comma_p(),
132                    GeneralOperand::parse(),
133                    comma_p(),
134                    GeneralOperand::parse(),
135                    comma_p(),
136                    GeneralOperand::parse(),
137                    semicolon_p()
138                ),
139                |(
140                    _,
141                    spvariant,
142                    sync,
143                    aligned,
144                    m16n8k16,
145                    row,
146                    col,
147                    dtype,
148                    f16,
149                    f162,
150                    ctype,
151                    d,
152                    _,
153                    a,
154                    _,
155                    b,
156                    _,
157                    c,
158                    _,
159                    e,
160                    _,
161                    f,
162                    _,
163                ),
164                 span| {
165                    ok!(MmaSpvariantSyncAlignedM16n8k16RowColDtypeF16F16Ctype {
166                        spvariant = spvariant,
167                        sync = sync,
168                        aligned = aligned,
169                        m16n8k16 = m16n8k16,
170                        row = row,
171                        col = col,
172                        dtype = dtype,
173                        f16 = f16,
174                        f162 = f162,
175                        ctype = ctype,
176                        d = d,
177                        a = a,
178                        b = b,
179                        c = c,
180                        e = e,
181                        f = f,
182
183                    })
184                },
185            )
186        }
187    }
188
189    impl PtxParser for MmaSpvariantSyncAlignedM16n8k32RowColDtypeF16F16Ctype {
190        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
191            try_map(
192                seq_n!(
193                    string_p("mma"),
194                    Spvariant::parse(),
195                    string_p(".sync"),
196                    string_p(".aligned"),
197                    string_p(".m16n8k32"),
198                    string_p(".row"),
199                    string_p(".col"),
200                    Dtype::parse(),
201                    string_p(".f16"),
202                    string_p(".f16"),
203                    Ctype::parse(),
204                    GeneralOperand::parse(),
205                    comma_p(),
206                    GeneralOperand::parse(),
207                    comma_p(),
208                    GeneralOperand::parse(),
209                    comma_p(),
210                    GeneralOperand::parse(),
211                    comma_p(),
212                    GeneralOperand::parse(),
213                    comma_p(),
214                    GeneralOperand::parse(),
215                    semicolon_p()
216                ),
217                |(
218                    _,
219                    spvariant,
220                    sync,
221                    aligned,
222                    m16n8k32,
223                    row,
224                    col,
225                    dtype,
226                    f16,
227                    f162,
228                    ctype,
229                    d,
230                    _,
231                    a,
232                    _,
233                    b,
234                    _,
235                    c,
236                    _,
237                    e,
238                    _,
239                    f,
240                    _,
241                ),
242                 span| {
243                    ok!(MmaSpvariantSyncAlignedM16n8k32RowColDtypeF16F16Ctype {
244                        spvariant = spvariant,
245                        sync = sync,
246                        aligned = aligned,
247                        m16n8k32 = m16n8k32,
248                        row = row,
249                        col = col,
250                        dtype = dtype,
251                        f16 = f16,
252                        f162 = f162,
253                        ctype = ctype,
254                        d = d,
255                        a = a,
256                        b = b,
257                        c = c,
258                        e = e,
259                        f = f,
260
261                    })
262                },
263            )
264        }
265    }
266}
267
268pub mod section_1 {
269    use super::*;
270    use crate::r#type::instruction::mma_sp::section_1::*;
271
272    // ============================================================================
273    // Generated enum parsers
274    // ============================================================================
275
276    impl PtxParser for Ctype {
277        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
278            alt!(
279                map(string_p(".f16"), |_, _span| Ctype::F16),
280                map(string_p(".f32"), |_, _span| Ctype::F32)
281            )
282        }
283    }
284
285    impl PtxParser for Dtype {
286        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
287            alt!(
288                map(string_p(".f16"), |_, _span| Dtype::F16),
289                map(string_p(".f32"), |_, _span| Dtype::F32)
290            )
291        }
292    }
293
294    impl PtxParser for F8f6f4type {
295        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
296            alt!(
297                map(string_p(".e4m3"), |_, _span| F8f6f4type::E4m3),
298                map(string_p(".e5m2"), |_, _span| F8f6f4type::E5m2),
299                map(string_p(".e3m2"), |_, _span| F8f6f4type::E3m2),
300                map(string_p(".e2m3"), |_, _span| F8f6f4type::E2m3),
301                map(string_p(".e2m1"), |_, _span| F8f6f4type::E2m1)
302            )
303        }
304    }
305
306    impl PtxParser for F8type {
307        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
308            alt!(
309                map(string_p(".e4m3"), |_, _span| F8type::E4m3),
310                map(string_p(".e5m2"), |_, _span| F8type::E5m2)
311            )
312        }
313    }
314
315    impl PtxParser for Kind {
316        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
317            alt!(map(string_p(".kind::f8f6f4"), |_, _span| Kind::KindF8f6f4))
318        }
319    }
320
321    impl PtxParser for Spvariant {
322        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
323            alt!(
324                map(string_p(".sp::ordered_metadata"), |_, _span| {
325                    Spvariant::SpOrderedMetadata
326                }),
327                map(string_p(".sp"), |_, _span| Spvariant::Sp)
328            )
329        }
330    }
331
332    impl PtxParser for MmaSpvariantSyncAlignedM16n8k16RowColF32Bf16Bf16F32 {
333        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
334            try_map(
335                seq_n!(
336                    string_p("mma"),
337                    Spvariant::parse(),
338                    string_p(".sync"),
339                    string_p(".aligned"),
340                    string_p(".m16n8k16"),
341                    string_p(".row"),
342                    string_p(".col"),
343                    string_p(".f32"),
344                    string_p(".bf16"),
345                    string_p(".bf16"),
346                    string_p(".f32"),
347                    GeneralOperand::parse(),
348                    comma_p(),
349                    GeneralOperand::parse(),
350                    comma_p(),
351                    GeneralOperand::parse(),
352                    comma_p(),
353                    GeneralOperand::parse(),
354                    comma_p(),
355                    GeneralOperand::parse(),
356                    comma_p(),
357                    GeneralOperand::parse(),
358                    semicolon_p()
359                ),
360                |(
361                    _,
362                    spvariant,
363                    sync,
364                    aligned,
365                    m16n8k16,
366                    row,
367                    col,
368                    f32,
369                    bf16,
370                    bf162,
371                    f322,
372                    d,
373                    _,
374                    a,
375                    _,
376                    b,
377                    _,
378                    c,
379                    _,
380                    e,
381                    _,
382                    f,
383                    _,
384                ),
385                 span| {
386                    ok!(MmaSpvariantSyncAlignedM16n8k16RowColF32Bf16Bf16F32 {
387                        spvariant = spvariant,
388                        sync = sync,
389                        aligned = aligned,
390                        m16n8k16 = m16n8k16,
391                        row = row,
392                        col = col,
393                        f32 = f32,
394                        bf16 = bf16,
395                        bf162 = bf162,
396                        f322 = f322,
397                        d = d,
398                        a = a,
399                        b = b,
400                        c = c,
401                        e = e,
402                        f = f,
403
404                    })
405                },
406            )
407        }
408    }
409
410    impl PtxParser for MmaSpvariantSyncAlignedM16n8k32RowColF32Bf16Bf16F32 {
411        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
412            try_map(
413                seq_n!(
414                    string_p("mma"),
415                    Spvariant::parse(),
416                    string_p(".sync"),
417                    string_p(".aligned"),
418                    string_p(".m16n8k32"),
419                    string_p(".row"),
420                    string_p(".col"),
421                    string_p(".f32"),
422                    string_p(".bf16"),
423                    string_p(".bf16"),
424                    string_p(".f32"),
425                    GeneralOperand::parse(),
426                    comma_p(),
427                    GeneralOperand::parse(),
428                    comma_p(),
429                    GeneralOperand::parse(),
430                    comma_p(),
431                    GeneralOperand::parse(),
432                    comma_p(),
433                    GeneralOperand::parse(),
434                    comma_p(),
435                    GeneralOperand::parse(),
436                    semicolon_p()
437                ),
438                |(
439                    _,
440                    spvariant,
441                    sync,
442                    aligned,
443                    m16n8k32,
444                    row,
445                    col,
446                    f32,
447                    bf16,
448                    bf162,
449                    f322,
450                    d,
451                    _,
452                    a,
453                    _,
454                    b,
455                    _,
456                    c,
457                    _,
458                    e,
459                    _,
460                    f,
461                    _,
462                ),
463                 span| {
464                    ok!(MmaSpvariantSyncAlignedM16n8k32RowColF32Bf16Bf16F32 {
465                        spvariant = spvariant,
466                        sync = sync,
467                        aligned = aligned,
468                        m16n8k32 = m16n8k32,
469                        row = row,
470                        col = col,
471                        f32 = f32,
472                        bf16 = bf16,
473                        bf162 = bf162,
474                        f322 = f322,
475                        d = d,
476                        a = a,
477                        b = b,
478                        c = c,
479                        e = e,
480                        f = f,
481
482                    })
483                },
484            )
485        }
486    }
487
488    impl PtxParser for MmaSpvariantSyncAlignedM16n8k8RowColF32Tf32Tf32F32 {
489        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
490            try_map(
491                seq_n!(
492                    string_p("mma"),
493                    Spvariant::parse(),
494                    string_p(".sync"),
495                    string_p(".aligned"),
496                    string_p(".m16n8k8"),
497                    string_p(".row"),
498                    string_p(".col"),
499                    string_p(".f32"),
500                    string_p(".tf32"),
501                    string_p(".tf32"),
502                    string_p(".f32"),
503                    GeneralOperand::parse(),
504                    comma_p(),
505                    GeneralOperand::parse(),
506                    comma_p(),
507                    GeneralOperand::parse(),
508                    comma_p(),
509                    GeneralOperand::parse(),
510                    comma_p(),
511                    GeneralOperand::parse(),
512                    comma_p(),
513                    GeneralOperand::parse(),
514                    semicolon_p()
515                ),
516                |(
517                    _,
518                    spvariant,
519                    sync,
520                    aligned,
521                    m16n8k8,
522                    row,
523                    col,
524                    f32,
525                    tf32,
526                    tf322,
527                    f322,
528                    d,
529                    _,
530                    a,
531                    _,
532                    b,
533                    _,
534                    c,
535                    _,
536                    e,
537                    _,
538                    f,
539                    _,
540                ),
541                 span| {
542                    ok!(MmaSpvariantSyncAlignedM16n8k8RowColF32Tf32Tf32F32 {
543                        spvariant = spvariant,
544                        sync = sync,
545                        aligned = aligned,
546                        m16n8k8 = m16n8k8,
547                        row = row,
548                        col = col,
549                        f32 = f32,
550                        tf32 = tf32,
551                        tf322 = tf322,
552                        f322 = f322,
553                        d = d,
554                        a = a,
555                        b = b,
556                        c = c,
557                        e = e,
558                        f = f,
559
560                    })
561                },
562            )
563        }
564    }
565
566    impl PtxParser for MmaSpvariantSyncAlignedM16n8k16RowColF32Tf32Tf32F32 {
567        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
568            try_map(
569                seq_n!(
570                    string_p("mma"),
571                    Spvariant::parse(),
572                    string_p(".sync"),
573                    string_p(".aligned"),
574                    string_p(".m16n8k16"),
575                    string_p(".row"),
576                    string_p(".col"),
577                    string_p(".f32"),
578                    string_p(".tf32"),
579                    string_p(".tf32"),
580                    string_p(".f32"),
581                    GeneralOperand::parse(),
582                    comma_p(),
583                    GeneralOperand::parse(),
584                    comma_p(),
585                    GeneralOperand::parse(),
586                    comma_p(),
587                    GeneralOperand::parse(),
588                    comma_p(),
589                    GeneralOperand::parse(),
590                    comma_p(),
591                    GeneralOperand::parse(),
592                    semicolon_p()
593                ),
594                |(
595                    _,
596                    spvariant,
597                    sync,
598                    aligned,
599                    m16n8k16,
600                    row,
601                    col,
602                    f32,
603                    tf32,
604                    tf322,
605                    f322,
606                    d,
607                    _,
608                    a,
609                    _,
610                    b,
611                    _,
612                    c,
613                    _,
614                    e,
615                    _,
616                    f,
617                    _,
618                ),
619                 span| {
620                    ok!(MmaSpvariantSyncAlignedM16n8k16RowColF32Tf32Tf32F32 {
621                        spvariant = spvariant,
622                        sync = sync,
623                        aligned = aligned,
624                        m16n8k16 = m16n8k16,
625                        row = row,
626                        col = col,
627                        f32 = f32,
628                        tf32 = tf32,
629                        tf322 = tf322,
630                        f322 = f322,
631                        d = d,
632                        a = a,
633                        b = b,
634                        c = c,
635                        e = e,
636                        f = f,
637
638                    })
639                },
640            )
641        }
642    }
643
644    impl PtxParser for MmaSpvariantSyncAlignedM16n8k64RowColF32F8typeF8typeF32 {
645        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
646            try_map(
647                seq_n!(
648                    string_p("mma"),
649                    Spvariant::parse(),
650                    string_p(".sync"),
651                    string_p(".aligned"),
652                    string_p(".m16n8k64"),
653                    string_p(".row"),
654                    string_p(".col"),
655                    string_p(".f32"),
656                    F8type::parse(),
657                    F8type::parse(),
658                    string_p(".f32"),
659                    GeneralOperand::parse(),
660                    comma_p(),
661                    GeneralOperand::parse(),
662                    comma_p(),
663                    GeneralOperand::parse(),
664                    comma_p(),
665                    GeneralOperand::parse(),
666                    comma_p(),
667                    GeneralOperand::parse(),
668                    comma_p(),
669                    GeneralOperand::parse(),
670                    semicolon_p()
671                ),
672                |(
673                    _,
674                    spvariant,
675                    sync,
676                    aligned,
677                    m16n8k64,
678                    row,
679                    col,
680                    f32,
681                    f8type,
682                    f8type1,
683                    f322,
684                    d,
685                    _,
686                    a,
687                    _,
688                    b,
689                    _,
690                    c,
691                    _,
692                    e,
693                    _,
694                    f,
695                    _,
696                ),
697                 span| {
698                    ok!(MmaSpvariantSyncAlignedM16n8k64RowColF32F8typeF8typeF32 {
699                        spvariant = spvariant,
700                        sync = sync,
701                        aligned = aligned,
702                        m16n8k64 = m16n8k64,
703                        row = row,
704                        col = col,
705                        f32 = f32,
706                        f8type = f8type,
707                        f8type1 = f8type1,
708                        f322 = f322,
709                        d = d,
710                        a = a,
711                        b = b,
712                        c = c,
713                        e = e,
714                        f = f,
715
716                    })
717                },
718            )
719        }
720    }
721
722    impl PtxParser for MmaSpOrderedMetadataSyncAlignedM16n8k64RowColKindDtypeF8f6f4typeF8f6f4typeCtype {
723        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
724            try_map(
725                seq_n!(
726                    string_p("mma"),
727                    string_p(".sp::ordered_metadata"),
728                    string_p(".sync"),
729                    string_p(".aligned"),
730                    string_p(".m16n8k64"),
731                    string_p(".row"),
732                    string_p(".col"),
733                    Kind::parse(),
734                    Dtype::parse(),
735                    F8f6f4type::parse(),
736                    F8f6f4type::parse(),
737                    Ctype::parse(),
738                    GeneralOperand::parse(),
739                    comma_p(),
740                    GeneralOperand::parse(),
741                    comma_p(),
742                    GeneralOperand::parse(),
743                    comma_p(),
744                    GeneralOperand::parse(),
745                    comma_p(),
746                    GeneralOperand::parse(),
747                    comma_p(),
748                    GeneralOperand::parse(),
749                    semicolon_p()
750                ),
751                |(
752                    _,
753                    sp_ordered_metadata,
754                    sync,
755                    aligned,
756                    m16n8k64,
757                    row,
758                    col,
759                    kind,
760                    dtype,
761                    f8f6f4type,
762                    f8f6f4type1,
763                    ctype,
764                    d,
765                    _,
766                    a,
767                    _,
768                    b,
769                    _,
770                    c,
771                    _,
772                    e,
773                    _,
774                    f,
775                    _,
776                ),
777                 span| {
778                    ok!(MmaSpOrderedMetadataSyncAlignedM16n8k64RowColKindDtypeF8f6f4typeF8f6f4typeCtype {
779                        sp_ordered_metadata = sp_ordered_metadata,
780                        sync = sync,
781                        aligned = aligned,
782                        m16n8k64 = m16n8k64,
783                        row = row,
784                        col = col,
785                        kind = kind,
786                        dtype = dtype,
787                        f8f6f4type = f8f6f4type,
788                        f8f6f4type1 = f8f6f4type1,
789                        ctype = ctype,
790                        d = d,
791                        a = a,
792                        b = b,
793                        c = c,
794                        e = e,
795                        f = f,
796
797                    })
798                },
799            )
800        }
801    }
802}
803
804pub mod section_2 {
805    use super::*;
806    use crate::r#type::instruction::mma_sp::section_2::*;
807
808    // ============================================================================
809    // Generated enum parsers
810    // ============================================================================
811
812    impl PtxParser for Kind {
813        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
814            alt!(map(string_p(".kind::mxf4"), |_, _span| Kind::KindMxf4))
815        }
816    }
817
818    impl PtxParser for ScaleVecSize {
819        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
820            alt!(map(string_p(".scale_vec::2X"), |_, _span| {
821                ScaleVecSize::ScaleVec2x
822            }))
823        }
824    }
825
826    impl PtxParser for Spvariant {
827        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
828            alt!(map(string_p(".sp::ordered_metadata"), |_, _span| {
829                Spvariant::SpOrderedMetadata
830            }))
831        }
832    }
833
834    impl PtxParser for Stype {
835        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
836            alt!(map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0))
837        }
838    }
839
840    impl PtxParser
841        for MmaSpvariantSyncAlignedM16n8k128RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype
842    {
843        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
844            try_map(
845                seq_n!(
846                    string_p("mma"),
847                    Spvariant::parse(),
848                    string_p(".sync"),
849                    string_p(".aligned"),
850                    string_p(".m16n8k128"),
851                    string_p(".row"),
852                    string_p(".col"),
853                    Kind::parse(),
854                    string_p(".block_scale"),
855                    optional(ScaleVecSize::parse()),
856                    string_p(".f32"),
857                    string_p(".e2m1"),
858                    string_p(".e2m1"),
859                    string_p(".f32"),
860                    Stype::parse(),
861                    GeneralOperand::parse(),
862                    comma_p(),
863                    GeneralOperand::parse(),
864                    comma_p(),
865                    GeneralOperand::parse(),
866                    comma_p(),
867                    GeneralOperand::parse(),
868                    comma_p(),
869                    GeneralOperand::parse(),
870                    comma_p(),
871                    GeneralOperand::parse(),
872                    comma_p(),
873                    GeneralOperand::parse(),
874                    comma_p(),
875                    VectorOperand::parse(),
876                    comma_p(),
877                    GeneralOperand::parse(),
878                    comma_p(),
879                    VectorOperand::parse(),
880                    semicolon_p()
881                ),
882                |(
883                    _,
884                    spvariant,
885                    sync,
886                    aligned,
887                    m16n8k128,
888                    row,
889                    col,
890                    kind,
891                    block_scale,
892                    scale_vec_size,
893                    f32,
894                    e2m1,
895                    e2m12,
896                    f322,
897                    stype,
898                    d,
899                    _,
900                    a,
901                    _,
902                    b,
903                    _,
904                    c,
905                    _,
906                    e,
907                    _,
908                    f,
909                    _,
910                    scale_a_data,
911                    _,
912                    byte_id_a,
913                    _,
914                    scale_b_data,
915                    _,
916                    byte_id_b,
917                    _,
918                ),
919                 span| {
920                    ok!(MmaSpvariantSyncAlignedM16n8k128RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype {
921                        spvariant = spvariant,
922                        sync = sync,
923                        aligned = aligned,
924                        m16n8k128 = m16n8k128,
925                        row = row,
926                        col = col,
927                        kind = kind,
928                        block_scale = block_scale,
929                        scale_vec_size = scale_vec_size,
930                        f32 = f32,
931                        e2m1 = e2m1,
932                        e2m12 = e2m12,
933                        f322 = f322,
934                        stype = stype,
935                        d = d,
936                        a = a,
937                        b = b,
938                        c = c,
939                        e = e,
940                        f = f,
941                        scale_a_data = scale_a_data,
942                        byte_id_a = byte_id_a,
943                        scale_b_data = scale_b_data,
944                        byte_id_b = byte_id_b,
945
946                    })
947                },
948            )
949        }
950    }
951}
952
953pub mod section_3 {
954    use super::*;
955    use crate::r#type::instruction::mma_sp::section_3::*;
956
957    // ============================================================================
958    // Generated enum parsers
959    // ============================================================================
960
961    impl PtxParser for Kind {
962        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
963            alt!(map(string_p(".kind::mxf4nvf4"), |_, _span| {
964                Kind::KindMxf4nvf4
965            }))
966        }
967    }
968
969    impl PtxParser for ScaleVecSize {
970        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
971            alt!(
972                map(string_p(".scale_vec::2X"), |_, _span| {
973                    ScaleVecSize::ScaleVec2x
974                }),
975                map(string_p(".scale_vec::4X"), |_, _span| {
976                    ScaleVecSize::ScaleVec4x
977                })
978            )
979        }
980    }
981
982    impl PtxParser for Spvariant {
983        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
984            alt!(map(string_p(".sp::ordered_metadata"), |_, _span| {
985                Spvariant::SpOrderedMetadata
986            }))
987        }
988    }
989
990    impl PtxParser for Stype {
991        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
992            alt!(
993                map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0),
994                map(string_p(".ue4m3"), |_, _span| Stype::Ue4m3)
995            )
996        }
997    }
998
999    impl PtxParser
1000        for MmaSpvariantSyncAlignedM16n8k128RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype1
1001    {
1002        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1003            try_map(
1004                seq_n!(
1005                    string_p("mma"),
1006                    Spvariant::parse(),
1007                    string_p(".sync"),
1008                    string_p(".aligned"),
1009                    string_p(".m16n8k128"),
1010                    string_p(".row"),
1011                    string_p(".col"),
1012                    Kind::parse(),
1013                    string_p(".block_scale"),
1014                    ScaleVecSize::parse(),
1015                    string_p(".f32"),
1016                    string_p(".e2m1"),
1017                    string_p(".e2m1"),
1018                    string_p(".f32"),
1019                    Stype::parse(),
1020                    GeneralOperand::parse(),
1021                    comma_p(),
1022                    GeneralOperand::parse(),
1023                    comma_p(),
1024                    GeneralOperand::parse(),
1025                    comma_p(),
1026                    GeneralOperand::parse(),
1027                    comma_p(),
1028                    GeneralOperand::parse(),
1029                    comma_p(),
1030                    GeneralOperand::parse(),
1031                    comma_p(),
1032                    GeneralOperand::parse(),
1033                    comma_p(),
1034                    VectorOperand::parse(),
1035                    comma_p(),
1036                    GeneralOperand::parse(),
1037                    comma_p(),
1038                    VectorOperand::parse(),
1039                    semicolon_p()
1040                ),
1041                |(
1042                    _,
1043                    spvariant,
1044                    sync,
1045                    aligned,
1046                    m16n8k128,
1047                    row,
1048                    col,
1049                    kind,
1050                    block_scale,
1051                    scale_vec_size,
1052                    f32,
1053                    e2m1,
1054                    e2m12,
1055                    f322,
1056                    stype,
1057                    d,
1058                    _,
1059                    a,
1060                    _,
1061                    b,
1062                    _,
1063                    c,
1064                    _,
1065                    e,
1066                    _,
1067                    f,
1068                    _,
1069                    scale_a_data,
1070                    _,
1071                    byte_id_a,
1072                    _,
1073                    scale_b_data,
1074                    _,
1075                    byte_id_b,
1076                    _,
1077                ),
1078                 span| {
1079                    ok!(MmaSpvariantSyncAlignedM16n8k128RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype1 {
1080                        spvariant = spvariant,
1081                        sync = sync,
1082                        aligned = aligned,
1083                        m16n8k128 = m16n8k128,
1084                        row = row,
1085                        col = col,
1086                        kind = kind,
1087                        block_scale = block_scale,
1088                        scale_vec_size = scale_vec_size,
1089                        f32 = f32,
1090                        e2m1 = e2m1,
1091                        e2m12 = e2m12,
1092                        f322 = f322,
1093                        stype = stype,
1094                        d = d,
1095                        a = a,
1096                        b = b,
1097                        c = c,
1098                        e = e,
1099                        f = f,
1100                        scale_a_data = scale_a_data,
1101                        byte_id_a = byte_id_a,
1102                        scale_b_data = scale_b_data,
1103                        byte_id_b = byte_id_b,
1104
1105                    })
1106                },
1107            )
1108        }
1109    }
1110}
1111
1112pub mod section_4 {
1113    use super::*;
1114    use crate::r#type::instruction::mma_sp::section_4::*;
1115
1116    // ============================================================================
1117    // Generated enum parsers
1118    // ============================================================================
1119
1120    impl PtxParser for F8f6f4type {
1121        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1122            alt!(
1123                map(string_p(".e4m3"), |_, _span| F8f6f4type::E4m3),
1124                map(string_p(".e5m2"), |_, _span| F8f6f4type::E5m2),
1125                map(string_p(".e3m2"), |_, _span| F8f6f4type::E3m2),
1126                map(string_p(".e2m3"), |_, _span| F8f6f4type::E2m3),
1127                map(string_p(".e2m1"), |_, _span| F8f6f4type::E2m1)
1128            )
1129        }
1130    }
1131
1132    impl PtxParser for Kind {
1133        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1134            alt!(map(string_p(".kind::mxf8f6f4"), |_, _span| {
1135                Kind::KindMxf8f6f4
1136            }))
1137        }
1138    }
1139
1140    impl PtxParser for ScaleVecSize {
1141        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1142            alt!(map(string_p(".scale_vec::1X"), |_, _span| {
1143                ScaleVecSize::ScaleVec1x
1144            }))
1145        }
1146    }
1147
1148    impl PtxParser for Spvariant {
1149        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1150            alt!(map(string_p(".sp::ordered_metadata"), |_, _span| {
1151                Spvariant::SpOrderedMetadata
1152            }))
1153        }
1154    }
1155
1156    impl PtxParser for Stype {
1157        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1158            alt!(map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0))
1159        }
1160    }
1161
1162    impl PtxParser for MmaSpvariantSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32F8f6f4typeF8f6f4typeF32Stype {
1163        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1164            try_map(
1165                seq_n!(
1166                    string_p("mma"),
1167                    Spvariant::parse(),
1168                    string_p(".sync"),
1169                    string_p(".aligned"),
1170                    string_p(".m16n8k64"),
1171                    string_p(".row"),
1172                    string_p(".col"),
1173                    Kind::parse(),
1174                    string_p(".block_scale"),
1175                    optional(ScaleVecSize::parse()),
1176                    string_p(".f32"),
1177                    F8f6f4type::parse(),
1178                    F8f6f4type::parse(),
1179                    string_p(".f32"),
1180                    Stype::parse(),
1181                    GeneralOperand::parse(),
1182                    comma_p(),
1183                    GeneralOperand::parse(),
1184                    comma_p(),
1185                    GeneralOperand::parse(),
1186                    comma_p(),
1187                    GeneralOperand::parse(),
1188                    comma_p(),
1189                    GeneralOperand::parse(),
1190                    comma_p(),
1191                    GeneralOperand::parse(),
1192                    comma_p(),
1193                    GeneralOperand::parse(),
1194                    comma_p(),
1195                    VectorOperand::parse(),
1196                    comma_p(),
1197                    GeneralOperand::parse(),
1198                    comma_p(),
1199                    VectorOperand::parse(),
1200                    semicolon_p()
1201                ),
1202                |(_, spvariant, sync, aligned, m16n8k64, row, col, kind, block_scale, scale_vec_size, f32, f8f6f4type, f8f6f4type1, f322, stype, d, _, a, _, b, _, c, _, e, _, f, _, scale_a_data, _, byte_id_a, _, scale_b_data, _, byte_id_b, _), span| {
1203                    ok!(MmaSpvariantSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32F8f6f4typeF8f6f4typeF32Stype {
1204                        spvariant = spvariant,
1205                        sync = sync,
1206                        aligned = aligned,
1207                        m16n8k64 = m16n8k64,
1208                        row = row,
1209                        col = col,
1210                        kind = kind,
1211                        block_scale = block_scale,
1212                        scale_vec_size = scale_vec_size,
1213                        f32 = f32,
1214                        f8f6f4type = f8f6f4type,
1215                        f8f6f4type1 = f8f6f4type1,
1216                        f322 = f322,
1217                        stype = stype,
1218                        d = d,
1219                        a = a,
1220                        b = b,
1221                        c = c,
1222                        e = e,
1223                        f = f,
1224                        scale_a_data = scale_a_data,
1225                        byte_id_a = byte_id_a,
1226                        scale_b_data = scale_b_data,
1227                        byte_id_b = byte_id_b,
1228
1229                    })
1230                },
1231            )
1232        }
1233    }
1234}
1235
1236pub mod section_5 {
1237    use super::*;
1238    use crate::r#type::instruction::mma_sp::section_5::*;
1239
1240    // ============================================================================
1241    // Generated enum parsers
1242    // ============================================================================
1243
1244    impl PtxParser for Atype {
1245        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1246            alt!(
1247                map(string_p(".u8"), |_, _span| Atype::U8),
1248                map(string_p(".s8"), |_, _span| Atype::S8)
1249            )
1250        }
1251    }
1252
1253    impl PtxParser for Btype {
1254        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1255            alt!(
1256                map(string_p(".u8"), |_, _span| Btype::U8),
1257                map(string_p(".s8"), |_, _span| Btype::S8)
1258            )
1259        }
1260    }
1261
1262    impl PtxParser for Shape {
1263        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1264            alt!(
1265                map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32),
1266                map(string_p(".m16n8k64"), |_, _span| Shape::M16n8k64)
1267            )
1268        }
1269    }
1270
1271    impl PtxParser for Spvariant {
1272        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1273            alt!(
1274                map(string_p(".sp::ordered_metadata"), |_, _span| {
1275                    Spvariant::SpOrderedMetadata
1276                }),
1277                map(string_p(".sp"), |_, _span| Spvariant::Sp)
1278            )
1279        }
1280    }
1281
1282    impl PtxParser for MmaSpvariantSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS32 {
1283        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1284            try_map(
1285                seq_n!(
1286                    string_p("mma"),
1287                    Spvariant::parse(),
1288                    string_p(".sync"),
1289                    string_p(".aligned"),
1290                    Shape::parse(),
1291                    string_p(".row"),
1292                    string_p(".col"),
1293                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1294                    string_p(".s32"),
1295                    Atype::parse(),
1296                    Btype::parse(),
1297                    string_p(".s32"),
1298                    GeneralOperand::parse(),
1299                    comma_p(),
1300                    GeneralOperand::parse(),
1301                    comma_p(),
1302                    GeneralOperand::parse(),
1303                    comma_p(),
1304                    GeneralOperand::parse(),
1305                    comma_p(),
1306                    GeneralOperand::parse(),
1307                    comma_p(),
1308                    GeneralOperand::parse(),
1309                    semicolon_p()
1310                ),
1311                |(
1312                    _,
1313                    spvariant,
1314                    sync,
1315                    aligned,
1316                    shape,
1317                    row,
1318                    col,
1319                    satfinite,
1320                    s32,
1321                    atype,
1322                    btype,
1323                    s322,
1324                    d,
1325                    _,
1326                    a,
1327                    _,
1328                    b,
1329                    _,
1330                    c,
1331                    _,
1332                    e,
1333                    _,
1334                    f,
1335                    _,
1336                ),
1337                 span| {
1338                    ok!(MmaSpvariantSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS32 {
1339                        spvariant = spvariant,
1340                        sync = sync,
1341                        aligned = aligned,
1342                        shape = shape,
1343                        row = row,
1344                        col = col,
1345                        satfinite = satfinite,
1346                        s32 = s32,
1347                        atype = atype,
1348                        btype = btype,
1349                        s322 = s322,
1350                        d = d,
1351                        a = a,
1352                        b = b,
1353                        c = c,
1354                        e = e,
1355                        f = f,
1356
1357                    })
1358                },
1359            )
1360        }
1361    }
1362}
1363
1364pub mod section_6 {
1365    use super::*;
1366    use crate::r#type::instruction::mma_sp::section_6::*;
1367
1368    // ============================================================================
1369    // Generated enum parsers
1370    // ============================================================================
1371
1372    impl PtxParser for Atype {
1373        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1374            alt!(
1375                map(string_p(".u4"), |_, _span| Atype::U4),
1376                map(string_p(".s4"), |_, _span| Atype::S4)
1377            )
1378        }
1379    }
1380
1381    impl PtxParser for Btype {
1382        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1383            alt!(
1384                map(string_p(".u4"), |_, _span| Btype::U4),
1385                map(string_p(".s4"), |_, _span| Btype::S4)
1386            )
1387        }
1388    }
1389
1390    impl PtxParser for Shape {
1391        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1392            alt!(
1393                map(string_p(".m16n8k128"), |_, _span| Shape::M16n8k128),
1394                map(string_p(".m16n8k64"), |_, _span| Shape::M16n8k64)
1395            )
1396        }
1397    }
1398
1399    impl PtxParser for Spvariant {
1400        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1401            alt!(
1402                map(string_p(".sp::ordered_metadata"), |_, _span| {
1403                    Spvariant::SpOrderedMetadata
1404                }),
1405                map(string_p(".sp"), |_, _span| Spvariant::Sp)
1406            )
1407        }
1408    }
1409
1410    impl PtxParser for MmaSpvariantSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS321 {
1411        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1412            try_map(
1413                seq_n!(
1414                    string_p("mma"),
1415                    Spvariant::parse(),
1416                    string_p(".sync"),
1417                    string_p(".aligned"),
1418                    Shape::parse(),
1419                    string_p(".row"),
1420                    string_p(".col"),
1421                    map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1422                    string_p(".s32"),
1423                    Atype::parse(),
1424                    Btype::parse(),
1425                    string_p(".s32"),
1426                    GeneralOperand::parse(),
1427                    comma_p(),
1428                    GeneralOperand::parse(),
1429                    comma_p(),
1430                    GeneralOperand::parse(),
1431                    comma_p(),
1432                    GeneralOperand::parse(),
1433                    comma_p(),
1434                    GeneralOperand::parse(),
1435                    comma_p(),
1436                    GeneralOperand::parse(),
1437                    semicolon_p()
1438                ),
1439                |(
1440                    _,
1441                    spvariant,
1442                    sync,
1443                    aligned,
1444                    shape,
1445                    row,
1446                    col,
1447                    satfinite,
1448                    s32,
1449                    atype,
1450                    btype,
1451                    s322,
1452                    d,
1453                    _,
1454                    a,
1455                    _,
1456                    b,
1457                    _,
1458                    c,
1459                    _,
1460                    e,
1461                    _,
1462                    f,
1463                    _,
1464                ),
1465                 span| {
1466                    ok!(MmaSpvariantSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS321 {
1467                        spvariant = spvariant,
1468                        sync = sync,
1469                        aligned = aligned,
1470                        shape = shape,
1471                        row = row,
1472                        col = col,
1473                        satfinite = satfinite,
1474                        s32 = s32,
1475                        atype = atype,
1476                        btype = btype,
1477                        s322 = s322,
1478                        d = d,
1479                        a = a,
1480                        b = b,
1481                        c = c,
1482                        e = e,
1483                        f = f,
1484
1485                    })
1486                },
1487            )
1488        }
1489    }
1490}