1#![allow(unused)]
70
71use crate::parser::{
72 PtxParseError, PtxParser, PtxTokenStream, Span,
73 util::{
74 between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
75 pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
76 },
77};
78use crate::r#type::common::*;
79use crate::{alt, ok, seq_n};
80
81pub mod section_0 {
82 use super::*;
83 use crate::r#type::instruction::wgmma_mma_async_sp::section_0::*;
84
85 impl PtxParser for Dtype {
90 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
91 alt!(
92 map(string_p(".f16"), |_, _span| Dtype::F16),
93 map(string_p(".f32"), |_, _span| Dtype::F32)
94 )
95 }
96 }
97
98 impl PtxParser for Shape {
99 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
100 alt!(
101 map(string_p(".m64n104k32"), |_, _span| Shape::M64n104k32),
102 map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
103 map(string_p(".m64n120k32"), |_, _span| Shape::M64n120k32),
104 map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
105 map(string_p(".m64n136k32"), |_, _span| Shape::M64n136k32),
106 map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
107 map(string_p(".m64n152k32"), |_, _span| Shape::M64n152k32),
108 map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
109 map(string_p(".m64n168k32"), |_, _span| Shape::M64n168k32),
110 map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
111 map(string_p(".m64n184k32"), |_, _span| Shape::M64n184k32),
112 map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
113 map(string_p(".m64n200k32"), |_, _span| Shape::M64n200k32),
114 map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
115 map(string_p(".m64n216k32"), |_, _span| Shape::M64n216k32),
116 map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
117 map(string_p(".m64n232k32"), |_, _span| Shape::M64n232k32),
118 map(string_p(".m64n240k32"), |_, _span| Shape::M64n240k32),
119 map(string_p(".m64n248k32"), |_, _span| Shape::M64n248k32),
120 map(string_p(".m64n256k32"), |_, _span| Shape::M64n256k32),
121 map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
122 map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
123 map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
124 map(string_p(".m64n40k32"), |_, _span| Shape::M64n40k32),
125 map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
126 map(string_p(".m64n56k32"), |_, _span| Shape::M64n56k32),
127 map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
128 map(string_p(".m64n72k32"), |_, _span| Shape::M64n72k32),
129 map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
130 map(string_p(".m64n88k32"), |_, _span| Shape::M64n88k32),
131 map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
132 map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
133 )
134 }
135 }
136
137 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F16 {
138 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
139 try_map(
140 seq_n!(
141 string_p("wgmma"),
142 string_p(".mma_async"),
143 string_p(".sp"),
144 string_p(".sync"),
145 string_p(".aligned"),
146 Shape::parse(),
147 Dtype::parse(),
148 string_p(".f16"),
149 string_p(".f16"),
150 GeneralOperand::parse(),
151 comma_p(),
152 GeneralOperand::parse(),
153 comma_p(),
154 GeneralOperand::parse(),
155 comma_p(),
156 GeneralOperand::parse(),
157 comma_p(),
158 GeneralOperand::parse(),
159 comma_p(),
160 GeneralOperand::parse(),
161 comma_p(),
162 GeneralOperand::parse(),
163 comma_p(),
164 GeneralOperand::parse(),
165 comma_p(),
166 GeneralOperand::parse(),
167 comma_p(),
168 GeneralOperand::parse(),
169 semicolon_p()
170 ),
171 |(
172 _,
173 mma_async,
174 sp,
175 sync,
176 aligned,
177 shape,
178 dtype,
179 f16,
180 f162,
181 d,
182 _,
183 a_desc,
184 _,
185 b_desc,
186 _,
187 sp_meta,
188 _,
189 sp_sel,
190 _,
191 scale_d,
192 _,
193 imm_scale_a,
194 _,
195 imm_scale_b,
196 _,
197 imm_trans_a,
198 _,
199 imm_trans_b,
200 _,
201 ),
202 span| {
203 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F16 {
204 mma_async = mma_async,
205 sp = sp,
206 sync = sync,
207 aligned = aligned,
208 shape = shape,
209 dtype = dtype,
210 f16 = f16,
211 f162 = f162,
212 d = d,
213 a_desc = a_desc,
214 b_desc = b_desc,
215 sp_meta = sp_meta,
216 sp_sel = sp_sel,
217 scale_d = scale_d,
218 imm_scale_a = imm_scale_a,
219 imm_scale_b = imm_scale_b,
220 imm_trans_a = imm_trans_a,
221 imm_trans_b = imm_trans_b,
222
223 })
224 },
225 )
226 }
227 }
228
229 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F161 {
230 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
231 try_map(
232 seq_n!(
233 string_p("wgmma"),
234 string_p(".mma_async"),
235 string_p(".sp"),
236 string_p(".sync"),
237 string_p(".aligned"),
238 Shape::parse(),
239 Dtype::parse(),
240 string_p(".f16"),
241 string_p(".f16"),
242 GeneralOperand::parse(),
243 comma_p(),
244 GeneralOperand::parse(),
245 comma_p(),
246 GeneralOperand::parse(),
247 comma_p(),
248 GeneralOperand::parse(),
249 comma_p(),
250 GeneralOperand::parse(),
251 comma_p(),
252 GeneralOperand::parse(),
253 comma_p(),
254 GeneralOperand::parse(),
255 comma_p(),
256 GeneralOperand::parse(),
257 comma_p(),
258 GeneralOperand::parse(),
259 semicolon_p()
260 ),
261 |(
262 _,
263 mma_async,
264 sp,
265 sync,
266 aligned,
267 shape,
268 dtype,
269 f16,
270 f162,
271 d,
272 _,
273 a,
274 _,
275 b_desc,
276 _,
277 sp_meta,
278 _,
279 sp_sel,
280 _,
281 scale_d,
282 _,
283 imm_scale_a,
284 _,
285 imm_scale_b,
286 _,
287 imm_trans_b,
288 _,
289 ),
290 span| {
291 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeF16F161 {
292 mma_async = mma_async,
293 sp = sp,
294 sync = sync,
295 aligned = aligned,
296 shape = shape,
297 dtype = dtype,
298 f16 = f16,
299 f162 = f162,
300 d = d,
301 a = a,
302 b_desc = b_desc,
303 sp_meta = sp_meta,
304 sp_sel = sp_sel,
305 scale_d = scale_d,
306 imm_scale_a = imm_scale_a,
307 imm_scale_b = imm_scale_b,
308 imm_trans_b = imm_trans_b,
309
310 })
311 },
312 )
313 }
314 }
315}
316
317pub mod section_1 {
318 use super::*;
319 use crate::r#type::instruction::wgmma_mma_async_sp::section_1::*;
320
321 impl PtxParser for Dtype {
326 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
327 alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
328 }
329 }
330
331 impl PtxParser for Shape {
332 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
333 alt!(
334 map(string_p(".m64n104k32"), |_, _span| Shape::M64n104k32),
335 map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
336 map(string_p(".m64n120k32"), |_, _span| Shape::M64n120k32),
337 map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
338 map(string_p(".m64n136k32"), |_, _span| Shape::M64n136k32),
339 map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
340 map(string_p(".m64n152k32"), |_, _span| Shape::M64n152k32),
341 map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
342 map(string_p(".m64n168k32"), |_, _span| Shape::M64n168k32),
343 map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
344 map(string_p(".m64n184k32"), |_, _span| Shape::M64n184k32),
345 map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
346 map(string_p(".m64n200k32"), |_, _span| Shape::M64n200k32),
347 map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
348 map(string_p(".m64n216k32"), |_, _span| Shape::M64n216k32),
349 map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
350 map(string_p(".m64n232k32"), |_, _span| Shape::M64n232k32),
351 map(string_p(".m64n240k32"), |_, _span| Shape::M64n240k32),
352 map(string_p(".m64n248k32"), |_, _span| Shape::M64n248k32),
353 map(string_p(".m64n256k32"), |_, _span| Shape::M64n256k32),
354 map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
355 map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
356 map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
357 map(string_p(".m64n40k32"), |_, _span| Shape::M64n40k32),
358 map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
359 map(string_p(".m64n56k32"), |_, _span| Shape::M64n56k32),
360 map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
361 map(string_p(".m64n72k32"), |_, _span| Shape::M64n72k32),
362 map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
363 map(string_p(".m64n88k32"), |_, _span| Shape::M64n88k32),
364 map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
365 map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
366 )
367 }
368 }
369
370 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf16 {
371 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
372 try_map(
373 seq_n!(
374 string_p("wgmma"),
375 string_p(".mma_async"),
376 string_p(".sp"),
377 string_p(".sync"),
378 string_p(".aligned"),
379 Shape::parse(),
380 Dtype::parse(),
381 string_p(".bf16"),
382 string_p(".bf16"),
383 GeneralOperand::parse(),
384 comma_p(),
385 GeneralOperand::parse(),
386 comma_p(),
387 GeneralOperand::parse(),
388 comma_p(),
389 GeneralOperand::parse(),
390 comma_p(),
391 GeneralOperand::parse(),
392 comma_p(),
393 GeneralOperand::parse(),
394 comma_p(),
395 GeneralOperand::parse(),
396 comma_p(),
397 GeneralOperand::parse(),
398 comma_p(),
399 GeneralOperand::parse(),
400 comma_p(),
401 GeneralOperand::parse(),
402 semicolon_p()
403 ),
404 |(
405 _,
406 mma_async,
407 sp,
408 sync,
409 aligned,
410 shape,
411 dtype,
412 bf16,
413 bf162,
414 d,
415 _,
416 a_desc,
417 _,
418 b_desc,
419 _,
420 sp_meta,
421 _,
422 sp_sel,
423 _,
424 scale_d,
425 _,
426 imm_scale_a,
427 _,
428 imm_scale_b,
429 _,
430 imm_trans_a,
431 _,
432 imm_trans_b,
433 _,
434 ),
435 span| {
436 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf16 {
437 mma_async = mma_async,
438 sp = sp,
439 sync = sync,
440 aligned = aligned,
441 shape = shape,
442 dtype = dtype,
443 bf16 = bf16,
444 bf162 = bf162,
445 d = d,
446 a_desc = a_desc,
447 b_desc = b_desc,
448 sp_meta = sp_meta,
449 sp_sel = sp_sel,
450 scale_d = scale_d,
451 imm_scale_a = imm_scale_a,
452 imm_scale_b = imm_scale_b,
453 imm_trans_a = imm_trans_a,
454 imm_trans_b = imm_trans_b,
455
456 })
457 },
458 )
459 }
460 }
461
462 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf161 {
463 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
464 try_map(
465 seq_n!(
466 string_p("wgmma"),
467 string_p(".mma_async"),
468 string_p(".sp"),
469 string_p(".sync"),
470 string_p(".aligned"),
471 Shape::parse(),
472 Dtype::parse(),
473 string_p(".bf16"),
474 string_p(".bf16"),
475 GeneralOperand::parse(),
476 comma_p(),
477 GeneralOperand::parse(),
478 comma_p(),
479 GeneralOperand::parse(),
480 comma_p(),
481 GeneralOperand::parse(),
482 comma_p(),
483 GeneralOperand::parse(),
484 comma_p(),
485 GeneralOperand::parse(),
486 comma_p(),
487 GeneralOperand::parse(),
488 comma_p(),
489 GeneralOperand::parse(),
490 comma_p(),
491 GeneralOperand::parse(),
492 semicolon_p()
493 ),
494 |(
495 _,
496 mma_async,
497 sp,
498 sync,
499 aligned,
500 shape,
501 dtype,
502 bf16,
503 bf162,
504 d,
505 _,
506 a,
507 _,
508 b_desc,
509 _,
510 sp_meta,
511 _,
512 sp_sel,
513 _,
514 scale_d,
515 _,
516 imm_scale_a,
517 _,
518 imm_scale_b,
519 _,
520 imm_trans_b,
521 _,
522 ),
523 span| {
524 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeBf16Bf161 {
525 mma_async = mma_async,
526 sp = sp,
527 sync = sync,
528 aligned = aligned,
529 shape = shape,
530 dtype = dtype,
531 bf16 = bf16,
532 bf162 = bf162,
533 d = d,
534 a = a,
535 b_desc = b_desc,
536 sp_meta = sp_meta,
537 sp_sel = sp_sel,
538 scale_d = scale_d,
539 imm_scale_a = imm_scale_a,
540 imm_scale_b = imm_scale_b,
541 imm_trans_b = imm_trans_b,
542
543 })
544 },
545 )
546 }
547 }
548}
549
550pub mod section_2 {
551 use super::*;
552 use crate::r#type::instruction::wgmma_mma_async_sp::section_2::*;
553
554 impl PtxParser for Dtype {
559 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
560 alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
561 }
562 }
563
564 impl PtxParser for Shape {
565 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
566 alt!(
567 map(string_p(".m64n104k16"), |_, _span| Shape::M64n104k16),
568 map(string_p(".m64n112k16"), |_, _span| Shape::M64n112k16),
569 map(string_p(".m64n120k16"), |_, _span| Shape::M64n120k16),
570 map(string_p(".m64n128k16"), |_, _span| Shape::M64n128k16),
571 map(string_p(".m64n136k16"), |_, _span| Shape::M64n136k16),
572 map(string_p(".m64n144k16"), |_, _span| Shape::M64n144k16),
573 map(string_p(".m64n152k16"), |_, _span| Shape::M64n152k16),
574 map(string_p(".m64n160k16"), |_, _span| Shape::M64n160k16),
575 map(string_p(".m64n168k16"), |_, _span| Shape::M64n168k16),
576 map(string_p(".m64n176k16"), |_, _span| Shape::M64n176k16),
577 map(string_p(".m64n184k16"), |_, _span| Shape::M64n184k16),
578 map(string_p(".m64n192k16"), |_, _span| Shape::M64n192k16),
579 map(string_p(".m64n200k16"), |_, _span| Shape::M64n200k16),
580 map(string_p(".m64n208k16"), |_, _span| Shape::M64n208k16),
581 map(string_p(".m64n216k16"), |_, _span| Shape::M64n216k16),
582 map(string_p(".m64n224k16"), |_, _span| Shape::M64n224k16),
583 map(string_p(".m64n232k16"), |_, _span| Shape::M64n232k16),
584 map(string_p(".m64n240k16"), |_, _span| Shape::M64n240k16),
585 map(string_p(".m64n248k16"), |_, _span| Shape::M64n248k16),
586 map(string_p(".m64n256k16"), |_, _span| Shape::M64n256k16),
587 map(string_p(".m64n16k16"), |_, _span| Shape::M64n16k16),
588 map(string_p(".m64n24k16"), |_, _span| Shape::M64n24k16),
589 map(string_p(".m64n32k16"), |_, _span| Shape::M64n32k16),
590 map(string_p(".m64n40k16"), |_, _span| Shape::M64n40k16),
591 map(string_p(".m64n48k16"), |_, _span| Shape::M64n48k16),
592 map(string_p(".m64n56k16"), |_, _span| Shape::M64n56k16),
593 map(string_p(".m64n64k16"), |_, _span| Shape::M64n64k16),
594 map(string_p(".m64n72k16"), |_, _span| Shape::M64n72k16),
595 map(string_p(".m64n80k16"), |_, _span| Shape::M64n80k16),
596 map(string_p(".m64n88k16"), |_, _span| Shape::M64n88k16),
597 map(string_p(".m64n96k16"), |_, _span| Shape::M64n96k16),
598 map(string_p(".m64n8k16"), |_, _span| Shape::M64n8k16)
599 )
600 }
601 }
602
603 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf32 {
604 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
605 try_map(
606 seq_n!(
607 string_p("wgmma"),
608 string_p(".mma_async"),
609 string_p(".sp"),
610 string_p(".sync"),
611 string_p(".aligned"),
612 Shape::parse(),
613 Dtype::parse(),
614 string_p(".tf32"),
615 string_p(".tf32"),
616 GeneralOperand::parse(),
617 comma_p(),
618 GeneralOperand::parse(),
619 comma_p(),
620 GeneralOperand::parse(),
621 comma_p(),
622 GeneralOperand::parse(),
623 comma_p(),
624 GeneralOperand::parse(),
625 comma_p(),
626 GeneralOperand::parse(),
627 comma_p(),
628 GeneralOperand::parse(),
629 comma_p(),
630 GeneralOperand::parse(),
631 semicolon_p()
632 ),
633 |(
634 _,
635 mma_async,
636 sp,
637 sync,
638 aligned,
639 shape,
640 dtype,
641 tf32,
642 tf322,
643 d,
644 _,
645 a_desc,
646 _,
647 b_desc,
648 _,
649 sp_meta,
650 _,
651 sp_sel,
652 _,
653 scale_d,
654 _,
655 imm_scale_a,
656 _,
657 imm_scale_b,
658 _,
659 ),
660 span| {
661 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf32 {
662 mma_async = mma_async,
663 sp = sp,
664 sync = sync,
665 aligned = aligned,
666 shape = shape,
667 dtype = dtype,
668 tf32 = tf32,
669 tf322 = tf322,
670 d = d,
671 a_desc = a_desc,
672 b_desc = b_desc,
673 sp_meta = sp_meta,
674 sp_sel = sp_sel,
675 scale_d = scale_d,
676 imm_scale_a = imm_scale_a,
677 imm_scale_b = imm_scale_b,
678
679 })
680 },
681 )
682 }
683 }
684
685 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf321 {
686 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
687 try_map(
688 seq_n!(
689 string_p("wgmma"),
690 string_p(".mma_async"),
691 string_p(".sp"),
692 string_p(".sync"),
693 string_p(".aligned"),
694 Shape::parse(),
695 Dtype::parse(),
696 string_p(".tf32"),
697 string_p(".tf32"),
698 GeneralOperand::parse(),
699 comma_p(),
700 GeneralOperand::parse(),
701 comma_p(),
702 GeneralOperand::parse(),
703 comma_p(),
704 GeneralOperand::parse(),
705 comma_p(),
706 GeneralOperand::parse(),
707 comma_p(),
708 GeneralOperand::parse(),
709 comma_p(),
710 GeneralOperand::parse(),
711 comma_p(),
712 GeneralOperand::parse(),
713 semicolon_p()
714 ),
715 |(
716 _,
717 mma_async,
718 sp,
719 sync,
720 aligned,
721 shape,
722 dtype,
723 tf32,
724 tf322,
725 d,
726 _,
727 a,
728 _,
729 b_desc,
730 _,
731 sp_meta,
732 _,
733 sp_sel,
734 _,
735 scale_d,
736 _,
737 imm_scale_a,
738 _,
739 imm_scale_b,
740 _,
741 ),
742 span| {
743 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeTf32Tf321 {
744 mma_async = mma_async,
745 sp = sp,
746 sync = sync,
747 aligned = aligned,
748 shape = shape,
749 dtype = dtype,
750 tf32 = tf32,
751 tf322 = tf322,
752 d = d,
753 a = a,
754 b_desc = b_desc,
755 sp_meta = sp_meta,
756 sp_sel = sp_sel,
757 scale_d = scale_d,
758 imm_scale_a = imm_scale_a,
759 imm_scale_b = imm_scale_b,
760
761 })
762 },
763 )
764 }
765 }
766}
767
768pub mod section_3 {
769 use super::*;
770 use crate::r#type::instruction::wgmma_mma_async_sp::section_3::*;
771
772 impl PtxParser for Atype {
777 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
778 alt!(
779 map(string_p(".e4m3"), |_, _span| Atype::E4m3),
780 map(string_p(".e5m2"), |_, _span| Atype::E5m2)
781 )
782 }
783 }
784
785 impl PtxParser for Btype {
786 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
787 alt!(
788 map(string_p(".e4m3"), |_, _span| Btype::E4m3),
789 map(string_p(".e5m2"), |_, _span| Btype::E5m2)
790 )
791 }
792 }
793
794 impl PtxParser for Dtype {
795 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
796 alt!(
797 map(string_p(".f16"), |_, _span| Dtype::F16),
798 map(string_p(".f32"), |_, _span| Dtype::F32)
799 )
800 }
801 }
802
803 impl PtxParser for Shape {
804 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
805 alt!(
806 map(string_p(".m64n104k64"), |_, _span| Shape::M64n104k64),
807 map(string_p(".m64n112k64"), |_, _span| Shape::M64n112k64),
808 map(string_p(".m64n120k64"), |_, _span| Shape::M64n120k64),
809 map(string_p(".m64n128k64"), |_, _span| Shape::M64n128k64),
810 map(string_p(".m64n136k64"), |_, _span| Shape::M64n136k64),
811 map(string_p(".m64n144k64"), |_, _span| Shape::M64n144k64),
812 map(string_p(".m64n152k64"), |_, _span| Shape::M64n152k64),
813 map(string_p(".m64n160k64"), |_, _span| Shape::M64n160k64),
814 map(string_p(".m64n168k64"), |_, _span| Shape::M64n168k64),
815 map(string_p(".m64n176k64"), |_, _span| Shape::M64n176k64),
816 map(string_p(".m64n184k64"), |_, _span| Shape::M64n184k64),
817 map(string_p(".m64n192k64"), |_, _span| Shape::M64n192k64),
818 map(string_p(".m64n200k64"), |_, _span| Shape::M64n200k64),
819 map(string_p(".m64n208k64"), |_, _span| Shape::M64n208k64),
820 map(string_p(".m64n216k64"), |_, _span| Shape::M64n216k64),
821 map(string_p(".m64n224k64"), |_, _span| Shape::M64n224k64),
822 map(string_p(".m64n232k64"), |_, _span| Shape::M64n232k64),
823 map(string_p(".m64n240k64"), |_, _span| Shape::M64n240k64),
824 map(string_p(".m64n248k64"), |_, _span| Shape::M64n248k64),
825 map(string_p(".m64n256k64"), |_, _span| Shape::M64n256k64),
826 map(string_p(".m64n16k64"), |_, _span| Shape::M64n16k64),
827 map(string_p(".m64n24k64"), |_, _span| Shape::M64n24k64),
828 map(string_p(".m64n32k64"), |_, _span| Shape::M64n32k64),
829 map(string_p(".m64n40k64"), |_, _span| Shape::M64n40k64),
830 map(string_p(".m64n48k64"), |_, _span| Shape::M64n48k64),
831 map(string_p(".m64n56k64"), |_, _span| Shape::M64n56k64),
832 map(string_p(".m64n64k64"), |_, _span| Shape::M64n64k64),
833 map(string_p(".m64n72k64"), |_, _span| Shape::M64n72k64),
834 map(string_p(".m64n80k64"), |_, _span| Shape::M64n80k64),
835 map(string_p(".m64n88k64"), |_, _span| Shape::M64n88k64),
836 map(string_p(".m64n96k64"), |_, _span| Shape::M64n96k64),
837 map(string_p(".m64n8k64"), |_, _span| Shape::M64n8k64)
838 )
839 }
840 }
841
842 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype {
843 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
844 try_map(
845 seq_n!(
846 string_p("wgmma"),
847 string_p(".mma_async"),
848 string_p(".sp"),
849 string_p(".sync"),
850 string_p(".aligned"),
851 Shape::parse(),
852 Dtype::parse(),
853 Atype::parse(),
854 Btype::parse(),
855 GeneralOperand::parse(),
856 comma_p(),
857 GeneralOperand::parse(),
858 comma_p(),
859 GeneralOperand::parse(),
860 comma_p(),
861 GeneralOperand::parse(),
862 comma_p(),
863 GeneralOperand::parse(),
864 comma_p(),
865 GeneralOperand::parse(),
866 comma_p(),
867 GeneralOperand::parse(),
868 comma_p(),
869 GeneralOperand::parse(),
870 semicolon_p()
871 ),
872 |(
873 _,
874 mma_async,
875 sp,
876 sync,
877 aligned,
878 shape,
879 dtype,
880 atype,
881 btype,
882 d,
883 _,
884 a_desc,
885 _,
886 b_desc,
887 _,
888 sp_meta,
889 _,
890 sp_sel,
891 _,
892 scale_d,
893 _,
894 imm_scale_a,
895 _,
896 imm_scale_b,
897 _,
898 ),
899 span| {
900 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype {
901 mma_async = mma_async,
902 sp = sp,
903 sync = sync,
904 aligned = aligned,
905 shape = shape,
906 dtype = dtype,
907 atype = atype,
908 btype = btype,
909 d = d,
910 a_desc = a_desc,
911 b_desc = b_desc,
912 sp_meta = sp_meta,
913 sp_sel = sp_sel,
914 scale_d = scale_d,
915 imm_scale_a = imm_scale_a,
916 imm_scale_b = imm_scale_b,
917
918 })
919 },
920 )
921 }
922 }
923
924 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype1 {
925 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
926 try_map(
927 seq_n!(
928 string_p("wgmma"),
929 string_p(".mma_async"),
930 string_p(".sp"),
931 string_p(".sync"),
932 string_p(".aligned"),
933 Shape::parse(),
934 Dtype::parse(),
935 Atype::parse(),
936 Btype::parse(),
937 GeneralOperand::parse(),
938 comma_p(),
939 GeneralOperand::parse(),
940 comma_p(),
941 GeneralOperand::parse(),
942 comma_p(),
943 GeneralOperand::parse(),
944 comma_p(),
945 GeneralOperand::parse(),
946 comma_p(),
947 GeneralOperand::parse(),
948 comma_p(),
949 GeneralOperand::parse(),
950 comma_p(),
951 GeneralOperand::parse(),
952 semicolon_p()
953 ),
954 |(
955 _,
956 mma_async,
957 sp,
958 sync,
959 aligned,
960 shape,
961 dtype,
962 atype,
963 btype,
964 d,
965 _,
966 a,
967 _,
968 b_desc,
969 _,
970 sp_meta,
971 _,
972 sp_sel,
973 _,
974 scale_d,
975 _,
976 imm_scale_a,
977 _,
978 imm_scale_b,
979 _,
980 ),
981 span| {
982 ok!(WgmmaMmaAsyncSpSyncAlignedShapeDtypeAtypeBtype1 {
983 mma_async = mma_async,
984 sp = sp,
985 sync = sync,
986 aligned = aligned,
987 shape = shape,
988 dtype = dtype,
989 atype = atype,
990 btype = btype,
991 d = d,
992 a = a,
993 b_desc = b_desc,
994 sp_meta = sp_meta,
995 sp_sel = sp_sel,
996 scale_d = scale_d,
997 imm_scale_a = imm_scale_a,
998 imm_scale_b = imm_scale_b,
999
1000 })
1001 },
1002 )
1003 }
1004 }
1005}
1006
1007pub mod section_4 {
1008 use super::*;
1009 use crate::r#type::instruction::wgmma_mma_async_sp::section_4::*;
1010
1011 impl PtxParser for Atype {
1016 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1017 alt!(
1018 map(string_p(".s8"), |_, _span| Atype::S8),
1019 map(string_p(".u8"), |_, _span| Atype::U8)
1020 )
1021 }
1022 }
1023
1024 impl PtxParser for Btype {
1025 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1026 alt!(
1027 map(string_p(".s8"), |_, _span| Btype::S8),
1028 map(string_p(".u8"), |_, _span| Btype::U8)
1029 )
1030 }
1031 }
1032
1033 impl PtxParser for Shape {
1034 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1035 alt!(
1036 map(string_p(".m64n112k64"), |_, _span| Shape::M64n112k64),
1037 map(string_p(".m64n128k64"), |_, _span| Shape::M64n128k64),
1038 map(string_p(".m64n144k64"), |_, _span| Shape::M64n144k64),
1039 map(string_p(".m64n160k64"), |_, _span| Shape::M64n160k64),
1040 map(string_p(".m64n176k64"), |_, _span| Shape::M64n176k64),
1041 map(string_p(".m64n192k64"), |_, _span| Shape::M64n192k64),
1042 map(string_p(".m64n208k64"), |_, _span| Shape::M64n208k64),
1043 map(string_p(".m64n224k64"), |_, _span| Shape::M64n224k64),
1044 map(string_p(".m64n240k64"), |_, _span| Shape::M64n240k64),
1045 map(string_p(".m64n256k64"), |_, _span| Shape::M64n256k64),
1046 map(string_p(".m64n16k64"), |_, _span| Shape::M64n16k64),
1047 map(string_p(".m64n24k64"), |_, _span| Shape::M64n24k64),
1048 map(string_p(".m64n32k64"), |_, _span| Shape::M64n32k64),
1049 map(string_p(".m64n48k64"), |_, _span| Shape::M64n48k64),
1050 map(string_p(".m64n64k64"), |_, _span| Shape::M64n64k64),
1051 map(string_p(".m64n80k64"), |_, _span| Shape::M64n80k64),
1052 map(string_p(".m64n96k64"), |_, _span| Shape::M64n96k64),
1053 map(string_p(".m64n8k64"), |_, _span| Shape::M64n8k64)
1054 )
1055 }
1056 }
1057
1058 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype {
1059 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1060 try_map(
1061 seq_n!(
1062 string_p("wgmma"),
1063 string_p(".mma_async"),
1064 string_p(".sp"),
1065 string_p(".sync"),
1066 string_p(".aligned"),
1067 Shape::parse(),
1068 map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1069 string_p(".s32"),
1070 Atype::parse(),
1071 Btype::parse(),
1072 GeneralOperand::parse(),
1073 comma_p(),
1074 GeneralOperand::parse(),
1075 comma_p(),
1076 GeneralOperand::parse(),
1077 comma_p(),
1078 GeneralOperand::parse(),
1079 comma_p(),
1080 GeneralOperand::parse(),
1081 comma_p(),
1082 GeneralOperand::parse(),
1083 semicolon_p()
1084 ),
1085 |(
1086 _,
1087 mma_async,
1088 sp,
1089 sync,
1090 aligned,
1091 shape,
1092 satfinite,
1093 s32,
1094 atype,
1095 btype,
1096 d,
1097 _,
1098 a_desc,
1099 _,
1100 b_desc,
1101 _,
1102 sp_meta,
1103 _,
1104 sp_sel,
1105 _,
1106 scale_d,
1107 _,
1108 ),
1109 span| {
1110 ok!(WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype {
1111 mma_async = mma_async,
1112 sp = sp,
1113 sync = sync,
1114 aligned = aligned,
1115 shape = shape,
1116 satfinite = satfinite,
1117 s32 = s32,
1118 atype = atype,
1119 btype = btype,
1120 d = d,
1121 a_desc = a_desc,
1122 b_desc = b_desc,
1123 sp_meta = sp_meta,
1124 sp_sel = sp_sel,
1125 scale_d = scale_d,
1126
1127 })
1128 },
1129 )
1130 }
1131 }
1132
1133 impl PtxParser for WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1134 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1135 try_map(
1136 seq_n!(
1137 string_p("wgmma"),
1138 string_p(".mma_async"),
1139 string_p(".sp"),
1140 string_p(".sync"),
1141 string_p(".aligned"),
1142 Shape::parse(),
1143 map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1144 string_p(".s32"),
1145 Atype::parse(),
1146 Btype::parse(),
1147 GeneralOperand::parse(),
1148 comma_p(),
1149 GeneralOperand::parse(),
1150 comma_p(),
1151 GeneralOperand::parse(),
1152 comma_p(),
1153 GeneralOperand::parse(),
1154 comma_p(),
1155 GeneralOperand::parse(),
1156 comma_p(),
1157 GeneralOperand::parse(),
1158 semicolon_p()
1159 ),
1160 |(
1161 _,
1162 mma_async,
1163 sp,
1164 sync,
1165 aligned,
1166 shape,
1167 satfinite,
1168 s32,
1169 atype,
1170 btype,
1171 d,
1172 _,
1173 a,
1174 _,
1175 b_desc,
1176 _,
1177 sp_meta,
1178 _,
1179 sp_sel,
1180 _,
1181 scale_d,
1182 _,
1183 ),
1184 span| {
1185 ok!(WgmmaMmaAsyncSpSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1186 mma_async = mma_async,
1187 sp = sp,
1188 sync = sync,
1189 aligned = aligned,
1190 shape = shape,
1191 satfinite = satfinite,
1192 s32 = s32,
1193 atype = atype,
1194 btype = btype,
1195 d = d,
1196 a = a,
1197 b_desc = b_desc,
1198 sp_meta = sp_meta,
1199 sp_sel = sp_sel,
1200 scale_d = scale_d,
1201
1202 })
1203 },
1204 )
1205 }
1206 }
1207}