1use crate::arch::SmVersion;
10use crate::error::PtxGenError;
11use crate::ir::{
12 CacheQualifier, CmpOp, FenceScope, ImmValue, Instruction, MemorySpace, MmaShape, MulMode,
13 Operand, PtxType, Register, RegisterAllocator, RoundingMode, SpecialReg, VectorWidth,
14};
15
16pub struct BodyBuilder<'a> {
27 pub(super) regs: &'a mut RegisterAllocator,
29 pub(super) instructions: &'a mut Vec<Instruction>,
31 label_counter: u32,
33 param_names: &'a [String],
35 pub(super) target: SmVersion,
37}
38
39impl<'a> BodyBuilder<'a> {
40 pub(crate) const fn new(
47 regs: &'a mut RegisterAllocator,
48 instructions: &'a mut Vec<Instruction>,
49 param_names: &'a [String],
50 target: SmVersion,
51 ) -> Self {
52 Self {
53 regs,
54 instructions,
55 label_counter: 0,
56 param_names,
57 target,
58 }
59 }
60
61 pub fn load_param_u32(&mut self, name: &str) -> Register {
69 self.load_param(name, PtxType::U32)
70 }
71
72 pub fn load_param_u64(&mut self, name: &str) -> Register {
76 self.load_param(name, PtxType::U64)
77 }
78
79 pub fn load_param_f32(&mut self, name: &str) -> Register {
83 self.load_param(name, PtxType::F32)
84 }
85
86 pub fn load_param_f64(&mut self, name: &str) -> Register {
90 self.load_param(name, PtxType::F64)
91 }
92
93 fn load_param(&mut self, name: &str, ty: PtxType) -> Register {
98 let dst = self.regs.alloc(ty);
99 self.emit(Instruction::LoadParam {
100 ty,
101 dst: dst.clone(),
102 param_name: format!("%param_{name}"),
103 });
104 dst
105 }
106
107 pub fn global_thread_id_x(&mut self) -> Register {
122 let tid = self.read_special_reg(SpecialReg::TidX);
123 let ntid = self.read_special_reg(SpecialReg::NtidX);
124 let ctaid = self.read_special_reg(SpecialReg::CtaidX);
125 let gid = self.regs.alloc(PtxType::U32);
126 self.emit(Instruction::Mad {
127 ty: PtxType::U32,
128 mode: MulMode::Lo,
129 dst: gid.clone(),
130 a: Operand::Register(ctaid),
131 b: Operand::Register(ntid),
132 c: Operand::Register(tid),
133 });
134 gid
135 }
136
137 pub fn global_thread_id_y(&mut self) -> Register {
141 let tid = self.read_special_reg(SpecialReg::TidY);
142 let ntid = self.read_special_reg(SpecialReg::NtidY);
143 let ctaid = self.read_special_reg(SpecialReg::CtaidY);
144 let gid = self.regs.alloc(PtxType::U32);
145 self.emit(Instruction::Mad {
146 ty: PtxType::U32,
147 mode: MulMode::Lo,
148 dst: gid.clone(),
149 a: Operand::Register(ctaid),
150 b: Operand::Register(ntid),
151 c: Operand::Register(tid),
152 });
153 gid
154 }
155
156 pub fn global_thread_id_2d(&mut self) -> (Register, Register) {
161 let col = self.global_thread_id_x();
162 let row = self.global_thread_id_y();
163 (row, col)
164 }
165
166 pub fn thread_id_x(&mut self) -> Register {
168 self.read_special_reg(SpecialReg::TidX)
169 }
170
171 pub fn block_id_x(&mut self) -> Register {
173 self.read_special_reg(SpecialReg::CtaidX)
174 }
175
176 pub fn block_dim_x(&mut self) -> Register {
178 self.read_special_reg(SpecialReg::NtidX)
179 }
180
181 fn read_special_reg(&mut self, sreg: SpecialReg) -> Register {
183 let dst = self.regs.alloc(PtxType::U32);
184 self.emit(Instruction::MovSpecial {
185 dst: dst.clone(),
186 special: sreg,
187 });
188 dst
189 }
190
191 pub fn add_u32(&mut self, a: Register, b: Register) -> Register {
197 self.add_typed(PtxType::U32, a, b)
198 }
199
200 pub fn add_u64(&mut self, a: Register, b: Register) -> Register {
202 self.add_typed(PtxType::U64, a, b)
203 }
204
205 pub fn add_f32(&mut self, a: Register, b: Register) -> Register {
207 self.add_typed(PtxType::F32, a, b)
208 }
209
210 pub fn add_f64(&mut self, a: Register, b: Register) -> Register {
212 self.add_typed(PtxType::F64, a, b)
213 }
214
215 fn add_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
217 let dst = self.regs.alloc(ty);
218 self.emit(Instruction::Add {
219 ty,
220 dst: dst.clone(),
221 a: Operand::Register(a),
222 b: Operand::Register(b),
223 });
224 dst
225 }
226
227 pub fn sub_f32(&mut self, a: Register, b: Register) -> Register {
229 self.sub_typed(PtxType::F32, a, b)
230 }
231
232 pub fn sub_f64(&mut self, a: Register, b: Register) -> Register {
234 self.sub_typed(PtxType::F64, a, b)
235 }
236
237 fn sub_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
239 let dst = self.regs.alloc(ty);
240 self.emit(Instruction::Sub {
241 ty,
242 dst: dst.clone(),
243 a: Operand::Register(a),
244 b: Operand::Register(b),
245 });
246 dst
247 }
248
249 pub fn mul_lo_u32(&mut self, a: Register, b: Register) -> Register {
251 let dst = self.regs.alloc(PtxType::U32);
252 self.emit(Instruction::Mul {
253 ty: PtxType::U32,
254 mode: MulMode::Lo,
255 dst: dst.clone(),
256 a: Operand::Register(a),
257 b: Operand::Register(b),
258 });
259 dst
260 }
261
262 pub fn mul_wide_u32_to_u64(&mut self, a: Register, b: Register) -> Register {
265 let dst = self.regs.alloc(PtxType::U64);
266 self.emit(Instruction::Mul {
267 ty: PtxType::U32,
268 mode: MulMode::Wide,
269 dst: dst.clone(),
270 a: Operand::Register(a),
271 b: Operand::Register(b),
272 });
273 dst
274 }
275
276 pub fn mad_lo_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
282 self.mad_lo_typed(PtxType::S32, a, b, c)
283 }
284
285 pub fn mad_lo_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
287 self.mad_lo_typed(PtxType::U32, a, b, c)
288 }
289
290 pub fn mad_lo_s64(&mut self, a: Register, b: Register, c: Register) -> Register {
292 self.mad_lo_typed(PtxType::S64, a, b, c)
293 }
294
295 pub fn mad_lo_u64(&mut self, a: Register, b: Register, c: Register) -> Register {
297 self.mad_lo_typed(PtxType::U64, a, b, c)
298 }
299
300 fn mad_lo_typed(&mut self, typ: PtxType, a: Register, b: Register, c: Register) -> Register {
302 let dst = self.regs.alloc(typ);
303 self.emit(Instruction::MadLo {
304 typ,
305 dst: dst.clone(),
306 a: Operand::Register(a),
307 b: Operand::Register(b),
308 c: Operand::Register(c),
309 });
310 dst
311 }
312
313 pub fn mad_hi_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
315 self.mad_hi_typed(PtxType::S32, a, b, c)
316 }
317
318 pub fn mad_hi_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
320 self.mad_hi_typed(PtxType::U32, a, b, c)
321 }
322
323 pub fn mad_hi_s64(&mut self, a: Register, b: Register, c: Register) -> Register {
325 self.mad_hi_typed(PtxType::S64, a, b, c)
326 }
327
328 pub fn mad_hi_u64(&mut self, a: Register, b: Register, c: Register) -> Register {
330 self.mad_hi_typed(PtxType::U64, a, b, c)
331 }
332
333 fn mad_hi_typed(&mut self, typ: PtxType, a: Register, b: Register, c: Register) -> Register {
335 let dst = self.regs.alloc(typ);
336 self.emit(Instruction::MadHi {
337 typ,
338 dst: dst.clone(),
339 a: Operand::Register(a),
340 b: Operand::Register(b),
341 c: Operand::Register(c),
342 });
343 dst
344 }
345
346 pub fn mad_wide_s16(&mut self, a: Register, b: Register, c: Register) -> Register {
348 let dst = self.regs.alloc(PtxType::S32);
349 self.emit(Instruction::MadWide {
350 src_typ: PtxType::S16,
351 dst: dst.clone(),
352 a: Operand::Register(a),
353 b: Operand::Register(b),
354 c: Operand::Register(c),
355 });
356 dst
357 }
358
359 pub fn mad_wide_u16(&mut self, a: Register, b: Register, c: Register) -> Register {
361 let dst = self.regs.alloc(PtxType::U32);
362 self.emit(Instruction::MadWide {
363 src_typ: PtxType::U16,
364 dst: dst.clone(),
365 a: Operand::Register(a),
366 b: Operand::Register(b),
367 c: Operand::Register(c),
368 });
369 dst
370 }
371
372 pub fn mad_wide_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
374 let dst = self.regs.alloc(PtxType::S64);
375 self.emit(Instruction::MadWide {
376 src_typ: PtxType::S32,
377 dst: dst.clone(),
378 a: Operand::Register(a),
379 b: Operand::Register(b),
380 c: Operand::Register(c),
381 });
382 dst
383 }
384
385 pub fn mad_wide_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
387 let dst = self.regs.alloc(PtxType::U64);
388 self.emit(Instruction::MadWide {
389 src_typ: PtxType::U32,
390 dst: dst.clone(),
391 a: Operand::Register(a),
392 b: Operand::Register(b),
393 c: Operand::Register(c),
394 });
395 dst
396 }
397
398 pub fn fma_f32(&mut self, a: Register, b: Register, c: Register) -> Register {
404 self.fma_typed(PtxType::F32, a, b, c)
405 }
406
407 pub fn fma_f64(&mut self, a: Register, b: Register, c: Register) -> Register {
409 self.fma_typed(PtxType::F64, a, b, c)
410 }
411
412 fn fma_typed(&mut self, ty: PtxType, a: Register, b: Register, c: Register) -> Register {
414 let dst = self.regs.alloc(ty);
415 self.emit(Instruction::Fma {
416 rnd: RoundingMode::Rn,
417 ty,
418 dst: dst.clone(),
419 a: Operand::Register(a),
420 b: Operand::Register(b),
421 c: Operand::Register(c),
422 });
423 dst
424 }
425
426 pub fn neg_f32(&mut self, src: Register) -> Register {
428 let dst = self.regs.alloc(PtxType::F32);
429 self.emit(Instruction::Neg {
430 ty: PtxType::F32,
431 dst: dst.clone(),
432 src: Operand::Register(src),
433 });
434 dst
435 }
436
437 pub fn abs_f32(&mut self, src: Register) -> Register {
439 let dst = self.regs.alloc(PtxType::F32);
440 self.emit(Instruction::Abs {
441 ty: PtxType::F32,
442 dst: dst.clone(),
443 src: Operand::Register(src),
444 });
445 dst
446 }
447
448 pub fn min_f32(&mut self, a: Register, b: Register) -> Register {
450 self.min_typed(PtxType::F32, a, b)
451 }
452
453 pub fn max_f32(&mut self, a: Register, b: Register) -> Register {
455 self.max_typed(PtxType::F32, a, b)
456 }
457
458 pub fn min_u32(&mut self, a: Register, b: Register) -> Register {
460 self.min_typed(PtxType::U32, a, b)
461 }
462
463 pub fn max_u32(&mut self, a: Register, b: Register) -> Register {
465 self.max_typed(PtxType::U32, a, b)
466 }
467
468 fn min_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
470 let dst = self.regs.alloc(ty);
471 self.emit(Instruction::Min {
472 ty,
473 dst: dst.clone(),
474 a: Operand::Register(a),
475 b: Operand::Register(b),
476 });
477 dst
478 }
479
480 fn max_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
482 let dst = self.regs.alloc(ty);
483 self.emit(Instruction::Max {
484 ty,
485 dst: dst.clone(),
486 a: Operand::Register(a),
487 b: Operand::Register(b),
488 });
489 dst
490 }
491
492 pub fn brev_b32(&mut self, src: Register) -> Register {
498 let dst = self.regs.alloc(PtxType::B32);
499 self.emit(Instruction::Brev {
500 ty: PtxType::B32,
501 dst: dst.clone(),
502 src: Operand::Register(src),
503 });
504 dst
505 }
506
507 pub fn brev_b64(&mut self, src: Register) -> Register {
509 let dst = self.regs.alloc(PtxType::B64);
510 self.emit(Instruction::Brev {
511 ty: PtxType::B64,
512 dst: dst.clone(),
513 src: Operand::Register(src),
514 });
515 dst
516 }
517
518 pub fn clz_b32(&mut self, src: Register) -> Register {
520 let dst = self.regs.alloc(PtxType::U32);
521 self.emit(Instruction::Clz {
522 ty: PtxType::B32,
523 dst: dst.clone(),
524 src: Operand::Register(src),
525 });
526 dst
527 }
528
529 pub fn popc_b32(&mut self, src: Register) -> Register {
531 let dst = self.regs.alloc(PtxType::U32);
532 self.emit(Instruction::Popc {
533 ty: PtxType::B32,
534 dst: dst.clone(),
535 src: Operand::Register(src),
536 });
537 dst
538 }
539
540 pub fn popc_b64(&mut self, src: Register) -> Register {
542 let dst = self.regs.alloc(PtxType::U32);
543 self.emit(Instruction::Popc {
544 ty: PtxType::B64,
545 dst: dst.clone(),
546 src: Operand::Register(src),
547 });
548 dst
549 }
550
551 pub fn bfind_u32(&mut self, src: Register) -> Register {
553 let dst = self.regs.alloc(PtxType::U32);
554 self.emit(Instruction::Bfind {
555 ty: PtxType::U32,
556 dst: dst.clone(),
557 src: Operand::Register(src),
558 });
559 dst
560 }
561
562 pub fn bfind_s32(&mut self, src: Register) -> Register {
564 let dst = self.regs.alloc(PtxType::U32);
565 self.emit(Instruction::Bfind {
566 ty: PtxType::S32,
567 dst: dst.clone(),
568 src: Operand::Register(src),
569 });
570 dst
571 }
572
573 pub fn bfe_u32(&mut self, src: Register, start: Register, len: Register) -> Register {
575 let dst = self.regs.alloc(PtxType::U32);
576 self.emit(Instruction::Bfe {
577 ty: PtxType::U32,
578 dst: dst.clone(),
579 src: Operand::Register(src),
580 start: Operand::Register(start),
581 len: Operand::Register(len),
582 });
583 dst
584 }
585
586 pub fn bfe_s32(&mut self, src: Register, start: Register, len: Register) -> Register {
588 let dst = self.regs.alloc(PtxType::S32);
589 self.emit(Instruction::Bfe {
590 ty: PtxType::S32,
591 dst: dst.clone(),
592 src: Operand::Register(src),
593 start: Operand::Register(start),
594 len: Operand::Register(len),
595 });
596 dst
597 }
598
599 pub fn bfi_b32(
601 &mut self,
602 insert: Register,
603 base: Register,
604 start: Register,
605 len: Register,
606 ) -> Register {
607 let dst = self.regs.alloc(PtxType::B32);
608 self.emit(Instruction::Bfi {
609 ty: PtxType::B32,
610 dst: dst.clone(),
611 insert: Operand::Register(insert),
612 base: Operand::Register(base),
613 start: Operand::Register(start),
614 len: Operand::Register(len),
615 });
616 dst
617 }
618
619 pub fn shl_b32(&mut self, src: Register, amount: Register) -> Register {
625 let dst = self.regs.alloc(PtxType::B32);
626 self.emit(Instruction::Shl {
627 ty: PtxType::B32,
628 dst: dst.clone(),
629 src: Operand::Register(src),
630 amount: Operand::Register(amount),
631 });
632 dst
633 }
634
635 pub fn shl_b64(&mut self, src: Register, amount: Register) -> Register {
637 let dst = self.regs.alloc(PtxType::B64);
638 self.emit(Instruction::Shl {
639 ty: PtxType::B64,
640 dst: dst.clone(),
641 src: Operand::Register(src),
642 amount: Operand::Register(amount),
643 });
644 dst
645 }
646
647 pub fn shr_b32(&mut self, src: Register, amount: Register) -> Register {
649 let dst = self.regs.alloc(PtxType::B32);
650 self.emit(Instruction::Shr {
651 ty: PtxType::B32,
652 dst: dst.clone(),
653 src: Operand::Register(src),
654 amount: Operand::Register(amount),
655 });
656 dst
657 }
658
659 pub fn shr_b64(&mut self, src: Register, amount: Register) -> Register {
661 let dst = self.regs.alloc(PtxType::B64);
662 self.emit(Instruction::Shr {
663 ty: PtxType::B64,
664 dst: dst.clone(),
665 src: Operand::Register(src),
666 amount: Operand::Register(amount),
667 });
668 dst
669 }
670
671 pub fn shr_u32(&mut self, src: Register, amount: Register) -> Register {
673 let dst = self.regs.alloc(PtxType::U32);
674 self.emit(Instruction::Shr {
675 ty: PtxType::U32,
676 dst: dst.clone(),
677 src: Operand::Register(src),
678 amount: Operand::Register(amount),
679 });
680 dst
681 }
682
683 pub fn shr_s32(&mut self, src: Register, amount: Register) -> Register {
685 let dst = self.regs.alloc(PtxType::S32);
686 self.emit(Instruction::Shr {
687 ty: PtxType::S32,
688 dst: dst.clone(),
689 src: Operand::Register(src),
690 amount: Operand::Register(amount),
691 });
692 dst
693 }
694
695 pub fn rcp_f32(&mut self, src: Register) -> Register {
701 let dst = self.regs.alloc(PtxType::F32);
702 self.emit(Instruction::Rcp {
703 rnd: Some(RoundingMode::Rn),
704 ty: PtxType::F32,
705 dst: dst.clone(),
706 src: Operand::Register(src),
707 });
708 dst
709 }
710
711 pub fn rcp_f64(&mut self, src: Register) -> Register {
713 let dst = self.regs.alloc(PtxType::F64);
714 self.emit(Instruction::Rcp {
715 rnd: Some(RoundingMode::Rn),
716 ty: PtxType::F64,
717 dst: dst.clone(),
718 src: Operand::Register(src),
719 });
720 dst
721 }
722
723 pub fn rcp_approx_f32(&mut self, src: Register) -> Register {
727 let dst = self.regs.alloc(PtxType::F32);
728 self.emit(Instruction::Rcp {
729 rnd: None,
730 ty: PtxType::F32,
731 dst: dst.clone(),
732 src: Operand::Register(src),
733 });
734 dst
735 }
736
737 pub fn rsqrt_approx_f32(&mut self, src: Register) -> Register {
739 let dst = self.regs.alloc(PtxType::F32);
740 self.emit(Instruction::Rsqrt {
741 approx: true,
742 ty: PtxType::F32,
743 dst: dst.clone(),
744 src: Operand::Register(src),
745 });
746 dst
747 }
748
749 pub fn rsqrt_approx_f64(&mut self, src: Register) -> Register {
751 let dst = self.regs.alloc(PtxType::F64);
752 self.emit(Instruction::Rsqrt {
753 approx: true,
754 ty: PtxType::F64,
755 dst: dst.clone(),
756 src: Operand::Register(src),
757 });
758 dst
759 }
760
761 pub fn sqrt_rn_f32(&mut self, src: Register) -> Register {
763 let dst = self.regs.alloc(PtxType::F32);
764 self.emit(Instruction::Sqrt {
765 rnd: Some(RoundingMode::Rn),
766 ty: PtxType::F32,
767 dst: dst.clone(),
768 src: Operand::Register(src),
769 });
770 dst
771 }
772
773 pub fn sqrt_rn_f64(&mut self, src: Register) -> Register {
775 let dst = self.regs.alloc(PtxType::F64);
776 self.emit(Instruction::Sqrt {
777 rnd: Some(RoundingMode::Rn),
778 ty: PtxType::F64,
779 dst: dst.clone(),
780 src: Operand::Register(src),
781 });
782 dst
783 }
784
785 pub fn ex2_approx_f32(&mut self, src: Register) -> Register {
787 let dst = self.regs.alloc(PtxType::F32);
788 self.emit(Instruction::Ex2 {
789 approx: true,
790 ty: PtxType::F32,
791 dst: dst.clone(),
792 src: Operand::Register(src),
793 });
794 dst
795 }
796
797 pub fn lg2_approx_f32(&mut self, src: Register) -> Register {
799 let dst = self.regs.alloc(PtxType::F32);
800 self.emit(Instruction::Lg2 {
801 approx: true,
802 ty: PtxType::F32,
803 dst: dst.clone(),
804 src: Operand::Register(src),
805 });
806 dst
807 }
808
809 pub fn sin_approx_f32(&mut self, src: Register) -> Register {
811 let dst = self.regs.alloc(PtxType::F32);
812 self.emit(Instruction::Sin {
813 approx: true,
814 ty: PtxType::F32,
815 dst: dst.clone(),
816 src: Operand::Register(src),
817 });
818 dst
819 }
820
821 pub fn cos_approx_f32(&mut self, src: Register) -> Register {
823 let dst = self.regs.alloc(PtxType::F32);
824 self.emit(Instruction::Cos {
825 approx: true,
826 ty: PtxType::F32,
827 dst: dst.clone(),
828 src: Operand::Register(src),
829 });
830 dst
831 }
832
833 pub fn load_global_f32(&mut self, addr: Register) -> Register {
842 self.load_global_scalar(PtxType::F32, addr)
843 }
844
845 pub fn load_global_f64(&mut self, addr: Register) -> Register {
847 self.load_global_scalar(PtxType::F64, addr)
848 }
849
850 pub fn load_global_i32(&mut self, addr: Register) -> Register {
854 self.load_global_scalar(PtxType::S32, addr)
855 }
856
857 pub fn load_global_u32(&mut self, addr: Register) -> Register {
861 self.load_global_scalar(PtxType::U32, addr)
862 }
863
864 fn load_global_scalar(&mut self, ty: PtxType, addr: Register) -> Register {
866 let dst = self.regs.alloc(ty);
867 self.emit(Instruction::Load {
868 space: MemorySpace::Global,
869 qualifier: CacheQualifier::None,
870 vec: VectorWidth::V1,
871 ty,
872 dst: dst.clone(),
873 addr: Operand::Address {
874 base: addr,
875 offset: None,
876 },
877 });
878 dst
879 }
880
881 pub fn load_global_f32x4(&mut self, addr: &Register) -> [Register; 4] {
890 let r0 = self.regs.alloc(PtxType::F32);
891 let r1 = self.regs.alloc(PtxType::F32);
892 let r2 = self.regs.alloc(PtxType::F32);
893 let r3 = self.regs.alloc(PtxType::F32);
894 self.emit(Instruction::Raw(format!(
895 "ld.global.v4.f32 {{{r0}, {r1}, {r2}, {r3}}}, [{addr}];"
896 )));
897 [r0, r1, r2, r3]
898 }
899
900 pub fn store_global_f32(&mut self, addr: Register, val: Register) {
904 self.store_global_scalar(PtxType::F32, addr, val);
905 }
906
907 pub fn store_global_f64(&mut self, addr: Register, val: Register) {
909 self.store_global_scalar(PtxType::F64, addr, val);
910 }
911
912 pub fn store_global_i32(&mut self, addr: Register, val: Register) {
916 self.store_global_scalar(PtxType::S32, addr, val);
917 }
918
919 pub fn store_global_u32(&mut self, addr: Register, val: Register) {
923 self.store_global_scalar(PtxType::U32, addr, val);
924 }
925
926 fn store_global_scalar(&mut self, ty: PtxType, addr: Register, val: Register) {
928 self.emit(Instruction::Store {
929 space: MemorySpace::Global,
930 qualifier: CacheQualifier::None,
931 vec: VectorWidth::V1,
932 ty,
933 addr: Operand::Address {
934 base: addr,
935 offset: None,
936 },
937 src: val,
938 });
939 }
940
941 pub fn load_shared_f32(&mut self, addr: Register) -> Register {
949 let dst = self.regs.alloc(PtxType::F32);
950 self.emit(Instruction::Load {
951 space: MemorySpace::Shared,
952 qualifier: CacheQualifier::None,
953 vec: VectorWidth::V1,
954 ty: PtxType::F32,
955 dst: dst.clone(),
956 addr: Operand::Address {
957 base: addr,
958 offset: None,
959 },
960 });
961 dst
962 }
963
964 pub fn store_shared_f32(&mut self, addr: Register, val: Register) {
966 self.emit(Instruction::Store {
967 space: MemorySpace::Shared,
968 qualifier: CacheQualifier::None,
969 vec: VectorWidth::V1,
970 ty: PtxType::F32,
971 addr: Operand::Address {
972 base: addr,
973 offset: None,
974 },
975 src: val,
976 });
977 }
978
979 pub fn cp_async_32bit(&mut self, dst_shared: Register, src_global: Register) {
988 self.emit(Instruction::CpAsync {
989 bytes: 4,
990 dst_shared: Operand::Register(dst_shared),
991 src_global: Operand::Register(src_global),
992 });
993 }
994
995 pub fn cp_async_64bit(&mut self, dst_shared: Register, src_global: Register) {
1000 self.emit(Instruction::CpAsync {
1001 bytes: 8,
1002 dst_shared: Operand::Register(dst_shared),
1003 src_global: Operand::Register(src_global),
1004 });
1005 }
1006
1007 pub fn cp_async_128bit(&mut self, dst_shared: Register, src_global: Register) {
1012 self.emit(Instruction::CpAsync {
1013 bytes: 16,
1014 dst_shared: Operand::Register(dst_shared),
1015 src_global: Operand::Register(src_global),
1016 });
1017 }
1018
1019 pub fn cp_async_commit(&mut self) {
1021 self.emit(Instruction::CpAsyncCommit);
1022 }
1023
1024 pub fn cp_async_wait(&mut self, n: u32) {
1029 self.emit(Instruction::CpAsyncWait { n });
1030 }
1031
1032 pub fn ldmatrix_x4(&mut self, src_addr: Register) -> Result<[Register; 4], PtxGenError> {
1043 use crate::ir::Instruction as I;
1044 if !self.target.capabilities().has_ldmatrix {
1045 return Err(PtxGenError::UnsupportedFeature {
1046 arch: self.target.as_ptx_str().to_string(),
1047 feature: "ldmatrix (SM >= 75)".to_string(),
1048 });
1049 }
1050 let r0 = self.regs.alloc(PtxType::B32);
1051 let r1 = self.regs.alloc(PtxType::B32);
1052 let r2 = self.regs.alloc(PtxType::B32);
1053 let r3 = self.regs.alloc(PtxType::B32);
1054 self.emit(I::Ldmatrix {
1055 num_fragments: 4,
1056 trans: false,
1057 dst_regs: vec![r0.clone(), r1.clone(), r2.clone(), r3.clone()],
1058 src_addr: Operand::Register(src_addr),
1059 });
1060 Ok([r0, r1, r2, r3])
1061 }
1062
1063 pub fn if_lt_u32<F>(&mut self, a: Register, b: Register, body: F)
1080 where
1081 F: FnOnce(&mut BodyBuilder<'_>),
1082 {
1083 let pred = self.regs.alloc(PtxType::Pred);
1084 self.emit(Instruction::SetP {
1085 cmp: CmpOp::Lo,
1086 ty: PtxType::U32,
1087 dst: pred.clone(),
1088 a: Operand::Register(a),
1089 b: Operand::Register(b),
1090 });
1091 let skip_label = self.fresh_label("skip");
1092 self.emit(Instruction::Branch {
1094 target: skip_label.clone(),
1095 predicate: Some((pred, true)),
1096 });
1097 body(self);
1098 self.emit(Instruction::Label(skip_label));
1099 }
1100
1101 pub fn if_ge_u32<F>(&mut self, a: Register, b: Register, body: F)
1103 where
1104 F: FnOnce(&mut BodyBuilder<'_>),
1105 {
1106 let pred = self.regs.alloc(PtxType::Pred);
1107 self.emit(Instruction::SetP {
1108 cmp: CmpOp::Hs,
1109 ty: PtxType::U32,
1110 dst: pred.clone(),
1111 a: Operand::Register(a),
1112 b: Operand::Register(b),
1113 });
1114 let skip_label = self.fresh_label("skip");
1115 self.emit(Instruction::Branch {
1116 target: skip_label.clone(),
1117 predicate: Some((pred, true)),
1118 });
1119 body(self);
1120 self.emit(Instruction::Label(skip_label));
1121 }
1122
1123 pub fn unroll<F>(&mut self, count: u32, mut body: F)
1130 where
1131 F: FnMut(&mut BodyBuilder<'_>, u32),
1132 {
1133 for i in 0..count {
1134 self.comment(&format!("unroll iteration {i}/{count}"));
1135 body(self, i);
1136 }
1137 }
1138
1139 pub fn pragma_unroll(&mut self, factor: Option<u32>) {
1146 let text = factor.map_or_else(|| "nounroll".to_string(), |n| format!("unroll {n}"));
1147 self.emit(Instruction::Pragma(text));
1148 }
1149
1150 pub fn label(&mut self, name: &str) {
1155 self.emit(Instruction::Label(name.to_string()));
1156 }
1157
1158 pub fn branch(&mut self, target: &str) {
1160 self.emit(Instruction::Branch {
1161 target: target.to_string(),
1162 predicate: None,
1163 });
1164 }
1165
1166 pub fn branch_if(&mut self, pred: Register, target: &str) {
1168 self.emit(Instruction::Branch {
1169 target: target.to_string(),
1170 predicate: Some((pred, false)),
1171 });
1172 }
1173
1174 pub fn ret(&mut self) {
1176 self.emit(Instruction::Return);
1177 }
1178
1179 pub fn bar_sync(&mut self, id: u32) {
1188 self.emit(Instruction::BarSync { id });
1189 }
1190
1191 pub fn fence_acq_rel(&mut self, scope: FenceScope) {
1197 self.emit(Instruction::FenceAcqRel { scope });
1198 }
1199
1200 pub fn cvt_u32_to_u64(&mut self, src: Register) -> Register {
1208 let dst = self.regs.alloc(PtxType::U64);
1209 self.emit(Instruction::Cvt {
1210 rnd: None,
1211 dst_ty: PtxType::U64,
1212 src_ty: PtxType::U32,
1213 dst: dst.clone(),
1214 src: Operand::Register(src),
1215 });
1216 dst
1217 }
1218
1219 pub fn cvt_f32_to_f64(&mut self, src: Register) -> Register {
1223 let dst = self.regs.alloc(PtxType::F64);
1224 self.emit(Instruction::Cvt {
1225 rnd: None,
1226 dst_ty: PtxType::F64,
1227 src_ty: PtxType::F32,
1228 dst: dst.clone(),
1229 src: Operand::Register(src),
1230 });
1231 dst
1232 }
1233
1234 pub fn cvt_f64_to_f32(&mut self, src: Register) -> Register {
1238 let dst = self.regs.alloc(PtxType::F32);
1239 self.emit(Instruction::Cvt {
1240 rnd: Some(RoundingMode::Rn),
1241 dst_ty: PtxType::F32,
1242 src_ty: PtxType::F64,
1243 dst: dst.clone(),
1244 src: Operand::Register(src),
1245 });
1246 dst
1247 }
1248
1249 pub fn cvt_f16_to_f32(&mut self, src: Register) -> Register {
1253 let dst = self.regs.alloc(PtxType::F32);
1254 self.emit(Instruction::Cvt {
1255 rnd: None,
1256 dst_ty: PtxType::F32,
1257 src_ty: PtxType::F16,
1258 dst: dst.clone(),
1259 src: Operand::Register(src),
1260 });
1261 dst
1262 }
1263
1264 pub fn cvt_f32_to_f16(&mut self, src: Register) -> Register {
1268 let dst = self.regs.alloc(PtxType::F16);
1269 self.emit(Instruction::Cvt {
1270 rnd: Some(RoundingMode::Rn),
1271 dst_ty: PtxType::F16,
1272 src_ty: PtxType::F32,
1273 dst: dst.clone(),
1274 src: Operand::Register(src),
1275 });
1276 dst
1277 }
1278
1279 pub fn cvt_bf16_to_f32(&mut self, src: Register) -> Register {
1283 let dst = self.regs.alloc(PtxType::F32);
1284 self.emit(Instruction::Cvt {
1285 rnd: None,
1286 dst_ty: PtxType::F32,
1287 src_ty: PtxType::BF16,
1288 dst: dst.clone(),
1289 src: Operand::Register(src),
1290 });
1291 dst
1292 }
1293
1294 pub fn cvt_f32_to_bf16(&mut self, src: Register) -> Register {
1298 let dst = self.regs.alloc(PtxType::BF16);
1299 self.emit(Instruction::Cvt {
1300 rnd: Some(RoundingMode::Rn),
1301 dst_ty: PtxType::BF16,
1302 src_ty: PtxType::F32,
1303 dst: dst.clone(),
1304 src: Operand::Register(src),
1305 });
1306 dst
1307 }
1308
1309 pub fn cvt_f32_to_e4m3(&mut self, src: Register) -> Register {
1314 let dst = self.regs.alloc(PtxType::E4M3);
1315 self.emit(Instruction::Cvt {
1316 rnd: Some(RoundingMode::Rn),
1317 dst_ty: PtxType::E4M3,
1318 src_ty: PtxType::F32,
1319 dst: dst.clone(),
1320 src: Operand::Register(src),
1321 });
1322 dst
1323 }
1324
1325 pub fn cvt_e4m3_to_f32(&mut self, src: Register) -> Register {
1329 let dst = self.regs.alloc(PtxType::F32);
1330 self.emit(Instruction::Cvt {
1331 rnd: None,
1332 dst_ty: PtxType::F32,
1333 src_ty: PtxType::E4M3,
1334 dst: dst.clone(),
1335 src: Operand::Register(src),
1336 });
1337 dst
1338 }
1339
1340 pub fn cvt_f32_to_e5m2(&mut self, src: Register) -> Register {
1344 let dst = self.regs.alloc(PtxType::E5M2);
1345 self.emit(Instruction::Cvt {
1346 rnd: Some(RoundingMode::Rn),
1347 dst_ty: PtxType::E5M2,
1348 src_ty: PtxType::F32,
1349 dst: dst.clone(),
1350 src: Operand::Register(src),
1351 });
1352 dst
1353 }
1354
1355 pub fn cvt_e5m2_to_f32(&mut self, src: Register) -> Register {
1359 let dst = self.regs.alloc(PtxType::F32);
1360 self.emit(Instruction::Cvt {
1361 rnd: None,
1362 dst_ty: PtxType::F32,
1363 src_ty: PtxType::E5M2,
1364 dst: dst.clone(),
1365 src: Operand::Register(src),
1366 });
1367 dst
1368 }
1369
1370 pub fn mma_m16n8k16_f16_f32(
1390 &mut self,
1391 a_regs: &[Register],
1392 b_regs: &[Register],
1393 c_regs: &[Register],
1394 ) -> [Register; 4] {
1395 let dst = self.regs.alloc_group(PtxType::F32, 4);
1396 self.emit(Instruction::Mma {
1397 shape: MmaShape::M16N8K16,
1398 a_ty: PtxType::F16,
1399 b_ty: PtxType::F16,
1400 c_ty: PtxType::F32,
1401 d_ty: PtxType::F32,
1402 d_regs: dst.clone(),
1403 a_regs: a_regs.to_vec(),
1404 b_regs: b_regs.to_vec(),
1405 c_regs: c_regs.to_vec(),
1406 });
1407 [
1408 dst[0].clone(),
1409 dst[1].clone(),
1410 dst[2].clone(),
1411 dst[3].clone(),
1412 ]
1413 }
1414
1415 pub fn wgmma_mma_async_m64n128k16_f16(
1429 &mut self,
1430 a_desc: &str,
1431 b_desc: &str,
1432 ) -> Result<(), PtxGenError> {
1433 if !self.target.capabilities().has_wgmma {
1434 return Err(PtxGenError::GenerationFailed(format!(
1435 "wgmma.mma_async requires SM >= 90 (Hopper), target is {}",
1436 self.target
1437 )));
1438 }
1439 self.raw_ptx(&format!(
1440 "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {{...}}, {a_desc}, {b_desc}, 1, 1, 1, 0, 0;"
1441 ));
1442 Ok(())
1443 }
1444
1445 pub fn dp4a_u32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1451 self.dp4a_typed(a, b, c, false, false)
1452 }
1453
1454 pub fn dp4a_s32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1456 self.dp4a_typed(a, b, c, true, true)
1457 }
1458
1459 pub fn dp4a_s32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1461 self.dp4a_typed(a, b, c, true, false)
1462 }
1463
1464 pub fn dp4a_u32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1466 self.dp4a_typed(a, b, c, false, true)
1467 }
1468
1469 fn dp4a_typed(
1471 &mut self,
1472 a: Register,
1473 b: Register,
1474 c: Register,
1475 signed_a: bool,
1476 signed_b: bool,
1477 ) -> Register {
1478 let dst = self.regs.alloc(PtxType::S32);
1479 self.emit(Instruction::Dp4a {
1480 dst: dst.clone(),
1481 a: Operand::Register(a),
1482 b: Operand::Register(b),
1483 c: Operand::Register(c),
1484 signed_a,
1485 signed_b,
1486 });
1487 dst
1488 }
1489
1490 pub fn dp2a_lo_u32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1492 self.dp2a_typed(a, b, c, false, false, true)
1493 }
1494
1495 pub fn dp2a_hi_u32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1497 self.dp2a_typed(a, b, c, false, false, false)
1498 }
1499
1500 pub fn dp2a_lo_s32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1502 self.dp2a_typed(a, b, c, true, true, true)
1503 }
1504
1505 pub fn dp2a_hi_s32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1507 self.dp2a_typed(a, b, c, true, true, false)
1508 }
1509
1510 fn dp2a_typed(
1512 &mut self,
1513 a: Register,
1514 b: Register,
1515 c: Register,
1516 signed_a: bool,
1517 signed_b: bool,
1518 lo: bool,
1519 ) -> Register {
1520 let dst = self.regs.alloc(PtxType::S32);
1521 self.emit(Instruction::Dp2a {
1522 dst: dst.clone(),
1523 a: Operand::Register(a),
1524 b: Operand::Register(b),
1525 c: Operand::Register(c),
1526 signed_a,
1527 signed_b,
1528 lo,
1529 });
1530 dst
1531 }
1532
1533 #[must_use]
1539 pub const fn imm_u32(&self, val: u32) -> Operand {
1540 Operand::Immediate(ImmValue::U32(val))
1541 }
1542
1543 pub fn mov_imm_u32(&mut self, val: u32) -> Register {
1545 let dst = self.regs.alloc(PtxType::U32);
1546 self.emit(Instruction::Add {
1547 ty: PtxType::U32,
1548 dst: dst.clone(),
1549 a: Operand::Immediate(ImmValue::U32(0)),
1550 b: Operand::Immediate(ImmValue::U32(val)),
1551 });
1552 dst
1553 }
1554
1555 #[must_use]
1557 pub const fn imm_u64(&self, val: u64) -> Operand {
1558 Operand::Immediate(ImmValue::U64(val))
1559 }
1560
1561 #[must_use]
1563 pub const fn imm_f32(&self, val: f32) -> Operand {
1564 Operand::Immediate(ImmValue::F32(val))
1565 }
1566
1567 #[must_use]
1569 pub const fn imm_f64(&self, val: f64) -> Operand {
1570 Operand::Immediate(ImmValue::F64(val))
1571 }
1572
1573 pub fn comment(&mut self, text: &str) {
1579 self.emit(Instruction::Comment(text.to_string()));
1580 }
1581
1582 pub fn raw_ptx(&mut self, text: &str) {
1592 let mut i = 0;
1594 let bytes = text.as_bytes();
1595 while i < bytes.len() {
1596 if bytes[i] == b'%' {
1597 let start = i;
1598 i += 1;
1599 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
1601 i += 1;
1602 }
1603 let name = &text[start..i];
1604 if name.contains('_') {
1606 let ty = if name.starts_with("%rd_") {
1607 PtxType::B64
1608 } else if name.starts_with("%f_") {
1609 PtxType::F32
1610 } else if name.starts_with("%p_") {
1611 PtxType::Pred
1612 } else if name.starts_with("%r_") {
1613 PtxType::B32
1614 } else {
1615 continue;
1616 };
1617 self.regs.declare_named(name, ty);
1618 }
1619 } else {
1620 i += 1;
1621 }
1622 }
1623 self.emit(Instruction::Raw(text.to_string()));
1624 }
1625
1626 pub fn byte_offset_addr(
1637 &mut self,
1638 base: Register,
1639 index: Register,
1640 stride_bytes: u32,
1641 ) -> Register {
1642 let idx64 = self.cvt_u32_to_u64(index);
1643 let stride_reg = self.regs.alloc(PtxType::U64);
1646 self.emit(Instruction::Raw(format!(
1647 "mov.u64 {}, {};",
1648 stride_reg,
1649 u64::from(stride_bytes)
1650 )));
1651 let offset = self.regs.alloc(PtxType::U64);
1652 self.emit(Instruction::Mul {
1653 ty: PtxType::U64,
1654 mode: MulMode::Lo,
1655 dst: offset.clone(),
1656 a: Operand::Register(idx64),
1657 b: Operand::Register(stride_reg),
1658 });
1659 self.add_u64(base, offset)
1660 }
1661
1662 pub fn f32_elem_addr(&mut self, base: Register, index: Register) -> Register {
1664 self.byte_offset_addr(base, index, 4)
1665 }
1666
1667 pub fn f64_elem_addr(&mut self, base: Register, index: Register) -> Register {
1669 self.byte_offset_addr(base, index, 8)
1670 }
1671
1672 pub fn alloc_reg(&mut self, ty: PtxType) -> Register {
1681 self.regs.alloc(ty)
1682 }
1683
1684 pub fn declare_named_reg(&mut self, name: &str, ty: PtxType) {
1689 self.regs.declare_named(name, ty);
1690 }
1691
1692 fn emit(&mut self, inst: Instruction) {
1698 self.instructions.push(inst);
1699 }
1700
1701 pub fn fresh_label(&mut self, prefix: &str) -> String {
1706 let id = self.label_counter;
1707 self.label_counter += 1;
1708 format!("L__{prefix}_{id}")
1709 }
1710
1711 #[must_use]
1715 pub const fn target_sm(&self) -> SmVersion {
1716 self.target
1717 }
1718
1719 #[must_use]
1721 pub fn has_param(&self, name: &str) -> bool {
1722 self.param_names.iter().any(|p| p == name)
1723 }
1724}
1725
1726pub(super) mod body_builder_ext;
1729
1730pub(super) mod tensor_core_ops;
1732
1733#[cfg(test)]
1734#[path = "body_builder_tests.rs"]
1735mod tests;