1#![allow(unused)]
79
80use crate::parser::{
81 PtxParseError, PtxParser, PtxTokenStream, Span,
82 util::{
83 between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
84 pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
85 },
86};
87use crate::r#type::common::*;
88use crate::{alt, ok, seq_n};
89
90pub mod section_0 {
91 use super::*;
92 use crate::r#type::instruction::wgmma_mma_async::section_0::*;
93
94 impl PtxParser for Dtype {
99 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
100 alt!(
101 map(string_p(".f16"), |_, _span| Dtype::F16),
102 map(string_p(".f32"), |_, _span| Dtype::F32)
103 )
104 }
105 }
106
107 impl PtxParser for Shape {
108 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
109 alt!(
110 map(string_p(".m64n104k16"), |_, _span| Shape::M64n104k16),
111 map(string_p(".m64n112k16"), |_, _span| Shape::M64n112k16),
112 map(string_p(".m64n120k16"), |_, _span| Shape::M64n120k16),
113 map(string_p(".m64n128k16"), |_, _span| Shape::M64n128k16),
114 map(string_p(".m64n136k16"), |_, _span| Shape::M64n136k16),
115 map(string_p(".m64n144k16"), |_, _span| Shape::M64n144k16),
116 map(string_p(".m64n152k16"), |_, _span| Shape::M64n152k16),
117 map(string_p(".m64n160k16"), |_, _span| Shape::M64n160k16),
118 map(string_p(".m64n168k16"), |_, _span| Shape::M64n168k16),
119 map(string_p(".m64n176k16"), |_, _span| Shape::M64n176k16),
120 map(string_p(".m64n184k16"), |_, _span| Shape::M64n184k16),
121 map(string_p(".m64n192k16"), |_, _span| Shape::M64n192k16),
122 map(string_p(".m64n200k16"), |_, _span| Shape::M64n200k16),
123 map(string_p(".m64n208k16"), |_, _span| Shape::M64n208k16),
124 map(string_p(".m64n216k16"), |_, _span| Shape::M64n216k16),
125 map(string_p(".m64n224k16"), |_, _span| Shape::M64n224k16),
126 map(string_p(".m64n232k16"), |_, _span| Shape::M64n232k16),
127 map(string_p(".m64n240k16"), |_, _span| Shape::M64n240k16),
128 map(string_p(".m64n248k16"), |_, _span| Shape::M64n248k16),
129 map(string_p(".m64n256k16"), |_, _span| Shape::M64n256k16),
130 map(string_p(".m64n16k16"), |_, _span| Shape::M64n16k16),
131 map(string_p(".m64n24k16"), |_, _span| Shape::M64n24k16),
132 map(string_p(".m64n32k16"), |_, _span| Shape::M64n32k16),
133 map(string_p(".m64n40k16"), |_, _span| Shape::M64n40k16),
134 map(string_p(".m64n48k16"), |_, _span| Shape::M64n48k16),
135 map(string_p(".m64n56k16"), |_, _span| Shape::M64n56k16),
136 map(string_p(".m64n64k16"), |_, _span| Shape::M64n64k16),
137 map(string_p(".m64n72k16"), |_, _span| Shape::M64n72k16),
138 map(string_p(".m64n80k16"), |_, _span| Shape::M64n80k16),
139 map(string_p(".m64n88k16"), |_, _span| Shape::M64n88k16),
140 map(string_p(".m64n96k16"), |_, _span| Shape::M64n96k16),
141 map(string_p(".m64n8k16"), |_, _span| Shape::M64n8k16)
142 )
143 }
144 }
145
146 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeF16F16 {
147 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
148 try_map(
149 seq_n!(
150 string_p("wgmma"),
151 string_p(".mma_async"),
152 string_p(".sync"),
153 string_p(".aligned"),
154 Shape::parse(),
155 Dtype::parse(),
156 string_p(".f16"),
157 string_p(".f16"),
158 GeneralOperand::parse(),
159 comma_p(),
160 GeneralOperand::parse(),
161 comma_p(),
162 GeneralOperand::parse(),
163 comma_p(),
164 GeneralOperand::parse(),
165 comma_p(),
166 GeneralOperand::parse(),
167 comma_p(),
168 GeneralOperand::parse(),
169 comma_p(),
170 GeneralOperand::parse(),
171 comma_p(),
172 GeneralOperand::parse(),
173 semicolon_p()
174 ),
175 |(
176 _,
177 mma_async,
178 sync,
179 aligned,
180 shape,
181 dtype,
182 f16,
183 f162,
184 d,
185 _,
186 a_desc,
187 _,
188 b_desc,
189 _,
190 scale_d,
191 _,
192 imm_scale_a,
193 _,
194 imm_scale_b,
195 _,
196 imm_trans_a,
197 _,
198 imm_trans_b,
199 _,
200 ),
201 span| {
202 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeF16F16 {
203 mma_async = mma_async,
204 sync = sync,
205 aligned = aligned,
206 shape = shape,
207 dtype = dtype,
208 f16 = f16,
209 f162 = f162,
210 d = d,
211 a_desc = a_desc,
212 b_desc = b_desc,
213 scale_d = scale_d,
214 imm_scale_a = imm_scale_a,
215 imm_scale_b = imm_scale_b,
216 imm_trans_a = imm_trans_a,
217 imm_trans_b = imm_trans_b,
218
219 })
220 },
221 )
222 }
223 }
224
225 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeF16F161 {
226 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
227 try_map(
228 seq_n!(
229 string_p("wgmma"),
230 string_p(".mma_async"),
231 string_p(".sync"),
232 string_p(".aligned"),
233 Shape::parse(),
234 Dtype::parse(),
235 string_p(".f16"),
236 string_p(".f16"),
237 GeneralOperand::parse(),
238 comma_p(),
239 GeneralOperand::parse(),
240 comma_p(),
241 GeneralOperand::parse(),
242 comma_p(),
243 GeneralOperand::parse(),
244 comma_p(),
245 GeneralOperand::parse(),
246 comma_p(),
247 GeneralOperand::parse(),
248 comma_p(),
249 GeneralOperand::parse(),
250 semicolon_p()
251 ),
252 |(
253 _,
254 mma_async,
255 sync,
256 aligned,
257 shape,
258 dtype,
259 f16,
260 f162,
261 d,
262 _,
263 a,
264 _,
265 b_desc,
266 _,
267 scale_d,
268 _,
269 imm_scale_a,
270 _,
271 imm_scale_b,
272 _,
273 imm_trans_b,
274 _,
275 ),
276 span| {
277 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeF16F161 {
278 mma_async = mma_async,
279 sync = sync,
280 aligned = aligned,
281 shape = shape,
282 dtype = dtype,
283 f16 = f16,
284 f162 = f162,
285 d = d,
286 a = a,
287 b_desc = b_desc,
288 scale_d = scale_d,
289 imm_scale_a = imm_scale_a,
290 imm_scale_b = imm_scale_b,
291 imm_trans_b = imm_trans_b,
292
293 })
294 },
295 )
296 }
297 }
298}
299
300pub mod section_1 {
301 use super::*;
302 use crate::r#type::instruction::wgmma_mma_async::section_1::*;
303
304 impl PtxParser for Dtype {
309 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
310 alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
311 }
312 }
313
314 impl PtxParser for Shape {
315 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
316 alt!(
317 map(string_p(".m64n104k16"), |_, _span| Shape::M64n104k16),
318 map(string_p(".m64n112k16"), |_, _span| Shape::M64n112k16),
319 map(string_p(".m64n120k16"), |_, _span| Shape::M64n120k16),
320 map(string_p(".m64n128k16"), |_, _span| Shape::M64n128k16),
321 map(string_p(".m64n136k16"), |_, _span| Shape::M64n136k16),
322 map(string_p(".m64n144k16"), |_, _span| Shape::M64n144k16),
323 map(string_p(".m64n152k16"), |_, _span| Shape::M64n152k16),
324 map(string_p(".m64n160k16"), |_, _span| Shape::M64n160k16),
325 map(string_p(".m64n168k16"), |_, _span| Shape::M64n168k16),
326 map(string_p(".m64n176k16"), |_, _span| Shape::M64n176k16),
327 map(string_p(".m64n184k16"), |_, _span| Shape::M64n184k16),
328 map(string_p(".m64n192k16"), |_, _span| Shape::M64n192k16),
329 map(string_p(".m64n200k16"), |_, _span| Shape::M64n200k16),
330 map(string_p(".m64n208k16"), |_, _span| Shape::M64n208k16),
331 map(string_p(".m64n216k16"), |_, _span| Shape::M64n216k16),
332 map(string_p(".m64n224k16"), |_, _span| Shape::M64n224k16),
333 map(string_p(".m64n232k16"), |_, _span| Shape::M64n232k16),
334 map(string_p(".m64n240k16"), |_, _span| Shape::M64n240k16),
335 map(string_p(".m64n248k16"), |_, _span| Shape::M64n248k16),
336 map(string_p(".m64n256k16"), |_, _span| Shape::M64n256k16),
337 map(string_p(".m64n16k16"), |_, _span| Shape::M64n16k16),
338 map(string_p(".m64n24k16"), |_, _span| Shape::M64n24k16),
339 map(string_p(".m64n32k16"), |_, _span| Shape::M64n32k16),
340 map(string_p(".m64n40k16"), |_, _span| Shape::M64n40k16),
341 map(string_p(".m64n48k16"), |_, _span| Shape::M64n48k16),
342 map(string_p(".m64n56k16"), |_, _span| Shape::M64n56k16),
343 map(string_p(".m64n64k16"), |_, _span| Shape::M64n64k16),
344 map(string_p(".m64n72k16"), |_, _span| Shape::M64n72k16),
345 map(string_p(".m64n80k16"), |_, _span| Shape::M64n80k16),
346 map(string_p(".m64n88k16"), |_, _span| Shape::M64n88k16),
347 map(string_p(".m64n96k16"), |_, _span| Shape::M64n96k16),
348 map(string_p(".m64n8k16"), |_, _span| Shape::M64n8k16)
349 )
350 }
351 }
352
353 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf16 {
354 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
355 try_map(
356 seq_n!(
357 string_p("wgmma"),
358 string_p(".mma_async"),
359 string_p(".sync"),
360 string_p(".aligned"),
361 Shape::parse(),
362 Dtype::parse(),
363 string_p(".bf16"),
364 string_p(".bf16"),
365 GeneralOperand::parse(),
366 comma_p(),
367 GeneralOperand::parse(),
368 comma_p(),
369 GeneralOperand::parse(),
370 comma_p(),
371 GeneralOperand::parse(),
372 comma_p(),
373 GeneralOperand::parse(),
374 comma_p(),
375 GeneralOperand::parse(),
376 comma_p(),
377 GeneralOperand::parse(),
378 comma_p(),
379 GeneralOperand::parse(),
380 semicolon_p()
381 ),
382 |(
383 _,
384 mma_async,
385 sync,
386 aligned,
387 shape,
388 dtype,
389 bf16,
390 bf162,
391 d,
392 _,
393 a_desc,
394 _,
395 b_desc,
396 _,
397 scale_d,
398 _,
399 imm_scale_a,
400 _,
401 imm_scale_b,
402 _,
403 imm_trans_a,
404 _,
405 imm_trans_b,
406 _,
407 ),
408 span| {
409 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf16 {
410 mma_async = mma_async,
411 sync = sync,
412 aligned = aligned,
413 shape = shape,
414 dtype = dtype,
415 bf16 = bf16,
416 bf162 = bf162,
417 d = d,
418 a_desc = a_desc,
419 b_desc = b_desc,
420 scale_d = scale_d,
421 imm_scale_a = imm_scale_a,
422 imm_scale_b = imm_scale_b,
423 imm_trans_a = imm_trans_a,
424 imm_trans_b = imm_trans_b,
425
426 })
427 },
428 )
429 }
430 }
431
432 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf161 {
433 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
434 try_map(
435 seq_n!(
436 string_p("wgmma"),
437 string_p(".mma_async"),
438 string_p(".sync"),
439 string_p(".aligned"),
440 Shape::parse(),
441 Dtype::parse(),
442 string_p(".bf16"),
443 string_p(".bf16"),
444 GeneralOperand::parse(),
445 comma_p(),
446 GeneralOperand::parse(),
447 comma_p(),
448 GeneralOperand::parse(),
449 comma_p(),
450 GeneralOperand::parse(),
451 comma_p(),
452 GeneralOperand::parse(),
453 comma_p(),
454 GeneralOperand::parse(),
455 comma_p(),
456 GeneralOperand::parse(),
457 semicolon_p()
458 ),
459 |(
460 _,
461 mma_async,
462 sync,
463 aligned,
464 shape,
465 dtype,
466 bf16,
467 bf162,
468 d,
469 _,
470 a,
471 _,
472 b_desc,
473 _,
474 scale_d,
475 _,
476 imm_scale_a,
477 _,
478 imm_scale_b,
479 _,
480 imm_trans_b,
481 _,
482 ),
483 span| {
484 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeBf16Bf161 {
485 mma_async = mma_async,
486 sync = sync,
487 aligned = aligned,
488 shape = shape,
489 dtype = dtype,
490 bf16 = bf16,
491 bf162 = bf162,
492 d = d,
493 a = a,
494 b_desc = b_desc,
495 scale_d = scale_d,
496 imm_scale_a = imm_scale_a,
497 imm_scale_b = imm_scale_b,
498 imm_trans_b = imm_trans_b,
499
500 })
501 },
502 )
503 }
504 }
505}
506
507pub mod section_2 {
508 use super::*;
509 use crate::r#type::instruction::wgmma_mma_async::section_2::*;
510
511 impl PtxParser for Dtype {
516 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
517 alt!(map(string_p(".f32"), |_, _span| Dtype::F32))
518 }
519 }
520
521 impl PtxParser for Shape {
522 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
523 alt!(
524 map(string_p(".m64n104k8"), |_, _span| Shape::M64n104k8),
525 map(string_p(".m64n112k8"), |_, _span| Shape::M64n112k8),
526 map(string_p(".m64n120k8"), |_, _span| Shape::M64n120k8),
527 map(string_p(".m64n128k8"), |_, _span| Shape::M64n128k8),
528 map(string_p(".m64n136k8"), |_, _span| Shape::M64n136k8),
529 map(string_p(".m64n144k8"), |_, _span| Shape::M64n144k8),
530 map(string_p(".m64n152k8"), |_, _span| Shape::M64n152k8),
531 map(string_p(".m64n160k8"), |_, _span| Shape::M64n160k8),
532 map(string_p(".m64n168k8"), |_, _span| Shape::M64n168k8),
533 map(string_p(".m64n176k8"), |_, _span| Shape::M64n176k8),
534 map(string_p(".m64n184k8"), |_, _span| Shape::M64n184k8),
535 map(string_p(".m64n192k8"), |_, _span| Shape::M64n192k8),
536 map(string_p(".m64n200k8"), |_, _span| Shape::M64n200k8),
537 map(string_p(".m64n208k8"), |_, _span| Shape::M64n208k8),
538 map(string_p(".m64n216k8"), |_, _span| Shape::M64n216k8),
539 map(string_p(".m64n224k8"), |_, _span| Shape::M64n224k8),
540 map(string_p(".m64n232k8"), |_, _span| Shape::M64n232k8),
541 map(string_p(".m64n240k8"), |_, _span| Shape::M64n240k8),
542 map(string_p(".m64n248k8"), |_, _span| Shape::M64n248k8),
543 map(string_p(".m64n256k8"), |_, _span| Shape::M64n256k8),
544 map(string_p(".m64n16k8"), |_, _span| Shape::M64n16k8),
545 map(string_p(".m64n24k8"), |_, _span| Shape::M64n24k8),
546 map(string_p(".m64n32k8"), |_, _span| Shape::M64n32k8),
547 map(string_p(".m64n40k8"), |_, _span| Shape::M64n40k8),
548 map(string_p(".m64n48k8"), |_, _span| Shape::M64n48k8),
549 map(string_p(".m64n56k8"), |_, _span| Shape::M64n56k8),
550 map(string_p(".m64n64k8"), |_, _span| Shape::M64n64k8),
551 map(string_p(".m64n72k8"), |_, _span| Shape::M64n72k8),
552 map(string_p(".m64n80k8"), |_, _span| Shape::M64n80k8),
553 map(string_p(".m64n88k8"), |_, _span| Shape::M64n88k8),
554 map(string_p(".m64n96k8"), |_, _span| Shape::M64n96k8),
555 map(string_p(".m64n8k8"), |_, _span| Shape::M64n8k8)
556 )
557 }
558 }
559
560 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf32 {
561 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
562 try_map(
563 seq_n!(
564 string_p("wgmma"),
565 string_p(".mma_async"),
566 string_p(".sync"),
567 string_p(".aligned"),
568 Shape::parse(),
569 Dtype::parse(),
570 string_p(".tf32"),
571 string_p(".tf32"),
572 GeneralOperand::parse(),
573 comma_p(),
574 GeneralOperand::parse(),
575 comma_p(),
576 GeneralOperand::parse(),
577 comma_p(),
578 GeneralOperand::parse(),
579 comma_p(),
580 GeneralOperand::parse(),
581 comma_p(),
582 GeneralOperand::parse(),
583 semicolon_p()
584 ),
585 |(
586 _,
587 mma_async,
588 sync,
589 aligned,
590 shape,
591 dtype,
592 tf32,
593 tf322,
594 d,
595 _,
596 a_desc,
597 _,
598 b_desc,
599 _,
600 scale_d,
601 _,
602 imm_scale_a,
603 _,
604 imm_scale_b,
605 _,
606 ),
607 span| {
608 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf32 {
609 mma_async = mma_async,
610 sync = sync,
611 aligned = aligned,
612 shape = shape,
613 dtype = dtype,
614 tf32 = tf32,
615 tf322 = tf322,
616 d = d,
617 a_desc = a_desc,
618 b_desc = b_desc,
619 scale_d = scale_d,
620 imm_scale_a = imm_scale_a,
621 imm_scale_b = imm_scale_b,
622
623 })
624 },
625 )
626 }
627 }
628
629 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf321 {
630 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
631 try_map(
632 seq_n!(
633 string_p("wgmma"),
634 string_p(".mma_async"),
635 string_p(".sync"),
636 string_p(".aligned"),
637 Shape::parse(),
638 Dtype::parse(),
639 string_p(".tf32"),
640 string_p(".tf32"),
641 GeneralOperand::parse(),
642 comma_p(),
643 GeneralOperand::parse(),
644 comma_p(),
645 GeneralOperand::parse(),
646 comma_p(),
647 GeneralOperand::parse(),
648 comma_p(),
649 GeneralOperand::parse(),
650 comma_p(),
651 GeneralOperand::parse(),
652 semicolon_p()
653 ),
654 |(
655 _,
656 mma_async,
657 sync,
658 aligned,
659 shape,
660 dtype,
661 tf32,
662 tf322,
663 d,
664 _,
665 a,
666 _,
667 b_desc,
668 _,
669 scale_d,
670 _,
671 imm_scale_a,
672 _,
673 imm_scale_b,
674 _,
675 ),
676 span| {
677 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeTf32Tf321 {
678 mma_async = mma_async,
679 sync = sync,
680 aligned = aligned,
681 shape = shape,
682 dtype = dtype,
683 tf32 = tf32,
684 tf322 = tf322,
685 d = d,
686 a = a,
687 b_desc = b_desc,
688 scale_d = scale_d,
689 imm_scale_a = imm_scale_a,
690 imm_scale_b = imm_scale_b,
691
692 })
693 },
694 )
695 }
696 }
697}
698
699pub mod section_3 {
700 use super::*;
701 use crate::r#type::instruction::wgmma_mma_async::section_3::*;
702
703 impl PtxParser for Atype {
708 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
709 alt!(
710 map(string_p(".e4m3"), |_, _span| Atype::E4m3),
711 map(string_p(".e5m2"), |_, _span| Atype::E5m2)
712 )
713 }
714 }
715
716 impl PtxParser for Btype {
717 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
718 alt!(
719 map(string_p(".e4m3"), |_, _span| Btype::E4m3),
720 map(string_p(".e5m2"), |_, _span| Btype::E5m2)
721 )
722 }
723 }
724
725 impl PtxParser for Dtype {
726 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
727 alt!(
728 map(string_p(".f16"), |_, _span| Dtype::F16),
729 map(string_p(".f32"), |_, _span| Dtype::F32)
730 )
731 }
732 }
733
734 impl PtxParser for Shape {
735 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
736 alt!(
737 map(string_p(".m64n104k32"), |_, _span| Shape::M64n104k32),
738 map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
739 map(string_p(".m64n120k32"), |_, _span| Shape::M64n120k32),
740 map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
741 map(string_p(".m64n136k32"), |_, _span| Shape::M64n136k32),
742 map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
743 map(string_p(".m64n152k32"), |_, _span| Shape::M64n152k32),
744 map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
745 map(string_p(".m64n168k32"), |_, _span| Shape::M64n168k32),
746 map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
747 map(string_p(".m64n184k32"), |_, _span| Shape::M64n184k32),
748 map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
749 map(string_p(".m64n200k32"), |_, _span| Shape::M64n200k32),
750 map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
751 map(string_p(".m64n216k32"), |_, _span| Shape::M64n216k32),
752 map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
753 map(string_p(".m64n232k32"), |_, _span| Shape::M64n232k32),
754 map(string_p(".m64n240k32"), |_, _span| Shape::M64n240k32),
755 map(string_p(".m64n248k32"), |_, _span| Shape::M64n248k32),
756 map(string_p(".m64n256k32"), |_, _span| Shape::M64n256k32),
757 map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
758 map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
759 map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
760 map(string_p(".m64n40k32"), |_, _span| Shape::M64n40k32),
761 map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
762 map(string_p(".m64n56k32"), |_, _span| Shape::M64n56k32),
763 map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
764 map(string_p(".m64n72k32"), |_, _span| Shape::M64n72k32),
765 map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
766 map(string_p(".m64n88k32"), |_, _span| Shape::M64n88k32),
767 map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
768 map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
769 )
770 }
771 }
772
773 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype {
774 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
775 try_map(
776 seq_n!(
777 string_p("wgmma"),
778 string_p(".mma_async"),
779 string_p(".sync"),
780 string_p(".aligned"),
781 Shape::parse(),
782 Dtype::parse(),
783 Atype::parse(),
784 Btype::parse(),
785 GeneralOperand::parse(),
786 comma_p(),
787 GeneralOperand::parse(),
788 comma_p(),
789 GeneralOperand::parse(),
790 comma_p(),
791 GeneralOperand::parse(),
792 comma_p(),
793 GeneralOperand::parse(),
794 comma_p(),
795 GeneralOperand::parse(),
796 semicolon_p()
797 ),
798 |(
799 _,
800 mma_async,
801 sync,
802 aligned,
803 shape,
804 dtype,
805 atype,
806 btype,
807 d,
808 _,
809 a_desc,
810 _,
811 b_desc,
812 _,
813 scale_d,
814 _,
815 imm_scale_a,
816 _,
817 imm_scale_b,
818 _,
819 ),
820 span| {
821 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype {
822 mma_async = mma_async,
823 sync = sync,
824 aligned = aligned,
825 shape = shape,
826 dtype = dtype,
827 atype = atype,
828 btype = btype,
829 d = d,
830 a_desc = a_desc,
831 b_desc = b_desc,
832 scale_d = scale_d,
833 imm_scale_a = imm_scale_a,
834 imm_scale_b = imm_scale_b,
835
836 })
837 },
838 )
839 }
840 }
841
842 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype1 {
843 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
844 try_map(
845 seq_n!(
846 string_p("wgmma"),
847 string_p(".mma_async"),
848 string_p(".sync"),
849 string_p(".aligned"),
850 Shape::parse(),
851 Dtype::parse(),
852 Atype::parse(),
853 Btype::parse(),
854 GeneralOperand::parse(),
855 comma_p(),
856 GeneralOperand::parse(),
857 comma_p(),
858 GeneralOperand::parse(),
859 comma_p(),
860 GeneralOperand::parse(),
861 comma_p(),
862 GeneralOperand::parse(),
863 comma_p(),
864 GeneralOperand::parse(),
865 semicolon_p()
866 ),
867 |(
868 _,
869 mma_async,
870 sync,
871 aligned,
872 shape,
873 dtype,
874 atype,
875 btype,
876 d,
877 _,
878 a,
879 _,
880 b_desc,
881 _,
882 scale_d,
883 _,
884 imm_scale_a,
885 _,
886 imm_scale_b,
887 _,
888 ),
889 span| {
890 ok!(WgmmaMmaAsyncSyncAlignedShapeDtypeAtypeBtype1 {
891 mma_async = mma_async,
892 sync = sync,
893 aligned = aligned,
894 shape = shape,
895 dtype = dtype,
896 atype = atype,
897 btype = btype,
898 d = d,
899 a = a,
900 b_desc = b_desc,
901 scale_d = scale_d,
902 imm_scale_a = imm_scale_a,
903 imm_scale_b = imm_scale_b,
904
905 })
906 },
907 )
908 }
909 }
910}
911
912pub mod section_4 {
913 use super::*;
914 use crate::r#type::instruction::wgmma_mma_async::section_4::*;
915
916 impl PtxParser for Atype {
921 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
922 alt!(
923 map(string_p(".s8"), |_, _span| Atype::S8),
924 map(string_p(".u8"), |_, _span| Atype::U8)
925 )
926 }
927 }
928
929 impl PtxParser for Btype {
930 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
931 alt!(
932 map(string_p(".s8"), |_, _span| Btype::S8),
933 map(string_p(".u8"), |_, _span| Btype::U8)
934 )
935 }
936 }
937
938 impl PtxParser for Shape {
939 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
940 alt!(
941 map(string_p(".m64n112k32"), |_, _span| Shape::M64n112k32),
942 map(string_p(".m64n128k32"), |_, _span| Shape::M64n128k32),
943 map(string_p(".m64n144k32"), |_, _span| Shape::M64n144k32),
944 map(string_p(".m64n160k32"), |_, _span| Shape::M64n160k32),
945 map(string_p(".m64n176k32"), |_, _span| Shape::M64n176k32),
946 map(string_p(".m64n192k32"), |_, _span| Shape::M64n192k32),
947 map(string_p(".m64n208k32"), |_, _span| Shape::M64n208k32),
948 map(string_p(".m64n224k32"), |_, _span| Shape::M64n224k32),
949 map(string_p(".m64n16k32"), |_, _span| Shape::M64n16k32),
950 map(string_p(".m64n24k32"), |_, _span| Shape::M64n24k32),
951 map(string_p(".m64n32k32"), |_, _span| Shape::M64n32k32),
952 map(string_p(".m64n48k32"), |_, _span| Shape::M64n48k32),
953 map(string_p(".m64n64k32"), |_, _span| Shape::M64n64k32),
954 map(string_p(".m64n80k32"), |_, _span| Shape::M64n80k32),
955 map(string_p(".m64n96k32"), |_, _span| Shape::M64n96k32),
956 map(string_p(".m64n8k32"), |_, _span| Shape::M64n8k32)
957 )
958 }
959 }
960
961 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype {
962 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
963 try_map(
964 seq_n!(
965 string_p("wgmma"),
966 string_p(".mma_async"),
967 string_p(".sync"),
968 string_p(".aligned"),
969 Shape::parse(),
970 map(optional(string_p(".satfinite")), |value, _| value.is_some()),
971 string_p(".s32"),
972 Atype::parse(),
973 Btype::parse(),
974 GeneralOperand::parse(),
975 comma_p(),
976 GeneralOperand::parse(),
977 comma_p(),
978 GeneralOperand::parse(),
979 comma_p(),
980 GeneralOperand::parse(),
981 semicolon_p()
982 ),
983 |(
984 _,
985 mma_async,
986 sync,
987 aligned,
988 shape,
989 satfinite,
990 s32,
991 atype,
992 btype,
993 d,
994 _,
995 a_desc,
996 _,
997 b_desc,
998 _,
999 scale_d,
1000 _,
1001 ),
1002 span| {
1003 ok!(WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype {
1004 mma_async = mma_async,
1005 sync = sync,
1006 aligned = aligned,
1007 shape = shape,
1008 satfinite = satfinite,
1009 s32 = s32,
1010 atype = atype,
1011 btype = btype,
1012 d = d,
1013 a_desc = a_desc,
1014 b_desc = b_desc,
1015 scale_d = scale_d,
1016
1017 })
1018 },
1019 )
1020 }
1021 }
1022
1023 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1024 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1025 try_map(
1026 seq_n!(
1027 string_p("wgmma"),
1028 string_p(".mma_async"),
1029 string_p(".sync"),
1030 string_p(".aligned"),
1031 Shape::parse(),
1032 map(optional(string_p(".satfinite")), |value, _| value.is_some()),
1033 string_p(".s32"),
1034 Atype::parse(),
1035 Btype::parse(),
1036 GeneralOperand::parse(),
1037 comma_p(),
1038 GeneralOperand::parse(),
1039 comma_p(),
1040 GeneralOperand::parse(),
1041 comma_p(),
1042 GeneralOperand::parse(),
1043 semicolon_p()
1044 ),
1045 |(
1046 _,
1047 mma_async,
1048 sync,
1049 aligned,
1050 shape,
1051 satfinite,
1052 s32,
1053 atype,
1054 btype,
1055 d,
1056 _,
1057 a,
1058 _,
1059 b_desc,
1060 _,
1061 scale_d,
1062 _,
1063 ),
1064 span| {
1065 ok!(WgmmaMmaAsyncSyncAlignedShapeSatfiniteS32AtypeBtype1 {
1066 mma_async = mma_async,
1067 sync = sync,
1068 aligned = aligned,
1069 shape = shape,
1070 satfinite = satfinite,
1071 s32 = s32,
1072 atype = atype,
1073 btype = btype,
1074 d = d,
1075 a = a,
1076 b_desc = b_desc,
1077 scale_d = scale_d,
1078
1079 })
1080 },
1081 )
1082 }
1083 }
1084}
1085
1086pub mod section_5 {
1087 use super::*;
1088 use crate::r#type::instruction::wgmma_mma_async::section_5::*;
1089
1090 impl PtxParser for Op {
1095 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1096 alt!(map(string_p(".and"), |_, _span| Op::And))
1097 }
1098 }
1099
1100 impl PtxParser for Shape {
1101 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1102 alt!(
1103 map(string_p(".m64n112k256"), |_, _span| Shape::M64n112k256),
1104 map(string_p(".m64n128k256"), |_, _span| Shape::M64n128k256),
1105 map(string_p(".m64n144k256"), |_, _span| Shape::M64n144k256),
1106 map(string_p(".m64n160k256"), |_, _span| Shape::M64n160k256),
1107 map(string_p(".m64n176k256"), |_, _span| Shape::M64n176k256),
1108 map(string_p(".m64n192k256"), |_, _span| Shape::M64n192k256),
1109 map(string_p(".m64n208k256"), |_, _span| Shape::M64n208k256),
1110 map(string_p(".m64n224k256"), |_, _span| Shape::M64n224k256),
1111 map(string_p(".m64n240k256"), |_, _span| Shape::M64n240k256),
1112 map(string_p(".m64n256k256"), |_, _span| Shape::M64n256k256),
1113 map(string_p(".m64n16k256"), |_, _span| Shape::M64n16k256),
1114 map(string_p(".m64n24k256"), |_, _span| Shape::M64n24k256),
1115 map(string_p(".m64n32k256"), |_, _span| Shape::M64n32k256),
1116 map(string_p(".m64n48k256"), |_, _span| Shape::M64n48k256),
1117 map(string_p(".m64n64k256"), |_, _span| Shape::M64n64k256),
1118 map(string_p(".m64n80k256"), |_, _span| Shape::M64n80k256),
1119 map(string_p(".m64n96k256"), |_, _span| Shape::M64n96k256),
1120 map(string_p(".m64n8k256"), |_, _span| Shape::M64n8k256)
1121 )
1122 }
1123 }
1124
1125 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc {
1126 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1127 try_map(
1128 seq_n!(
1129 string_p("wgmma"),
1130 string_p(".mma_async"),
1131 string_p(".sync"),
1132 string_p(".aligned"),
1133 Shape::parse(),
1134 string_p(".s32"),
1135 string_p(".b1"),
1136 string_p(".b1"),
1137 Op::parse(),
1138 string_p(".popc"),
1139 GeneralOperand::parse(),
1140 comma_p(),
1141 GeneralOperand::parse(),
1142 comma_p(),
1143 GeneralOperand::parse(),
1144 comma_p(),
1145 GeneralOperand::parse(),
1146 semicolon_p()
1147 ),
1148 |(
1149 _,
1150 mma_async,
1151 sync,
1152 aligned,
1153 shape,
1154 s32,
1155 b1,
1156 b12,
1157 op,
1158 popc,
1159 d,
1160 _,
1161 a_desc,
1162 _,
1163 b_desc,
1164 _,
1165 scale_d,
1166 _,
1167 ),
1168 span| {
1169 ok!(WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc {
1170 mma_async = mma_async,
1171 sync = sync,
1172 aligned = aligned,
1173 shape = shape,
1174 s32 = s32,
1175 b1 = b1,
1176 b12 = b12,
1177 op = op,
1178 popc = popc,
1179 d = d,
1180 a_desc = a_desc,
1181 b_desc = b_desc,
1182 scale_d = scale_d,
1183
1184 })
1185 },
1186 )
1187 }
1188 }
1189
1190 impl PtxParser for WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc1 {
1191 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
1192 try_map(
1193 seq_n!(
1194 string_p("wgmma"),
1195 string_p(".mma_async"),
1196 string_p(".sync"),
1197 string_p(".aligned"),
1198 Shape::parse(),
1199 string_p(".s32"),
1200 string_p(".b1"),
1201 string_p(".b1"),
1202 Op::parse(),
1203 string_p(".popc"),
1204 GeneralOperand::parse(),
1205 comma_p(),
1206 GeneralOperand::parse(),
1207 comma_p(),
1208 GeneralOperand::parse(),
1209 comma_p(),
1210 GeneralOperand::parse(),
1211 semicolon_p()
1212 ),
1213 |(
1214 _,
1215 mma_async,
1216 sync,
1217 aligned,
1218 shape,
1219 s32,
1220 b1,
1221 b12,
1222 op,
1223 popc,
1224 d,
1225 _,
1226 a,
1227 _,
1228 b_desc,
1229 _,
1230 scale_d,
1231 _,
1232 ),
1233 span| {
1234 ok!(WgmmaMmaAsyncSyncAlignedShapeS32B1B1OpPopc1 {
1235 mma_async = mma_async,
1236 sync = sync,
1237 aligned = aligned,
1238 shape = shape,
1239 s32 = s32,
1240 b1 = b1,
1241 b12 = b12,
1242 op = op,
1243 popc = popc,
1244 d = d,
1245 a = a,
1246 b_desc = b_desc,
1247 scale_d = scale_d,
1248
1249 })
1250 },
1251 )
1252 }
1253 }
1254}