1#![allow(unused)]
73
74use crate::parser::{
75 PtxParseError, PtxParser, PtxTokenStream, Span,
76 util::{
77 between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
78 pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
79 },
80};
81use crate::r#type::common::*;
82use crate::{alt, ok, seq_n};
83
84pub mod section_0 {
85 use super::*;
86 use crate::r#type::instruction::mma::section_0::*;
87
88 impl PtxParser for Alayout {
93 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
94 alt!(
95 map(string_p(".row"), |_, _span| Alayout::Row),
96 map(string_p(".col"), |_, _span| Alayout::Col)
97 )
98 }
99 }
100
101 impl PtxParser for Blayout {
102 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
103 alt!(
104 map(string_p(".row"), |_, _span| Blayout::Row),
105 map(string_p(".col"), |_, _span| Blayout::Col)
106 )
107 }
108 }
109
110 impl PtxParser for Ctype {
111 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
112 alt!(
113 map(string_p(".f16"), |_, _span| Ctype::F16),
114 map(string_p(".f32"), |_, _span| Ctype::F32)
115 )
116 }
117 }
118
119 impl PtxParser for Dtype {
120 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
121 alt!(
122 map(string_p(".f16"), |_, _span| Dtype::F16),
123 map(string_p(".f32"), |_, _span| Dtype::F32)
124 )
125 }
126 }
127
128 impl PtxParser for MmaSyncAlignedM8n8k4AlayoutBlayoutDtypeF16F16Ctype {
129 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
130 try_map(
131 seq_n!(
132 string_p("mma"),
133 string_p(".sync"),
134 string_p(".aligned"),
135 string_p(".m8n8k4"),
136 Alayout::parse(),
137 Blayout::parse(),
138 Dtype::parse(),
139 string_p(".f16"),
140 string_p(".f16"),
141 Ctype::parse(),
142 GeneralOperand::parse(),
143 comma_p(),
144 GeneralOperand::parse(),
145 comma_p(),
146 GeneralOperand::parse(),
147 comma_p(),
148 GeneralOperand::parse(),
149 semicolon_p()
150 ),
151 |(
152 _,
153 sync,
154 aligned,
155 m8n8k4,
156 alayout,
157 blayout,
158 dtype,
159 f16,
160 f162,
161 ctype,
162 d,
163 _,
164 a,
165 _,
166 b,
167 _,
168 c,
169 _,
170 ),
171 span| {
172 ok!(MmaSyncAlignedM8n8k4AlayoutBlayoutDtypeF16F16Ctype {
173 sync = sync,
174 aligned = aligned,
175 m8n8k4 = m8n8k4,
176 alayout = alayout,
177 blayout = blayout,
178 dtype = dtype,
179 f16 = f16,
180 f162 = f162,
181 ctype = ctype,
182 d = d,
183 a = a,
184 b = b,
185 c = c,
186
187 })
188 },
189 )
190 }
191 }
192
193 impl PtxParser for MmaSyncAlignedM16n8k8RowColDtypeF16F16Ctype {
194 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
195 try_map(
196 seq_n!(
197 string_p("mma"),
198 string_p(".sync"),
199 string_p(".aligned"),
200 string_p(".m16n8k8"),
201 string_p(".row"),
202 string_p(".col"),
203 Dtype::parse(),
204 string_p(".f16"),
205 string_p(".f16"),
206 Ctype::parse(),
207 GeneralOperand::parse(),
208 comma_p(),
209 GeneralOperand::parse(),
210 comma_p(),
211 GeneralOperand::parse(),
212 comma_p(),
213 GeneralOperand::parse(),
214 semicolon_p()
215 ),
216 |(
217 _,
218 sync,
219 aligned,
220 m16n8k8,
221 row,
222 col,
223 dtype,
224 f16,
225 f162,
226 ctype,
227 d,
228 _,
229 a,
230 _,
231 b,
232 _,
233 c,
234 _,
235 ),
236 span| {
237 ok!(MmaSyncAlignedM16n8k8RowColDtypeF16F16Ctype {
238 sync = sync,
239 aligned = aligned,
240 m16n8k8 = m16n8k8,
241 row = row,
242 col = col,
243 dtype = dtype,
244 f16 = f16,
245 f162 = f162,
246 ctype = ctype,
247 d = d,
248 a = a,
249 b = b,
250 c = c,
251
252 })
253 },
254 )
255 }
256 }
257
258 impl PtxParser for MmaSyncAlignedM16n8k16RowColDtypeF16F16Ctype {
259 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
260 try_map(
261 seq_n!(
262 string_p("mma"),
263 string_p(".sync"),
264 string_p(".aligned"),
265 string_p(".m16n8k16"),
266 string_p(".row"),
267 string_p(".col"),
268 Dtype::parse(),
269 string_p(".f16"),
270 string_p(".f16"),
271 Ctype::parse(),
272 GeneralOperand::parse(),
273 comma_p(),
274 GeneralOperand::parse(),
275 comma_p(),
276 GeneralOperand::parse(),
277 comma_p(),
278 GeneralOperand::parse(),
279 semicolon_p()
280 ),
281 |(
282 _,
283 sync,
284 aligned,
285 m16n8k16,
286 row,
287 col,
288 dtype,
289 f16,
290 f162,
291 ctype,
292 d,
293 _,
294 a,
295 _,
296 b,
297 _,
298 c,
299 _,
300 ),
301 span| {
302 ok!(MmaSyncAlignedM16n8k16RowColDtypeF16F16Ctype {
303 sync = sync,
304 aligned = aligned,
305 m16n8k16 = m16n8k16,
306 row = row,
307 col = col,
308 dtype = dtype,
309 f16 = f16,
310 f162 = f162,
311 ctype = ctype,
312 d = d,
313 a = a,
314 b = b,
315 c = c,
316
317 })
318 },
319 )
320 }
321 }
322}
323
324pub mod section_1 {
325 use super::*;
326 use crate::r#type::instruction::mma::section_1::*;
327
328 impl PtxParser for Atype {
333 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
334 alt!(
335 map(string_p(".bf16"), |_, _span| Atype::Bf16),
336 map(string_p(".tf32"), |_, _span| Atype::Tf32)
337 )
338 }
339 }
340
341 impl PtxParser for Btype {
342 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
343 alt!(
344 map(string_p(".bf16"), |_, _span| Btype::Bf16),
345 map(string_p(".tf32"), |_, _span| Btype::Tf32)
346 )
347 }
348 }
349
350 impl PtxParser for Ctype {
351 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
352 alt!(
353 map(string_p(".f16"), |_, _span| Ctype::F16),
354 map(string_p(".f32"), |_, _span| Ctype::F32)
355 )
356 }
357 }
358
359 impl PtxParser for Dtype {
360 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
361 alt!(
362 map(string_p(".f16"), |_, _span| Dtype::F16),
363 map(string_p(".f32"), |_, _span| Dtype::F32)
364 )
365 }
366 }
367
368 impl PtxParser for F8f6f4type {
369 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
370 alt!(
371 map(string_p(".e4m3"), |_, _span| F8f6f4type::E4m3),
372 map(string_p(".e5m2"), |_, _span| F8f6f4type::E5m2),
373 map(string_p(".e3m2"), |_, _span| F8f6f4type::E3m2),
374 map(string_p(".e2m3"), |_, _span| F8f6f4type::E2m3),
375 map(string_p(".e2m1"), |_, _span| F8f6f4type::E2m1)
376 )
377 }
378 }
379
380 impl PtxParser for F8type {
381 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
382 alt!(
383 map(string_p(".e4m3"), |_, _span| F8type::E4m3),
384 map(string_p(".e5m2"), |_, _span| F8type::E5m2)
385 )
386 }
387 }
388
389 impl PtxParser for Kind {
390 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
391 alt!(map(string_p(".kind::f8f6f4"), |_, _span| Kind::KindF8f6f4))
392 }
393 }
394
395 impl PtxParser for Shape {
396 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
397 alt!(
398 map(string_p(".m16n8k16"), |_, _span| Shape::M16n8k16),
399 map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32)
400 )
401 }
402 }
403
404 impl PtxParser for MmaSyncAlignedM16n8k4RowColF32Tf32Tf32F32 {
405 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
406 try_map(
407 seq_n!(
408 string_p("mma"),
409 string_p(".sync"),
410 string_p(".aligned"),
411 string_p(".m16n8k4"),
412 string_p(".row"),
413 string_p(".col"),
414 string_p(".f32"),
415 string_p(".tf32"),
416 string_p(".tf32"),
417 string_p(".f32"),
418 GeneralOperand::parse(),
419 comma_p(),
420 GeneralOperand::parse(),
421 comma_p(),
422 GeneralOperand::parse(),
423 comma_p(),
424 GeneralOperand::parse(),
425 semicolon_p()
426 ),
427 |(
428 _,
429 sync,
430 aligned,
431 m16n8k4,
432 row,
433 col,
434 f32,
435 tf32,
436 tf322,
437 f322,
438 d,
439 _,
440 a,
441 _,
442 b,
443 _,
444 c,
445 _,
446 ),
447 span| {
448 ok!(MmaSyncAlignedM16n8k4RowColF32Tf32Tf32F32 {
449 sync = sync,
450 aligned = aligned,
451 m16n8k4 = m16n8k4,
452 row = row,
453 col = col,
454 f32 = f32,
455 tf32 = tf32,
456 tf322 = tf322,
457 f322 = f322,
458 d = d,
459 a = a,
460 b = b,
461 c = c,
462
463 })
464 },
465 )
466 }
467 }
468
469 impl PtxParser for MmaSyncAlignedM16n8k8RowColF32AtypeBtypeF32 {
470 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
471 try_map(
472 seq_n!(
473 string_p("mma"),
474 string_p(".sync"),
475 string_p(".aligned"),
476 string_p(".m16n8k8"),
477 string_p(".row"),
478 string_p(".col"),
479 string_p(".f32"),
480 Atype::parse(),
481 Btype::parse(),
482 string_p(".f32"),
483 GeneralOperand::parse(),
484 comma_p(),
485 GeneralOperand::parse(),
486 comma_p(),
487 GeneralOperand::parse(),
488 comma_p(),
489 GeneralOperand::parse(),
490 semicolon_p()
491 ),
492 |(
493 _,
494 sync,
495 aligned,
496 m16n8k8,
497 row,
498 col,
499 f32,
500 atype,
501 btype,
502 f322,
503 d,
504 _,
505 a,
506 _,
507 b,
508 _,
509 c,
510 _,
511 ),
512 span| {
513 ok!(MmaSyncAlignedM16n8k8RowColF32AtypeBtypeF32 {
514 sync = sync,
515 aligned = aligned,
516 m16n8k8 = m16n8k8,
517 row = row,
518 col = col,
519 f32 = f32,
520 atype = atype,
521 btype = btype,
522 f322 = f322,
523 d = d,
524 a = a,
525 b = b,
526 c = c,
527
528 })
529 },
530 )
531 }
532 }
533
534 impl PtxParser for MmaSyncAlignedM16n8k16RowColF32Bf16Bf16F32 {
535 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
536 try_map(
537 seq_n!(
538 string_p("mma"),
539 string_p(".sync"),
540 string_p(".aligned"),
541 string_p(".m16n8k16"),
542 string_p(".row"),
543 string_p(".col"),
544 string_p(".f32"),
545 string_p(".bf16"),
546 string_p(".bf16"),
547 string_p(".f32"),
548 GeneralOperand::parse(),
549 comma_p(),
550 GeneralOperand::parse(),
551 comma_p(),
552 GeneralOperand::parse(),
553 comma_p(),
554 GeneralOperand::parse(),
555 semicolon_p()
556 ),
557 |(
558 _,
559 sync,
560 aligned,
561 m16n8k16,
562 row,
563 col,
564 f32,
565 bf16,
566 bf162,
567 f322,
568 d,
569 _,
570 a,
571 _,
572 b,
573 _,
574 c,
575 _,
576 ),
577 span| {
578 ok!(MmaSyncAlignedM16n8k16RowColF32Bf16Bf16F32 {
579 sync = sync,
580 aligned = aligned,
581 m16n8k16 = m16n8k16,
582 row = row,
583 col = col,
584 f32 = f32,
585 bf16 = bf16,
586 bf162 = bf162,
587 f322 = f322,
588 d = d,
589 a = a,
590 b = b,
591 c = c,
592
593 })
594 },
595 )
596 }
597 }
598
599 impl PtxParser for MmaSyncAlignedShapeRowColDtypeF8typeF8typeCtype {
600 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
601 try_map(
602 seq_n!(
603 string_p("mma"),
604 string_p(".sync"),
605 string_p(".aligned"),
606 Shape::parse(),
607 string_p(".row"),
608 string_p(".col"),
609 Dtype::parse(),
610 F8type::parse(),
611 F8type::parse(),
612 Ctype::parse(),
613 GeneralOperand::parse(),
614 comma_p(),
615 GeneralOperand::parse(),
616 comma_p(),
617 GeneralOperand::parse(),
618 comma_p(),
619 GeneralOperand::parse(),
620 semicolon_p()
621 ),
622 |(
623 _,
624 sync,
625 aligned,
626 shape,
627 row,
628 col,
629 dtype,
630 f8type,
631 f8type1,
632 ctype,
633 d,
634 _,
635 a,
636 _,
637 b,
638 _,
639 c,
640 _,
641 ),
642 span| {
643 ok!(MmaSyncAlignedShapeRowColDtypeF8typeF8typeCtype {
644 sync = sync,
645 aligned = aligned,
646 shape = shape,
647 row = row,
648 col = col,
649 dtype = dtype,
650 f8type = f8type,
651 f8type1 = f8type1,
652 ctype = ctype,
653 d = d,
654 a = a,
655 b = b,
656 c = c,
657
658 })
659 },
660 )
661 }
662 }
663
664 impl PtxParser for MmaSyncAlignedM16n8k32RowColKindDtypeF8f6f4typeF8f6f4typeCtype {
665 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
666 try_map(
667 seq_n!(
668 string_p("mma"),
669 string_p(".sync"),
670 string_p(".aligned"),
671 string_p(".m16n8k32"),
672 string_p(".row"),
673 string_p(".col"),
674 Kind::parse(),
675 Dtype::parse(),
676 F8f6f4type::parse(),
677 F8f6f4type::parse(),
678 Ctype::parse(),
679 GeneralOperand::parse(),
680 comma_p(),
681 GeneralOperand::parse(),
682 comma_p(),
683 GeneralOperand::parse(),
684 comma_p(),
685 GeneralOperand::parse(),
686 semicolon_p()
687 ),
688 |(
689 _,
690 sync,
691 aligned,
692 m16n8k32,
693 row,
694 col,
695 kind,
696 dtype,
697 f8f6f4type,
698 f8f6f4type1,
699 ctype,
700 d,
701 _,
702 a,
703 _,
704 b,
705 _,
706 c,
707 _,
708 ),
709 span| {
710 ok!(MmaSyncAlignedM16n8k32RowColKindDtypeF8f6f4typeF8f6f4typeCtype {
711 sync = sync,
712 aligned = aligned,
713 m16n8k32 = m16n8k32,
714 row = row,
715 col = col,
716 kind = kind,
717 dtype = dtype,
718 f8f6f4type = f8f6f4type,
719 f8f6f4type1 = f8f6f4type1,
720 ctype = ctype,
721 d = d,
722 a = a,
723 b = b,
724 c = c,
725
726 })
727 },
728 )
729 }
730 }
731}
732
733pub mod section_2 {
734 use super::*;
735 use crate::r#type::instruction::mma::section_2::*;
736
737 impl PtxParser for Kind {
742 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
743 alt!(map(string_p(".kind::mxf4"), |_, _span| Kind::KindMxf4))
744 }
745 }
746
747 impl PtxParser for ScaleVecSize {
748 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
749 alt!(map(string_p(".scale_vec::2X"), |_, _span| {
750 ScaleVecSize::ScaleVec2x
751 }))
752 }
753 }
754
755 impl PtxParser for Stype {
756 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
757 alt!(map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0))
758 }
759 }
760
761 impl PtxParser for MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype {
762 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
763 try_map(
764 seq_n!(
765 string_p("mma"),
766 string_p(".sync"),
767 string_p(".aligned"),
768 string_p(".m16n8k64"),
769 string_p(".row"),
770 string_p(".col"),
771 Kind::parse(),
772 string_p(".block_scale"),
773 optional(ScaleVecSize::parse()),
774 string_p(".f32"),
775 string_p(".e2m1"),
776 string_p(".e2m1"),
777 string_p(".f32"),
778 Stype::parse(),
779 GeneralOperand::parse(),
780 comma_p(),
781 GeneralOperand::parse(),
782 comma_p(),
783 GeneralOperand::parse(),
784 comma_p(),
785 GeneralOperand::parse(),
786 comma_p(),
787 GeneralOperand::parse(),
788 comma_p(),
789 VectorOperand::parse(),
790 comma_p(),
791 GeneralOperand::parse(),
792 comma_p(),
793 VectorOperand::parse(),
794 semicolon_p()
795 ),
796 |(
797 _,
798 sync,
799 aligned,
800 m16n8k64,
801 row,
802 col,
803 kind,
804 block_scale,
805 scale_vec_size,
806 f32,
807 e2m1,
808 e2m12,
809 f322,
810 stype,
811 d,
812 _,
813 a,
814 _,
815 b,
816 _,
817 c,
818 _,
819 scale_a_data,
820 _,
821 byte_id_a,
822 _,
823 scale_b_data,
824 _,
825 byte_id_b,
826 _,
827 ),
828 span| {
829 ok!(MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype {
830 sync = sync,
831 aligned = aligned,
832 m16n8k64 = m16n8k64,
833 row = row,
834 col = col,
835 kind = kind,
836 block_scale = block_scale,
837 scale_vec_size = scale_vec_size,
838 f32 = f32,
839 e2m1 = e2m1,
840 e2m12 = e2m12,
841 f322 = f322,
842 stype = stype,
843 d = d,
844 a = a,
845 b = b,
846 c = c,
847 scale_a_data = scale_a_data,
848 byte_id_a = byte_id_a,
849 scale_b_data = scale_b_data,
850 byte_id_b = byte_id_b,
851
852 })
853 },
854 )
855 }
856 }
857}
858
859pub mod section_3 {
860 use super::*;
861 use crate::r#type::instruction::mma::section_3::*;
862
863 impl PtxParser for Kind {
868 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
869 alt!(map(string_p(".kind::mxf4nvf4"), |_, _span| {
870 Kind::KindMxf4nvf4
871 }))
872 }
873 }
874
875 impl PtxParser for ScaleVecSize {
876 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
877 alt!(
878 map(string_p(".scale_vec::2X"), |_, _span| {
879 ScaleVecSize::ScaleVec2x
880 }),
881 map(string_p(".scale_vec::4X"), |_, _span| {
882 ScaleVecSize::ScaleVec4x
883 })
884 )
885 }
886 }
887
888 impl PtxParser for Stype {
889 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
890 alt!(
891 map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0),
892 map(string_p(".ue4m3"), |_, _span| Stype::Ue4m3)
893 )
894 }
895 }
896
897 impl PtxParser for MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype1 {
898 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
899 try_map(
900 seq_n!(
901 string_p("mma"),
902 string_p(".sync"),
903 string_p(".aligned"),
904 string_p(".m16n8k64"),
905 string_p(".row"),
906 string_p(".col"),
907 Kind::parse(),
908 string_p(".block_scale"),
909 ScaleVecSize::parse(),
910 string_p(".f32"),
911 string_p(".e2m1"),
912 string_p(".e2m1"),
913 string_p(".f32"),
914 Stype::parse(),
915 GeneralOperand::parse(),
916 comma_p(),
917 GeneralOperand::parse(),
918 comma_p(),
919 GeneralOperand::parse(),
920 comma_p(),
921 GeneralOperand::parse(),
922 comma_p(),
923 GeneralOperand::parse(),
924 comma_p(),
925 VectorOperand::parse(),
926 comma_p(),
927 GeneralOperand::parse(),
928 comma_p(),
929 VectorOperand::parse(),
930 semicolon_p()
931 ),
932 |(
933 _,
934 sync,
935 aligned,
936 m16n8k64,
937 row,
938 col,
939 kind,
940 block_scale,
941 scale_vec_size,
942 f32,
943 e2m1,
944 e2m12,
945 f322,
946 stype,
947 d,
948 _,
949 a,
950 _,
951 b,
952 _,
953 c,
954 _,
955 scale_a_data,
956 _,
957 byte_id_a,
958 _,
959 scale_b_data,
960 _,
961 byte_id_b,
962 _,
963 ),
964 span| {
965 ok!(MmaSyncAlignedM16n8k64RowColKindBlockScaleScaleVecSizeF32E2m1E2m1F32Stype1 {
966 sync = sync,
967 aligned = aligned,
968 m16n8k64 = m16n8k64,
969 row = row,
970 col = col,
971 kind = kind,
972 block_scale = block_scale,
973 scale_vec_size = scale_vec_size,
974 f32 = f32,
975 e2m1 = e2m1,
976 e2m12 = e2m12,
977 f322 = f322,
978 stype = stype,
979 d = d,
980 a = a,
981 b = b,
982 c = c,
983 scale_a_data = scale_a_data,
984 byte_id_a = byte_id_a,
985 scale_b_data = scale_b_data,
986 byte_id_b = byte_id_b,
987
988 })
989 },
990 )
991 }
992 }
993}
994
995pub mod section_4 {
996 use super::*;
997 use crate::r#type::instruction::mma::section_4::*;
998
999 impl PtxParser for F8f6f4type {
1004 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1005 alt!(
1006 map(string_p(".e4m3"), |_, _span| F8f6f4type::E4m3),
1007 map(string_p(".e5m2"), |_, _span| F8f6f4type::E5m2),
1008 map(string_p(".e3m2"), |_, _span| F8f6f4type::E3m2),
1009 map(string_p(".e2m3"), |_, _span| F8f6f4type::E2m3),
1010 map(string_p(".e2m1"), |_, _span| F8f6f4type::E2m1)
1011 )
1012 }
1013 }
1014
1015 impl PtxParser for Kind {
1016 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1017 alt!(map(string_p(".kind::mxf8f6f4"), |_, _span| {
1018 Kind::KindMxf8f6f4
1019 }))
1020 }
1021 }
1022
1023 impl PtxParser for ScaleVecSize {
1024 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1025 alt!(map(string_p(".scale_vec::1X"), |_, _span| {
1026 ScaleVecSize::ScaleVec1x
1027 }))
1028 }
1029 }
1030
1031 impl PtxParser for Stype {
1032 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1033 alt!(map(string_p(".ue8m0"), |_, _span| Stype::Ue8m0))
1034 }
1035 }
1036
1037 impl PtxParser
1038 for MmaSyncAlignedM16n8k32RowColKindBlockScaleScaleVecSizeF32F8f6f4typeF8f6f4typeF32Stype
1039 {
1040 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1041 try_map(
1042 seq_n!(
1043 string_p("mma"),
1044 string_p(".sync"),
1045 string_p(".aligned"),
1046 string_p(".m16n8k32"),
1047 string_p(".row"),
1048 string_p(".col"),
1049 Kind::parse(),
1050 string_p(".block_scale"),
1051 optional(ScaleVecSize::parse()),
1052 string_p(".f32"),
1053 F8f6f4type::parse(),
1054 F8f6f4type::parse(),
1055 string_p(".f32"),
1056 Stype::parse(),
1057 GeneralOperand::parse(),
1058 comma_p(),
1059 GeneralOperand::parse(),
1060 comma_p(),
1061 GeneralOperand::parse(),
1062 comma_p(),
1063 GeneralOperand::parse(),
1064 comma_p(),
1065 GeneralOperand::parse(),
1066 comma_p(),
1067 VectorOperand::parse(),
1068 comma_p(),
1069 GeneralOperand::parse(),
1070 comma_p(),
1071 VectorOperand::parse(),
1072 semicolon_p()
1073 ),
1074 |(
1075 _,
1076 sync,
1077 aligned,
1078 m16n8k32,
1079 row,
1080 col,
1081 kind,
1082 block_scale,
1083 scale_vec_size,
1084 f32,
1085 f8f6f4type,
1086 f8f6f4type1,
1087 f322,
1088 stype,
1089 d,
1090 _,
1091 a,
1092 _,
1093 b,
1094 _,
1095 c,
1096 _,
1097 scale_a_data,
1098 _,
1099 byte_id_a,
1100 _,
1101 scale_b_data,
1102 _,
1103 byte_id_b,
1104 _,
1105 ),
1106 span| {
1107 ok!(MmaSyncAlignedM16n8k32RowColKindBlockScaleScaleVecSizeF32F8f6f4typeF8f6f4typeF32Stype {
1108 sync = sync,
1109 aligned = aligned,
1110 m16n8k32 = m16n8k32,
1111 row = row,
1112 col = col,
1113 kind = kind,
1114 block_scale = block_scale,
1115 scale_vec_size = scale_vec_size,
1116 f32 = f32,
1117 f8f6f4type = f8f6f4type,
1118 f8f6f4type1 = f8f6f4type1,
1119 f322 = f322,
1120 stype = stype,
1121 d = d,
1122 a = a,
1123 b = b,
1124 c = c,
1125 scale_a_data = scale_a_data,
1126 byte_id_a = byte_id_a,
1127 scale_b_data = scale_b_data,
1128 byte_id_b = byte_id_b,
1129
1130 })
1131 },
1132 )
1133 }
1134 }
1135}
1136
1137pub mod section_5 {
1138 use super::*;
1139 use crate::r#type::instruction::mma::section_5::*;
1140
1141 impl PtxParser for Shape {
1146 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1147 alt!(
1148 map(string_p(".m16n8k16"), |_, _span| Shape::M16n8k16),
1149 map(string_p(".m16n8k4"), |_, _span| Shape::M16n8k4),
1150 map(string_p(".m16n8k8"), |_, _span| Shape::M16n8k8),
1151 map(string_p(".m8n84"), |_, _span| Shape::M8n84)
1152 )
1153 }
1154 }
1155
1156 impl PtxParser for MmaSyncAlignedShapeRowColF64F64F64F64 {
1157 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1158 try_map(
1159 seq_n!(
1160 string_p("mma"),
1161 string_p(".sync"),
1162 string_p(".aligned"),
1163 Shape::parse(),
1164 string_p(".row"),
1165 string_p(".col"),
1166 string_p(".f64"),
1167 string_p(".f64"),
1168 string_p(".f64"),
1169 string_p(".f64"),
1170 GeneralOperand::parse(),
1171 comma_p(),
1172 GeneralOperand::parse(),
1173 comma_p(),
1174 GeneralOperand::parse(),
1175 comma_p(),
1176 GeneralOperand::parse(),
1177 semicolon_p()
1178 ),
1179 |(
1180 _,
1181 sync,
1182 aligned,
1183 shape,
1184 row,
1185 col,
1186 f64,
1187 f642,
1188 f644,
1189 f646,
1190 d,
1191 _,
1192 a,
1193 _,
1194 b,
1195 _,
1196 c,
1197 _,
1198 ),
1199 span| {
1200 ok!(MmaSyncAlignedShapeRowColF64F64F64F64 {
1201 sync = sync,
1202 aligned = aligned,
1203 shape = shape,
1204 row = row,
1205 col = col,
1206 f64 = f64,
1207 f642 = f642,
1208 f644 = f644,
1209 f646 = f646,
1210 d = d,
1211 a = a,
1212 b = b,
1213 c = c,
1214
1215 })
1216 },
1217 )
1218 }
1219 }
1220}
1221
1222pub mod section_6 {
1223 use super::*;
1224 use crate::r#type::instruction::mma::section_6::*;
1225
1226 impl PtxParser for Atype {
1231 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1232 alt!(
1233 map(string_p(".u8"), |_, _span| Atype::U8),
1234 map(string_p(".s8"), |_, _span| Atype::S8)
1235 )
1236 }
1237 }
1238
1239 impl PtxParser for Btype {
1240 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1241 alt!(
1242 map(string_p(".u8"), |_, _span| Btype::U8),
1243 map(string_p(".s8"), |_, _span| Btype::S8)
1244 )
1245 }
1246 }
1247
1248 impl PtxParser for Shape {
1249 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1250 alt!(
1251 map(string_p(".m16n8k16"), |_, _span| Shape::M16n8k16),
1252 map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32),
1253 map(string_p(".m8n8k16"), |_, _span| Shape::M8n8k16)
1254 )
1255 }
1256 }
1257
1258 impl PtxParser for MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS32 {
1259 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1260 try_map(
1261 seq_n!(
1262 string_p("mma"),
1263 string_p(".sync"),
1264 string_p(".aligned"),
1265 Shape::parse(),
1266 string_p(".row"),
1267 string_p(".col"),
1268 map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1269 string_p(".s32"),
1270 Atype::parse(),
1271 Btype::parse(),
1272 string_p(".s32"),
1273 GeneralOperand::parse(),
1274 comma_p(),
1275 GeneralOperand::parse(),
1276 comma_p(),
1277 GeneralOperand::parse(),
1278 comma_p(),
1279 GeneralOperand::parse(),
1280 semicolon_p()
1281 ),
1282 |(
1283 _,
1284 sync,
1285 aligned,
1286 shape,
1287 row,
1288 col,
1289 satfinite,
1290 s32,
1291 atype,
1292 btype,
1293 s322,
1294 d,
1295 _,
1296 a,
1297 _,
1298 b,
1299 _,
1300 c,
1301 _,
1302 ),
1303 span| {
1304 ok!(MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS32 {
1305 sync = sync,
1306 aligned = aligned,
1307 shape = shape,
1308 row = row,
1309 col = col,
1310 satfinite = satfinite,
1311 s32 = s32,
1312 atype = atype,
1313 btype = btype,
1314 s322 = s322,
1315 d = d,
1316 a = a,
1317 b = b,
1318 c = c,
1319
1320 })
1321 },
1322 )
1323 }
1324 }
1325}
1326
1327pub mod section_7 {
1328 use super::*;
1329 use crate::r#type::instruction::mma::section_7::*;
1330
1331 impl PtxParser for Atype {
1336 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1337 alt!(
1338 map(string_p(".u4"), |_, _span| Atype::U4),
1339 map(string_p(".s4"), |_, _span| Atype::S4)
1340 )
1341 }
1342 }
1343
1344 impl PtxParser for Btype {
1345 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1346 alt!(
1347 map(string_p(".u4"), |_, _span| Btype::U4),
1348 map(string_p(".s4"), |_, _span| Btype::S4)
1349 )
1350 }
1351 }
1352
1353 impl PtxParser for Shape {
1354 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1355 alt!(
1356 map(string_p(".m16n8k32"), |_, _span| Shape::M16n8k32),
1357 map(string_p(".m16n8k64"), |_, _span| Shape::M16n8k64),
1358 map(string_p(".m8n8k32"), |_, _span| Shape::M8n8k32)
1359 )
1360 }
1361 }
1362
1363 impl PtxParser for MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS321 {
1364 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1365 try_map(
1366 seq_n!(
1367 string_p("mma"),
1368 string_p(".sync"),
1369 string_p(".aligned"),
1370 Shape::parse(),
1371 string_p(".row"),
1372 string_p(".col"),
1373 map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1374 string_p(".s32"),
1375 Atype::parse(),
1376 Btype::parse(),
1377 string_p(".s32"),
1378 GeneralOperand::parse(),
1379 comma_p(),
1380 GeneralOperand::parse(),
1381 comma_p(),
1382 GeneralOperand::parse(),
1383 comma_p(),
1384 GeneralOperand::parse(),
1385 semicolon_p()
1386 ),
1387 |(
1388 _,
1389 sync,
1390 aligned,
1391 shape,
1392 row,
1393 col,
1394 satfinite,
1395 s32,
1396 atype,
1397 btype,
1398 s322,
1399 d,
1400 _,
1401 a,
1402 _,
1403 b,
1404 _,
1405 c,
1406 _,
1407 ),
1408 span| {
1409 ok!(MmaSyncAlignedShapeRowColSatfiniteS32AtypeBtypeS321 {
1410 sync = sync,
1411 aligned = aligned,
1412 shape = shape,
1413 row = row,
1414 col = col,
1415 satfinite = satfinite,
1416 s32 = s32,
1417 atype = atype,
1418 btype = btype,
1419 s322 = s322,
1420 d = d,
1421 a = a,
1422 b = b,
1423 c = c,
1424
1425 })
1426 },
1427 )
1428 }
1429 }
1430}
1431
1432pub mod section_8 {
1433 use super::*;
1434 use crate::r#type::instruction::mma::section_8::*;
1435
1436 impl PtxParser for Bitop {
1441 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1442 alt!(
1443 map(string_p(".xor"), |_, _span| Bitop::Xor),
1444 map(string_p(".and"), |_, _span| Bitop::And)
1445 )
1446 }
1447 }
1448
1449 impl PtxParser for Shape {
1450 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1451 alt!(
1452 map(string_p(".m16n8k128"), |_, _span| Shape::M16n8k128),
1453 map(string_p(".m16n8k256"), |_, _span| Shape::M16n8k256),
1454 map(string_p(".m8n8k128"), |_, _span| Shape::M8n8k128)
1455 )
1456 }
1457 }
1458
1459 impl PtxParser for MmaSyncAlignedShapeRowColS32B1B1S32BitopPopc {
1460 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1461 try_map(
1462 seq_n!(
1463 string_p("mma"),
1464 string_p(".sync"),
1465 string_p(".aligned"),
1466 Shape::parse(),
1467 string_p(".row"),
1468 string_p(".col"),
1469 string_p(".s32"),
1470 string_p(".b1"),
1471 string_p(".b1"),
1472 string_p(".s32"),
1473 Bitop::parse(),
1474 string_p(".popc"),
1475 GeneralOperand::parse(),
1476 comma_p(),
1477 GeneralOperand::parse(),
1478 comma_p(),
1479 GeneralOperand::parse(),
1480 comma_p(),
1481 GeneralOperand::parse(),
1482 semicolon_p()
1483 ),
1484 |(
1485 _,
1486 sync,
1487 aligned,
1488 shape,
1489 row,
1490 col,
1491 s32,
1492 b1,
1493 b12,
1494 s322,
1495 bitop,
1496 popc,
1497 d,
1498 _,
1499 a,
1500 _,
1501 b,
1502 _,
1503 c,
1504 _,
1505 ),
1506 span| {
1507 ok!(MmaSyncAlignedShapeRowColS32B1B1S32BitopPopc {
1508 sync = sync,
1509 aligned = aligned,
1510 shape = shape,
1511 row = row,
1512 col = col,
1513 s32 = s32,
1514 b1 = b1,
1515 b12 = b12,
1516 s322 = s322,
1517 bitop = bitop,
1518 popc = popc,
1519 d = d,
1520 a = a,
1521 b = b,
1522 c = c,
1523
1524 })
1525 },
1526 )
1527 }
1528 }
1529}