1use std::collections::HashSet;
8
9use crate::ir::{Instruction, Operand, Register, WmmaOp};
10
11pub fn eliminate_dead_code(instructions: &[Instruction]) -> (Vec<Instruction>, usize) {
29 let mut current: Vec<Instruction> = instructions.to_vec();
30 let mut total_eliminated: usize = 0;
31
32 loop {
33 let (next, eliminated) = dce_pass(¤t);
34 if eliminated == 0 {
35 break;
36 }
37 total_eliminated += eliminated;
38 current = next;
39 }
40
41 (current, total_eliminated)
42}
43
44fn dce_pass(instructions: &[Instruction]) -> (Vec<Instruction>, usize) {
52 let mut used_regs: HashSet<String> = HashSet::new();
54 for inst in instructions {
55 for reg in uses(inst) {
56 used_regs.insert(reg.name.clone());
57 }
58 }
59
60 let mut result = Vec::with_capacity(instructions.len());
62 let mut eliminated: usize = 0;
63
64 for inst in instructions {
65 if has_side_effects(inst) {
66 result.push(inst.clone());
68 continue;
69 }
70
71 let defined = defs(inst);
72 if defined.is_empty() {
73 result.push(inst.clone());
76 continue;
77 }
78
79 let any_def_used = defined.iter().any(|r| used_regs.contains(&r.name));
81
82 if any_def_used {
83 result.push(inst.clone());
84 } else {
85 eliminated += 1;
86 }
87 }
88
89 (result, eliminated)
90}
91
92const fn has_side_effects(inst: &Instruction) -> bool {
102 match inst {
103 Instruction::Store { .. }
106 | Instruction::CpAsync { .. }
107 | Instruction::CpAsyncCommit
108 | Instruction::CpAsyncWait { .. }
109 | Instruction::Branch { .. }
110 | Instruction::Label(_)
111 | Instruction::Return
112 | Instruction::BarSync { .. }
113 | Instruction::BarArrive { .. }
114 | Instruction::FenceAcqRel { .. }
115 | Instruction::TmaLoad { .. }
116 | Instruction::Atom { .. }
117 | Instruction::AtomCas { .. }
118 | Instruction::Red { .. }
119 | Instruction::SurfStore { .. }
120 | Instruction::Stmatrix { .. }
121 | Instruction::Setmaxnreg { .. }
122 | Instruction::Griddepcontrol { .. }
123 | Instruction::FenceProxy { .. }
124 | Instruction::MbarrierInit { .. }
125 | Instruction::MbarrierArrive { .. }
126 | Instruction::MbarrierWait { .. }
127 | Instruction::Tcgen05Mma { .. }
128 | Instruction::BarrierCluster
129 | Instruction::FenceCluster
130 | Instruction::CpAsyncBulk { .. }
131 | Instruction::Comment(_)
132 | Instruction::Raw(_) => true,
133
134 Instruction::Wmma { op, .. } => matches!(op, WmmaOp::StoreD),
136
137 Instruction::Add { .. }
139 | Instruction::Sub { .. }
140 | Instruction::Mul { .. }
141 | Instruction::Mad { .. }
142 | Instruction::Fma { .. }
143 | Instruction::MadLo { .. }
144 | Instruction::MadHi { .. }
145 | Instruction::MadWide { .. }
146 | Instruction::Neg { .. }
147 | Instruction::Abs { .. }
148 | Instruction::Min { .. }
149 | Instruction::Max { .. }
150 | Instruction::Brev { .. }
151 | Instruction::Clz { .. }
152 | Instruction::Popc { .. }
153 | Instruction::Bfind { .. }
154 | Instruction::Bfe { .. }
155 | Instruction::Bfi { .. }
156 | Instruction::Rcp { .. }
157 | Instruction::Rsqrt { .. }
158 | Instruction::Sqrt { .. }
159 | Instruction::Ex2 { .. }
160 | Instruction::Lg2 { .. }
161 | Instruction::Sin { .. }
162 | Instruction::Cos { .. }
163 | Instruction::Shl { .. }
164 | Instruction::Shr { .. }
165 | Instruction::Div { .. }
166 | Instruction::Rem { .. }
167 | Instruction::And { .. }
168 | Instruction::Or { .. }
169 | Instruction::Xor { .. }
170 | Instruction::SetP { .. }
171 | Instruction::Load { .. }
172 | Instruction::Cvt { .. }
173 | Instruction::Mma { .. }
174 | Instruction::Wgmma { .. }
175 | Instruction::MovSpecial { .. }
176 | Instruction::LoadParam { .. }
177 | Instruction::Dp4a { .. }
178 | Instruction::Dp2a { .. }
179 | Instruction::Tex1d { .. }
180 | Instruction::Tex2d { .. }
181 | Instruction::Tex3d { .. }
182 | Instruction::SurfLoad { .. }
183 | Instruction::Redux { .. }
184 | Instruction::ElectSync { .. }
185 | Instruction::Pragma(_)
187 | Instruction::Ldmatrix { .. } => false,
189 }
190}
191
192fn defs(inst: &Instruction) -> Vec<&Register> {
198 match inst {
199 Instruction::Add { dst, .. }
200 | Instruction::Sub { dst, .. }
201 | Instruction::Mul { dst, .. }
202 | Instruction::Mad { dst, .. }
203 | Instruction::MadLo { dst, .. }
204 | Instruction::MadHi { dst, .. }
205 | Instruction::MadWide { dst, .. }
206 | Instruction::Fma { dst, .. }
207 | Instruction::Neg { dst, .. }
208 | Instruction::Abs { dst, .. }
209 | Instruction::Min { dst, .. }
210 | Instruction::Max { dst, .. }
211 | Instruction::Brev { dst, .. }
212 | Instruction::Clz { dst, .. }
213 | Instruction::Popc { dst, .. }
214 | Instruction::Bfind { dst, .. }
215 | Instruction::Bfe { dst, .. }
216 | Instruction::Bfi { dst, .. }
217 | Instruction::Rcp { dst, .. }
218 | Instruction::Rsqrt { dst, .. }
219 | Instruction::Sqrt { dst, .. }
220 | Instruction::Ex2 { dst, .. }
221 | Instruction::Lg2 { dst, .. }
222 | Instruction::Sin { dst, .. }
223 | Instruction::Cos { dst, .. }
224 | Instruction::Shl { dst, .. }
225 | Instruction::Shr { dst, .. }
226 | Instruction::Div { dst, .. }
227 | Instruction::Rem { dst, .. }
228 | Instruction::And { dst, .. }
229 | Instruction::Or { dst, .. }
230 | Instruction::Xor { dst, .. }
231 | Instruction::SetP { dst, .. }
232 | Instruction::Load { dst, .. }
233 | Instruction::Cvt { dst, .. }
234 | Instruction::MovSpecial { dst, .. }
235 | Instruction::LoadParam { dst, .. }
236 | Instruction::Atom { dst, .. }
237 | Instruction::AtomCas { dst, .. }
238 | Instruction::Dp4a { dst, .. }
239 | Instruction::Dp2a { dst, .. }
240 | Instruction::Tex1d { dst, .. }
241 | Instruction::Tex2d { dst, .. }
242 | Instruction::Tex3d { dst, .. }
243 | Instruction::SurfLoad { dst, .. }
244 | Instruction::Redux { dst, .. }
245 | Instruction::ElectSync { dst, .. } => vec![dst],
246
247 Instruction::Ldmatrix { dst_regs, .. } => dst_regs.iter().collect(),
248
249 Instruction::Store { .. }
250 | Instruction::CpAsync { .. }
251 | Instruction::CpAsyncCommit
252 | Instruction::CpAsyncWait { .. }
253 | Instruction::Branch { .. }
254 | Instruction::Label(_)
255 | Instruction::Return
256 | Instruction::BarSync { .. }
257 | Instruction::BarArrive { .. }
258 | Instruction::FenceAcqRel { .. }
259 | Instruction::TmaLoad { .. }
260 | Instruction::Red { .. }
261 | Instruction::SurfStore { .. }
262 | Instruction::Stmatrix { .. }
263 | Instruction::Setmaxnreg { .. }
264 | Instruction::Griddepcontrol { .. }
265 | Instruction::FenceProxy { .. }
266 | Instruction::MbarrierInit { .. }
267 | Instruction::MbarrierArrive { .. }
268 | Instruction::MbarrierWait { .. }
269 | Instruction::Tcgen05Mma { .. }
270 | Instruction::BarrierCluster
271 | Instruction::FenceCluster
272 | Instruction::CpAsyncBulk { .. }
273 | Instruction::Comment(_)
274 | Instruction::Raw(_)
275 | Instruction::Pragma(_) => vec![],
276
277 Instruction::Wmma { op, fragments, .. } => match op {
278 WmmaOp::LoadA | WmmaOp::LoadB | WmmaOp::Mma => fragments.iter().collect(),
279 WmmaOp::StoreD => vec![],
280 },
281 Instruction::Mma { d_regs, .. } | Instruction::Wgmma { d_regs, .. } => {
282 d_regs.iter().collect()
283 }
284 }
285}
286
287#[allow(clippy::too_many_lines)]
289fn uses(inst: &Instruction) -> Vec<&Register> {
290 match inst {
291 Instruction::Add { a, b, .. }
292 | Instruction::Sub { a, b, .. }
293 | Instruction::Mul { a, b, .. }
294 | Instruction::Min { a, b, .. }
295 | Instruction::Max { a, b, .. }
296 | Instruction::Div { a, b, .. }
297 | Instruction::Rem { a, b, .. }
298 | Instruction::And { a, b, .. }
299 | Instruction::Or { a, b, .. }
300 | Instruction::Xor { a, b, .. }
301 | Instruction::SetP { a, b, .. }
302 | Instruction::Shl {
303 src: a, amount: b, ..
304 }
305 | Instruction::Shr {
306 src: a, amount: b, ..
307 } => {
308 let mut regs = operand_regs(a);
309 regs.extend(operand_regs(b));
310 regs
311 }
312
313 Instruction::Mad { a, b, c, .. }
314 | Instruction::MadLo { a, b, c, .. }
315 | Instruction::MadHi { a, b, c, .. }
316 | Instruction::MadWide { a, b, c, .. }
317 | Instruction::Fma { a, b, c, .. }
318 | Instruction::Dp4a { a, b, c, .. }
319 | Instruction::Dp2a { a, b, c, .. } => {
320 let mut regs = operand_regs(a);
321 regs.extend(operand_regs(b));
322 regs.extend(operand_regs(c));
323 regs
324 }
325
326 Instruction::Neg { src, .. }
327 | Instruction::Abs { src, .. }
328 | Instruction::Brev { src, .. }
329 | Instruction::Clz { src, .. }
330 | Instruction::Popc { src, .. }
331 | Instruction::Bfind { src, .. }
332 | Instruction::Rcp { src, .. }
333 | Instruction::Rsqrt { src, .. }
334 | Instruction::Sqrt { src, .. }
335 | Instruction::Ex2 { src, .. }
336 | Instruction::Lg2 { src, .. }
337 | Instruction::Sin { src, .. }
338 | Instruction::Cos { src, .. }
339 | Instruction::Cvt { src, .. }
340 | Instruction::Redux { src, .. } => operand_regs(src),
341
342 Instruction::Bfe {
343 src, start, len, ..
344 } => {
345 let mut regs = operand_regs(src);
346 regs.extend(operand_regs(start));
347 regs.extend(operand_regs(len));
348 regs
349 }
350
351 Instruction::Bfi {
352 insert,
353 base,
354 start,
355 len,
356 ..
357 } => {
358 let mut regs = operand_regs(insert);
359 regs.extend(operand_regs(base));
360 regs.extend(operand_regs(start));
361 regs.extend(operand_regs(len));
362 regs
363 }
364
365 Instruction::Load { addr, .. } | Instruction::MbarrierArrive { addr } => operand_regs(addr),
366
367 Instruction::Store { addr, src, .. } => {
368 let mut regs = operand_regs(addr);
369 regs.push(src);
370 regs
371 }
372
373 Instruction::CpAsync {
374 dst_shared,
375 src_global,
376 ..
377 } => {
378 let mut regs = operand_regs(dst_shared);
379 regs.extend(operand_regs(src_global));
380 regs
381 }
382
383 Instruction::CpAsyncCommit
384 | Instruction::CpAsyncWait { .. }
385 | Instruction::Label(_)
386 | Instruction::Return
387 | Instruction::BarSync { .. }
388 | Instruction::BarArrive { .. }
389 | Instruction::FenceAcqRel { .. }
390 | Instruction::MovSpecial { .. }
391 | Instruction::LoadParam { .. }
392 | Instruction::ElectSync { .. }
393 | Instruction::Setmaxnreg { .. }
394 | Instruction::Griddepcontrol { .. }
395 | Instruction::FenceProxy { .. }
396 | Instruction::BarrierCluster
397 | Instruction::FenceCluster
398 | Instruction::Comment(_)
399 | Instruction::Raw(_)
400 | Instruction::Pragma(_) => vec![],
401
402 Instruction::Branch { predicate, .. } => {
403 if let Some((reg, _)) = predicate {
404 vec![reg]
405 } else {
406 vec![]
407 }
408 }
409
410 Instruction::Wmma {
411 op,
412 fragments,
413 addr,
414 stride,
415 ..
416 } => {
417 let mut regs: Vec<&Register> = Vec::new();
418 match op {
419 WmmaOp::LoadA | WmmaOp::LoadB => {
420 if let Some(a) = addr {
421 regs.extend(operand_regs(a));
422 }
423 if let Some(s) = stride {
424 regs.extend(operand_regs(s));
425 }
426 }
427 WmmaOp::StoreD => {
428 regs.extend(fragments.iter());
429 if let Some(a) = addr {
430 regs.extend(operand_regs(a));
431 }
432 if let Some(s) = stride {
433 regs.extend(operand_regs(s));
434 }
435 }
436 WmmaOp::Mma => {
437 regs.extend(fragments.iter());
438 }
439 }
440 regs
441 }
442
443 Instruction::Mma {
444 a_regs,
445 b_regs,
446 c_regs,
447 ..
448 } => {
449 let mut regs: Vec<&Register> = Vec::new();
450 regs.extend(a_regs.iter());
451 regs.extend(b_regs.iter());
452 regs.extend(c_regs.iter());
453 regs
454 }
455
456 Instruction::Wgmma { desc_a, desc_b, .. } => vec![desc_a, desc_b],
457
458 Instruction::TmaLoad {
459 dst_shared,
460 desc,
461 coords,
462 barrier,
463 ..
464 } => {
465 let mut regs = operand_regs(dst_shared);
466 regs.push(desc);
467 regs.extend(coords.iter());
468 regs.push(barrier);
469 regs
470 }
471
472 Instruction::Atom { addr, src, .. } | Instruction::Red { addr, src, .. } => {
474 let mut regs = operand_regs(addr);
475 regs.extend(operand_regs(src));
476 regs
477 }
478 Instruction::AtomCas {
480 addr,
481 compare,
482 value,
483 ..
484 } => {
485 let mut regs = operand_regs(addr);
486 regs.extend(operand_regs(compare));
487 regs.extend(operand_regs(value));
488 regs
489 }
490
491 Instruction::Tex1d { coord, .. } | Instruction::SurfLoad { coord, .. } => {
493 operand_regs(coord)
494 }
495 Instruction::Tex2d {
496 coord_x, coord_y, ..
497 } => {
498 let mut regs = operand_regs(coord_x);
499 regs.extend(operand_regs(coord_y));
500 regs
501 }
502 Instruction::Tex3d {
503 coord_x,
504 coord_y,
505 coord_z,
506 ..
507 } => {
508 let mut regs = operand_regs(coord_x);
509 regs.extend(operand_regs(coord_y));
510 regs.extend(operand_regs(coord_z));
511 regs
512 }
513 Instruction::SurfStore { coord, src, .. } => {
514 let mut regs = operand_regs(coord);
515 regs.push(src);
516 regs
517 }
518
519 Instruction::Stmatrix { dst_addr, src, .. } => {
521 let mut regs = operand_regs(dst_addr);
522 regs.push(src);
523 regs
524 }
525 Instruction::MbarrierInit { addr, count, .. } => {
526 let mut regs = operand_regs(addr);
527 regs.extend(operand_regs(count));
528 regs
529 }
530 Instruction::MbarrierWait { addr, phase } => {
531 let mut regs = operand_regs(addr);
532 regs.extend(operand_regs(phase));
533 regs
534 }
535
536 Instruction::Tcgen05Mma { a_desc, b_desc } => vec![a_desc, b_desc],
537
538 Instruction::CpAsyncBulk {
539 dst_smem,
540 src_gmem,
541 desc,
542 } => vec![dst_smem, src_gmem, desc],
543
544 Instruction::Ldmatrix { src_addr, .. } => operand_regs(src_addr),
545 }
546}
547
548fn operand_regs(op: &Operand) -> Vec<&Register> {
550 match op {
551 Operand::Register(reg) => vec![reg],
552 Operand::Address { base, .. } => vec![base],
553 Operand::Immediate(_) | Operand::Symbol(_) => vec![],
554 }
555}
556
557#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::ir::{
565 CacheQualifier, FenceScope, ImmValue, Instruction, MemorySpace, MulMode, Operand, PtxType,
566 Register, SpecialReg, VectorWidth, WmmaOp,
567 };
568
569 fn reg(name: &str, ty: PtxType) -> Register {
570 Register {
571 name: name.to_string(),
572 ty,
573 }
574 }
575
576 fn reg_op(name: &str, ty: PtxType) -> Operand {
577 Operand::Register(reg(name, ty))
578 }
579
580 fn imm_u32(val: u32) -> Operand {
581 Operand::Immediate(ImmValue::U32(val))
582 }
583
584 #[test]
586 fn test_unused_register_removed() {
587 let instructions = vec![
588 Instruction::Add {
589 ty: PtxType::F32,
590 dst: reg("%f0", PtxType::F32),
591 a: imm_u32(1),
592 b: imm_u32(2),
593 },
594 ];
596 let (result, eliminated) = eliminate_dead_code(&instructions);
597 assert_eq!(eliminated, 1);
598 assert!(result.is_empty());
599 }
600
601 #[test]
603 fn test_used_register_kept() {
604 let instructions = vec![
605 Instruction::MovSpecial {
606 dst: reg("%r0", PtxType::U32),
607 special: SpecialReg::TidX,
608 },
609 Instruction::Store {
610 space: MemorySpace::Global,
611 qualifier: CacheQualifier::None,
612 vec: VectorWidth::V1,
613 ty: PtxType::U32,
614 addr: Operand::Address {
615 base: reg("%rd0", PtxType::U64),
616 offset: None,
617 },
618 src: reg("%r0", PtxType::U32),
619 },
620 ];
621 let (result, eliminated) = eliminate_dead_code(&instructions);
622 assert_eq!(eliminated, 0);
623 assert_eq!(result.len(), 2);
624 }
625
626 #[test]
628 fn test_stores_never_removed() {
629 let instructions = vec![Instruction::Store {
630 space: MemorySpace::Global,
631 qualifier: CacheQualifier::None,
632 vec: VectorWidth::V1,
633 ty: PtxType::F32,
634 addr: Operand::Address {
635 base: reg("%rd0", PtxType::U64),
636 offset: None,
637 },
638 src: reg("%f0", PtxType::F32),
639 }];
640 let (result, eliminated) = eliminate_dead_code(&instructions);
641 assert_eq!(eliminated, 0);
642 assert_eq!(result.len(), 1);
643 }
644
645 #[test]
647 fn test_branches_never_removed() {
648 let instructions = vec![
649 Instruction::Branch {
650 target: "loop".to_string(),
651 predicate: None,
652 },
653 Instruction::Label("loop".to_string()),
654 ];
655 let (result, eliminated) = eliminate_dead_code(&instructions);
656 assert_eq!(eliminated, 0);
657 assert_eq!(result.len(), 2);
658 }
659
660 #[test]
662 fn test_barrier_never_removed() {
663 let instructions = vec![Instruction::BarSync { id: 0 }];
664 let (result, eliminated) = eliminate_dead_code(&instructions);
665 assert_eq!(eliminated, 0);
666 assert_eq!(result.len(), 1);
667 }
668
669 #[test]
671 fn test_bar_arrive_never_removed() {
672 let instructions = vec![Instruction::BarArrive { id: 0, count: 32 }];
673 let (result, eliminated) = eliminate_dead_code(&instructions);
674 assert_eq!(eliminated, 0);
675 assert_eq!(result.len(), 1);
676 }
677
678 #[test]
680 fn test_fence_never_removed() {
681 let instructions = vec![Instruction::FenceAcqRel {
682 scope: FenceScope::Gpu,
683 }];
684 let (result, eliminated) = eliminate_dead_code(&instructions);
685 assert_eq!(eliminated, 0);
686 assert_eq!(result.len(), 1);
687 }
688
689 #[test]
691 fn test_return_never_removed() {
692 let instructions = vec![Instruction::Return];
693 let (result, eliminated) = eliminate_dead_code(&instructions);
694 assert_eq!(eliminated, 0);
695 assert_eq!(result.len(), 1);
696 }
697
698 #[test]
700 fn test_comment_never_removed() {
701 let instructions = vec![Instruction::Comment("keep me".to_string())];
702 let (result, eliminated) = eliminate_dead_code(&instructions);
703 assert_eq!(eliminated, 0);
704 assert_eq!(result.len(), 1);
705 }
706
707 #[test]
709 fn test_raw_never_removed() {
710 let instructions = vec![Instruction::Raw("nop;".to_string())];
711 let (result, eliminated) = eliminate_dead_code(&instructions);
712 assert_eq!(eliminated, 0);
713 assert_eq!(result.len(), 1);
714 }
715
716 #[test]
719 fn test_chain_of_dead_instructions() {
720 let instructions = vec![
721 Instruction::Add {
723 ty: PtxType::F32,
724 dst: reg("%f0", PtxType::F32),
725 a: imm_u32(1),
726 b: imm_u32(2),
727 },
728 Instruction::Add {
730 ty: PtxType::F32,
731 dst: reg("%f1", PtxType::F32),
732 a: reg_op("%f0", PtxType::F32),
733 b: imm_u32(3),
734 },
735 ];
736 let (result, eliminated) = eliminate_dead_code(&instructions);
737 assert_eq!(eliminated, 2);
741 assert!(result.is_empty());
742 }
743
744 #[test]
746 fn test_three_level_dead_chain() {
747 let instructions = vec![
748 Instruction::Add {
749 ty: PtxType::U32,
750 dst: reg("%r0", PtxType::U32),
751 a: imm_u32(1),
752 b: imm_u32(2),
753 },
754 Instruction::Mul {
755 ty: PtxType::U32,
756 mode: MulMode::Lo,
757 dst: reg("%r1", PtxType::U32),
758 a: reg_op("%r0", PtxType::U32),
759 b: imm_u32(3),
760 },
761 Instruction::Sub {
762 ty: PtxType::U32,
763 dst: reg("%r2", PtxType::U32),
764 a: reg_op("%r1", PtxType::U32),
765 b: imm_u32(4),
766 },
767 ];
768 let (result, eliminated) = eliminate_dead_code(&instructions);
769 assert_eq!(eliminated, 3);
770 assert!(result.is_empty());
771 }
772
773 #[test]
775 fn test_no_dead_code_unchanged() {
776 let instructions = vec![
777 Instruction::MovSpecial {
778 dst: reg("%r0", PtxType::U32),
779 special: SpecialReg::TidX,
780 },
781 Instruction::Add {
782 ty: PtxType::U32,
783 dst: reg("%r1", PtxType::U32),
784 a: reg_op("%r0", PtxType::U32),
785 b: imm_u32(1),
786 },
787 Instruction::Store {
788 space: MemorySpace::Global,
789 qualifier: CacheQualifier::None,
790 vec: VectorWidth::V1,
791 ty: PtxType::U32,
792 addr: Operand::Address {
793 base: reg("%rd0", PtxType::U64),
794 offset: None,
795 },
796 src: reg("%r1", PtxType::U32),
797 },
798 ];
799 let (result, eliminated) = eliminate_dead_code(&instructions);
800 assert_eq!(eliminated, 0);
801 assert_eq!(result.len(), 3);
802 }
803
804 #[test]
806 fn test_cp_async_never_removed() {
807 let instructions = vec![
808 Instruction::CpAsync {
809 bytes: 16,
810 dst_shared: Operand::Address {
811 base: reg("%rd0", PtxType::U64),
812 offset: None,
813 },
814 src_global: Operand::Address {
815 base: reg("%rd1", PtxType::U64),
816 offset: None,
817 },
818 },
819 Instruction::CpAsyncCommit,
820 Instruction::CpAsyncWait { n: 0 },
821 ];
822 let (result, eliminated) = eliminate_dead_code(&instructions);
823 assert_eq!(eliminated, 0);
824 assert_eq!(result.len(), 3);
825 }
826
827 #[test]
829 fn test_tma_load_never_removed() {
830 let instructions = vec![Instruction::TmaLoad {
831 dst_shared: Operand::Address {
832 base: reg("%rd0", PtxType::U64),
833 offset: None,
834 },
835 desc: reg("%rd1", PtxType::U64),
836 coords: vec![reg("%r0", PtxType::U32)],
837 barrier: reg("%rd2", PtxType::U64),
838 }];
839 let (result, eliminated) = eliminate_dead_code(&instructions);
840 assert_eq!(eliminated, 0);
841 assert_eq!(result.len(), 1);
842 }
843
844 #[test]
846 fn test_mixed_live_and_dead() {
847 let instructions = vec![
848 Instruction::MovSpecial {
850 dst: reg("%r0", PtxType::U32),
851 special: SpecialReg::TidX,
852 },
853 Instruction::Add {
854 ty: PtxType::U32,
855 dst: reg("%r1", PtxType::U32),
856 a: reg_op("%r0", PtxType::U32),
857 b: imm_u32(1),
858 },
859 Instruction::Mul {
861 ty: PtxType::U32,
862 mode: MulMode::Lo,
863 dst: reg("%r2", PtxType::U32),
864 a: reg_op("%r0", PtxType::U32),
865 b: imm_u32(2),
866 },
867 Instruction::Store {
868 space: MemorySpace::Global,
869 qualifier: CacheQualifier::None,
870 vec: VectorWidth::V1,
871 ty: PtxType::U32,
872 addr: Operand::Address {
873 base: reg("%rd0", PtxType::U64),
874 offset: None,
875 },
876 src: reg("%r1", PtxType::U32),
877 },
878 ];
879 let (result, eliminated) = eliminate_dead_code(&instructions);
880 assert_eq!(eliminated, 1);
881 assert_eq!(result.len(), 3);
882 }
883
884 #[test]
886 fn test_empty_instructions() {
887 let (result, eliminated) = eliminate_dead_code(&[]);
888 assert_eq!(eliminated, 0);
889 assert!(result.is_empty());
890 }
891
892 #[test]
895 fn test_dead_load_removed() {
896 let instructions = vec![Instruction::Load {
897 space: MemorySpace::Global,
898 qualifier: CacheQualifier::None,
899 vec: VectorWidth::V1,
900 ty: PtxType::F32,
901 dst: reg("%f0", PtxType::F32),
902 addr: Operand::Address {
903 base: reg("%rd0", PtxType::U64),
904 offset: None,
905 },
906 }];
907 let (result, eliminated) = eliminate_dead_code(&instructions);
908 assert_eq!(eliminated, 1);
909 assert!(result.is_empty());
910 }
911
912 #[test]
914 fn test_wmma_store_never_removed() {
915 use crate::ir::{WmmaLayout, WmmaShape};
916
917 let instructions = vec![Instruction::Wmma {
918 op: WmmaOp::StoreD,
919 shape: WmmaShape::M16N16K16,
920 layout: WmmaLayout::RowMajor,
921 ty: PtxType::F16,
922 fragments: vec![reg("%f0", PtxType::F16), reg("%f1", PtxType::F16)],
923 addr: Some(Operand::Address {
924 base: reg("%rd0", PtxType::U64),
925 offset: None,
926 }),
927 stride: None,
928 }];
929 let (result, eliminated) = eliminate_dead_code(&instructions);
930 assert_eq!(eliminated, 0);
931 assert_eq!(result.len(), 1);
932 }
933
934 #[test]
936 fn test_side_effects_classification() {
937 let add = Instruction::Add {
939 ty: PtxType::F32,
940 dst: reg("%f0", PtxType::F32),
941 a: imm_u32(0),
942 b: imm_u32(0),
943 };
944 assert!(!has_side_effects(&add));
945
946 let store = Instruction::Store {
948 space: MemorySpace::Global,
949 qualifier: CacheQualifier::None,
950 vec: VectorWidth::V1,
951 ty: PtxType::F32,
952 addr: Operand::Address {
953 base: reg("%rd0", PtxType::U64),
954 offset: None,
955 },
956 src: reg("%f0", PtxType::F32),
957 };
958 assert!(has_side_effects(&store));
959
960 let branch = Instruction::Branch {
962 target: "L1".to_string(),
963 predicate: None,
964 };
965 assert!(has_side_effects(&branch));
966
967 let label = Instruction::Label("L1".to_string());
969 assert!(has_side_effects(&label));
970
971 let bar = Instruction::BarSync { id: 0 };
973 assert!(has_side_effects(&bar));
974
975 let mov = Instruction::MovSpecial {
977 dst: reg("%r0", PtxType::U32),
978 special: SpecialReg::TidX,
979 };
980 assert!(!has_side_effects(&mov));
981 }
982
983 #[test]
992 fn test_dce_removes_unreachable_block() {
993 let instructions = vec![
994 Instruction::Branch {
996 target: "after_dead".to_string(),
997 predicate: None,
998 },
999 Instruction::Add {
1003 ty: PtxType::F32,
1004 dst: reg("%f_dead0", PtxType::F32),
1005 a: imm_u32(1),
1006 b: imm_u32(2),
1007 },
1008 Instruction::Mul {
1009 ty: PtxType::F32,
1010 mode: MulMode::Lo,
1011 dst: reg("%f_dead1", PtxType::F32),
1012 a: reg_op("%f_dead0", PtxType::F32),
1013 b: imm_u32(3),
1014 },
1015 Instruction::Label("after_dead".to_string()),
1017 Instruction::Return,
1018 ];
1019
1020 let (result, eliminated) = eliminate_dead_code(&instructions);
1021
1022 assert_eq!(
1027 eliminated, 2,
1028 "DCE must eliminate both unreachable pure-computation instructions"
1029 );
1030 assert_eq!(
1032 result.len(),
1033 3,
1034 "Branch, Label and Return must be preserved"
1035 );
1036 }
1037
1038 #[test]
1041 fn test_dce_keeps_reachable_blocks() {
1042 let instructions = vec![
1043 Instruction::MovSpecial {
1044 dst: reg("%r0", PtxType::U32),
1045 special: SpecialReg::TidX,
1046 },
1047 Instruction::Add {
1048 ty: PtxType::U32,
1049 dst: reg("%r1", PtxType::U32),
1050 a: reg_op("%r0", PtxType::U32),
1051 b: imm_u32(10),
1052 },
1053 Instruction::Mul {
1054 ty: PtxType::U32,
1055 mode: MulMode::Lo,
1056 dst: reg("%r2", PtxType::U32),
1057 a: reg_op("%r1", PtxType::U32),
1058 b: imm_u32(4),
1059 },
1060 Instruction::Store {
1062 space: MemorySpace::Global,
1063 qualifier: CacheQualifier::None,
1064 vec: VectorWidth::V1,
1065 ty: PtxType::U32,
1066 addr: Operand::Address {
1067 base: reg("%rd0", PtxType::U64),
1068 offset: None,
1069 },
1070 src: reg("%r2", PtxType::U32),
1071 },
1072 ];
1073
1074 let (result, eliminated) = eliminate_dead_code(&instructions);
1075
1076 assert_eq!(
1077 eliminated, 0,
1078 "no instruction should be eliminated from a fully-live chain"
1079 );
1080 assert_eq!(
1081 result.len(),
1082 instructions.len(),
1083 "all instructions must survive DCE"
1084 );
1085 }
1086
1087 #[test]
1090 fn test_dce_idempotent() {
1091 let instructions = vec![
1092 Instruction::MovSpecial {
1094 dst: reg("%r0", PtxType::U32),
1095 special: SpecialReg::TidX,
1096 },
1097 Instruction::Add {
1099 ty: PtxType::F32,
1100 dst: reg("%f_unused", PtxType::F32),
1101 a: imm_u32(7),
1102 b: imm_u32(8),
1103 },
1104 Instruction::Store {
1105 space: MemorySpace::Global,
1106 qualifier: CacheQualifier::None,
1107 vec: VectorWidth::V1,
1108 ty: PtxType::U32,
1109 addr: Operand::Address {
1110 base: reg("%rd0", PtxType::U64),
1111 offset: None,
1112 },
1113 src: reg("%r0", PtxType::U32),
1114 },
1115 ];
1116
1117 let (first_result, first_eliminated) = eliminate_dead_code(&instructions);
1118 let (second_result, second_eliminated) = eliminate_dead_code(&first_result);
1120
1121 assert_eq!(
1122 second_eliminated, 0,
1123 "second DCE pass must not eliminate anything additional (idempotent)"
1124 );
1125 assert_eq!(
1126 first_result.len(),
1127 second_result.len(),
1128 "result length must be the same on both passes"
1129 );
1130 assert_eq!(
1131 first_eliminated, 1,
1132 "first pass must eliminate the unused Add instruction"
1133 );
1134 }
1135}