Skip to main content

oxicuda_ptx/analysis/
dead_code.rs

1//! Dead code elimination for PTX instruction sequences.
2//!
3//! This module implements a fixed-point dead code elimination (DCE) pass.
4//! An instruction is *dead* if it defines a register that is never used by
5//! any subsequent instruction **and** the instruction has no side effects.
6
7use std::collections::HashSet;
8
9use crate::ir::{Instruction, Operand, Register, WmmaOp};
10
11// ---------------------------------------------------------------------------
12// Public API
13// ---------------------------------------------------------------------------
14
15/// Remove instructions whose results are never used.
16///
17/// Performs iterative dead code elimination until a fixed point is reached:
18/// each pass removes instructions that define registers not consumed by any
19/// other instruction, provided the instruction has no side effects.
20///
21/// # Arguments
22///
23/// * `instructions` - The original instruction sequence.
24///
25/// # Returns
26///
27/// A tuple of `(optimized_instructions, eliminated_count)`.
28pub fn eliminate_dead_code(instructions: &[Instruction]) -> (Vec<Instruction>, usize) {
29    let mut current: Vec<Instruction> = instructions.to_vec();
30    let mut total_eliminated: usize = 0;
31
32    loop {
33        let (next, eliminated) = dce_pass(&current);
34        if eliminated == 0 {
35            break;
36        }
37        total_eliminated += eliminated;
38        current = next;
39    }
40
41    (current, total_eliminated)
42}
43
44// ---------------------------------------------------------------------------
45// Internal pass
46// ---------------------------------------------------------------------------
47
48/// Single pass of dead code elimination.
49///
50/// Returns `(surviving_instructions, number_eliminated)`.
51fn dce_pass(instructions: &[Instruction]) -> (Vec<Instruction>, usize) {
52    // Phase 1: Collect the set of all registers that are *used* by any instruction.
53    let mut used_regs: HashSet<String> = HashSet::new();
54    for inst in instructions {
55        for reg in uses(inst) {
56            used_regs.insert(reg.name.clone());
57        }
58    }
59
60    // Phase 2: Mark each instruction as live or dead.
61    let mut result = Vec::with_capacity(instructions.len());
62    let mut eliminated: usize = 0;
63
64    for inst in instructions {
65        if has_side_effects(inst) {
66            // Side-effecting instructions are always kept.
67            result.push(inst.clone());
68            continue;
69        }
70
71        let defined = defs(inst);
72        if defined.is_empty() {
73            // Instructions that define nothing and have no side effects
74            // are kept (e.g., they shouldn't exist, but be conservative).
75            result.push(inst.clone());
76            continue;
77        }
78
79        // The instruction is dead if *none* of its defined registers are used.
80        let any_def_used = defined.iter().any(|r| used_regs.contains(&r.name));
81
82        if any_def_used {
83            result.push(inst.clone());
84        } else {
85            eliminated += 1;
86        }
87    }
88
89    (result, eliminated)
90}
91
92// ---------------------------------------------------------------------------
93// Side-effect classification
94// ---------------------------------------------------------------------------
95
96/// Check if an instruction has side effects and therefore cannot be eliminated
97/// even when its result register is unused.
98///
99/// Side-effecting instructions include memory stores, control flow, barriers,
100/// fences, TMA operations, async copies, and meta-instructions (comments, raw).
101const fn has_side_effects(inst: &Instruction) -> bool {
102    match inst {
103        // Memory stores, async copy, control flow, synchronization, fences,
104        // TMA load, atomic operations, meta-instructions — all side-effecting.
105        Instruction::Store { .. }
106        | Instruction::CpAsync { .. }
107        | Instruction::CpAsyncCommit
108        | Instruction::CpAsyncWait { .. }
109        | Instruction::Branch { .. }
110        | Instruction::Label(_)
111        | Instruction::Return
112        | Instruction::BarSync { .. }
113        | Instruction::BarArrive { .. }
114        | Instruction::FenceAcqRel { .. }
115        | Instruction::TmaLoad { .. }
116        | Instruction::Atom { .. }
117        | Instruction::AtomCas { .. }
118        | Instruction::Red { .. }
119        | Instruction::SurfStore { .. }
120        | Instruction::Stmatrix { .. }
121        | Instruction::Setmaxnreg { .. }
122        | Instruction::Griddepcontrol { .. }
123        | Instruction::FenceProxy { .. }
124        | Instruction::MbarrierInit { .. }
125        | Instruction::MbarrierArrive { .. }
126        | Instruction::MbarrierWait { .. }
127        | Instruction::Tcgen05Mma { .. }
128        | Instruction::BarrierCluster
129        | Instruction::FenceCluster
130        | Instruction::CpAsyncBulk { .. }
131        | Instruction::Comment(_)
132        | Instruction::Raw(_) => true,
133
134        // WMMA store operations write to memory
135        Instruction::Wmma { op, .. } => matches!(op, WmmaOp::StoreD),
136
137        // Pure computation instructions: no side effects
138        Instruction::Add { .. }
139        | Instruction::Sub { .. }
140        | Instruction::Mul { .. }
141        | Instruction::Mad { .. }
142        | Instruction::Fma { .. }
143        | Instruction::MadLo { .. }
144        | Instruction::MadHi { .. }
145        | Instruction::MadWide { .. }
146        | Instruction::Neg { .. }
147        | Instruction::Abs { .. }
148        | Instruction::Min { .. }
149        | Instruction::Max { .. }
150        | Instruction::Brev { .. }
151        | Instruction::Clz { .. }
152        | Instruction::Popc { .. }
153        | Instruction::Bfind { .. }
154        | Instruction::Bfe { .. }
155        | Instruction::Bfi { .. }
156        | Instruction::Rcp { .. }
157        | Instruction::Rsqrt { .. }
158        | Instruction::Sqrt { .. }
159        | Instruction::Ex2 { .. }
160        | Instruction::Lg2 { .. }
161        | Instruction::Sin { .. }
162        | Instruction::Cos { .. }
163        | Instruction::Shl { .. }
164        | Instruction::Shr { .. }
165        | Instruction::Div { .. }
166        | Instruction::Rem { .. }
167        | Instruction::And { .. }
168        | Instruction::Or { .. }
169        | Instruction::Xor { .. }
170        | Instruction::SetP { .. }
171        | Instruction::Load { .. }
172        | Instruction::Cvt { .. }
173        | Instruction::Mma { .. }
174        | Instruction::Wgmma { .. }
175        | Instruction::MovSpecial { .. }
176        | Instruction::LoadParam { .. }
177        | Instruction::Dp4a { .. }
178        | Instruction::Dp2a { .. }
179        | Instruction::Tex1d { .. }
180        | Instruction::Tex2d { .. }
181        | Instruction::Tex3d { .. }
182        | Instruction::SurfLoad { .. }
183        | Instruction::Redux { .. }
184        | Instruction::ElectSync { .. }
185        // Pragma is a directive hint — no side effects on execution
186        | Instruction::Pragma(_)
187        // ldmatrix: warp-cooperative load — result registers are output side, no mem-side effect
188        | Instruction::Ldmatrix { .. } => false,
189    }
190}
191
192// ---------------------------------------------------------------------------
193// Register extraction helpers (mirrored from register_pressure)
194// ---------------------------------------------------------------------------
195
196/// Extract registers defined (written to) by an instruction.
197fn defs(inst: &Instruction) -> Vec<&Register> {
198    match inst {
199        Instruction::Add { dst, .. }
200        | Instruction::Sub { dst, .. }
201        | Instruction::Mul { dst, .. }
202        | Instruction::Mad { dst, .. }
203        | Instruction::MadLo { dst, .. }
204        | Instruction::MadHi { dst, .. }
205        | Instruction::MadWide { dst, .. }
206        | Instruction::Fma { dst, .. }
207        | Instruction::Neg { dst, .. }
208        | Instruction::Abs { dst, .. }
209        | Instruction::Min { dst, .. }
210        | Instruction::Max { dst, .. }
211        | Instruction::Brev { dst, .. }
212        | Instruction::Clz { dst, .. }
213        | Instruction::Popc { dst, .. }
214        | Instruction::Bfind { dst, .. }
215        | Instruction::Bfe { dst, .. }
216        | Instruction::Bfi { dst, .. }
217        | Instruction::Rcp { dst, .. }
218        | Instruction::Rsqrt { dst, .. }
219        | Instruction::Sqrt { dst, .. }
220        | Instruction::Ex2 { dst, .. }
221        | Instruction::Lg2 { dst, .. }
222        | Instruction::Sin { dst, .. }
223        | Instruction::Cos { dst, .. }
224        | Instruction::Shl { dst, .. }
225        | Instruction::Shr { dst, .. }
226        | Instruction::Div { dst, .. }
227        | Instruction::Rem { dst, .. }
228        | Instruction::And { dst, .. }
229        | Instruction::Or { dst, .. }
230        | Instruction::Xor { dst, .. }
231        | Instruction::SetP { dst, .. }
232        | Instruction::Load { dst, .. }
233        | Instruction::Cvt { dst, .. }
234        | Instruction::MovSpecial { dst, .. }
235        | Instruction::LoadParam { dst, .. }
236        | Instruction::Atom { dst, .. }
237        | Instruction::AtomCas { dst, .. }
238        | Instruction::Dp4a { dst, .. }
239        | Instruction::Dp2a { dst, .. }
240        | Instruction::Tex1d { dst, .. }
241        | Instruction::Tex2d { dst, .. }
242        | Instruction::Tex3d { dst, .. }
243        | Instruction::SurfLoad { dst, .. }
244        | Instruction::Redux { dst, .. }
245        | Instruction::ElectSync { dst, .. } => vec![dst],
246
247        Instruction::Ldmatrix { dst_regs, .. } => dst_regs.iter().collect(),
248
249        Instruction::Store { .. }
250        | Instruction::CpAsync { .. }
251        | Instruction::CpAsyncCommit
252        | Instruction::CpAsyncWait { .. }
253        | Instruction::Branch { .. }
254        | Instruction::Label(_)
255        | Instruction::Return
256        | Instruction::BarSync { .. }
257        | Instruction::BarArrive { .. }
258        | Instruction::FenceAcqRel { .. }
259        | Instruction::TmaLoad { .. }
260        | Instruction::Red { .. }
261        | Instruction::SurfStore { .. }
262        | Instruction::Stmatrix { .. }
263        | Instruction::Setmaxnreg { .. }
264        | Instruction::Griddepcontrol { .. }
265        | Instruction::FenceProxy { .. }
266        | Instruction::MbarrierInit { .. }
267        | Instruction::MbarrierArrive { .. }
268        | Instruction::MbarrierWait { .. }
269        | Instruction::Tcgen05Mma { .. }
270        | Instruction::BarrierCluster
271        | Instruction::FenceCluster
272        | Instruction::CpAsyncBulk { .. }
273        | Instruction::Comment(_)
274        | Instruction::Raw(_)
275        | Instruction::Pragma(_) => vec![],
276
277        Instruction::Wmma { op, fragments, .. } => match op {
278            WmmaOp::LoadA | WmmaOp::LoadB | WmmaOp::Mma => fragments.iter().collect(),
279            WmmaOp::StoreD => vec![],
280        },
281        Instruction::Mma { d_regs, .. } | Instruction::Wgmma { d_regs, .. } => {
282            d_regs.iter().collect()
283        }
284    }
285}
286
287/// Extract registers used (read from) by an instruction.
288#[allow(clippy::too_many_lines)]
289fn uses(inst: &Instruction) -> Vec<&Register> {
290    match inst {
291        Instruction::Add { a, b, .. }
292        | Instruction::Sub { a, b, .. }
293        | Instruction::Mul { a, b, .. }
294        | Instruction::Min { a, b, .. }
295        | Instruction::Max { a, b, .. }
296        | Instruction::Div { a, b, .. }
297        | Instruction::Rem { a, b, .. }
298        | Instruction::And { a, b, .. }
299        | Instruction::Or { a, b, .. }
300        | Instruction::Xor { a, b, .. }
301        | Instruction::SetP { a, b, .. }
302        | Instruction::Shl {
303            src: a, amount: b, ..
304        }
305        | Instruction::Shr {
306            src: a, amount: b, ..
307        } => {
308            let mut regs = operand_regs(a);
309            regs.extend(operand_regs(b));
310            regs
311        }
312
313        Instruction::Mad { a, b, c, .. }
314        | Instruction::MadLo { a, b, c, .. }
315        | Instruction::MadHi { a, b, c, .. }
316        | Instruction::MadWide { a, b, c, .. }
317        | Instruction::Fma { a, b, c, .. }
318        | Instruction::Dp4a { a, b, c, .. }
319        | Instruction::Dp2a { a, b, c, .. } => {
320            let mut regs = operand_regs(a);
321            regs.extend(operand_regs(b));
322            regs.extend(operand_regs(c));
323            regs
324        }
325
326        Instruction::Neg { src, .. }
327        | Instruction::Abs { src, .. }
328        | Instruction::Brev { src, .. }
329        | Instruction::Clz { src, .. }
330        | Instruction::Popc { src, .. }
331        | Instruction::Bfind { src, .. }
332        | Instruction::Rcp { src, .. }
333        | Instruction::Rsqrt { src, .. }
334        | Instruction::Sqrt { src, .. }
335        | Instruction::Ex2 { src, .. }
336        | Instruction::Lg2 { src, .. }
337        | Instruction::Sin { src, .. }
338        | Instruction::Cos { src, .. }
339        | Instruction::Cvt { src, .. }
340        | Instruction::Redux { src, .. } => operand_regs(src),
341
342        Instruction::Bfe {
343            src, start, len, ..
344        } => {
345            let mut regs = operand_regs(src);
346            regs.extend(operand_regs(start));
347            regs.extend(operand_regs(len));
348            regs
349        }
350
351        Instruction::Bfi {
352            insert,
353            base,
354            start,
355            len,
356            ..
357        } => {
358            let mut regs = operand_regs(insert);
359            regs.extend(operand_regs(base));
360            regs.extend(operand_regs(start));
361            regs.extend(operand_regs(len));
362            regs
363        }
364
365        Instruction::Load { addr, .. } | Instruction::MbarrierArrive { addr } => operand_regs(addr),
366
367        Instruction::Store { addr, src, .. } => {
368            let mut regs = operand_regs(addr);
369            regs.push(src);
370            regs
371        }
372
373        Instruction::CpAsync {
374            dst_shared,
375            src_global,
376            ..
377        } => {
378            let mut regs = operand_regs(dst_shared);
379            regs.extend(operand_regs(src_global));
380            regs
381        }
382
383        Instruction::CpAsyncCommit
384        | Instruction::CpAsyncWait { .. }
385        | Instruction::Label(_)
386        | Instruction::Return
387        | Instruction::BarSync { .. }
388        | Instruction::BarArrive { .. }
389        | Instruction::FenceAcqRel { .. }
390        | Instruction::MovSpecial { .. }
391        | Instruction::LoadParam { .. }
392        | Instruction::ElectSync { .. }
393        | Instruction::Setmaxnreg { .. }
394        | Instruction::Griddepcontrol { .. }
395        | Instruction::FenceProxy { .. }
396        | Instruction::BarrierCluster
397        | Instruction::FenceCluster
398        | Instruction::Comment(_)
399        | Instruction::Raw(_)
400        | Instruction::Pragma(_) => vec![],
401
402        Instruction::Branch { predicate, .. } => {
403            if let Some((reg, _)) = predicate {
404                vec![reg]
405            } else {
406                vec![]
407            }
408        }
409
410        Instruction::Wmma {
411            op,
412            fragments,
413            addr,
414            stride,
415            ..
416        } => {
417            let mut regs: Vec<&Register> = Vec::new();
418            match op {
419                WmmaOp::LoadA | WmmaOp::LoadB => {
420                    if let Some(a) = addr {
421                        regs.extend(operand_regs(a));
422                    }
423                    if let Some(s) = stride {
424                        regs.extend(operand_regs(s));
425                    }
426                }
427                WmmaOp::StoreD => {
428                    regs.extend(fragments.iter());
429                    if let Some(a) = addr {
430                        regs.extend(operand_regs(a));
431                    }
432                    if let Some(s) = stride {
433                        regs.extend(operand_regs(s));
434                    }
435                }
436                WmmaOp::Mma => {
437                    regs.extend(fragments.iter());
438                }
439            }
440            regs
441        }
442
443        Instruction::Mma {
444            a_regs,
445            b_regs,
446            c_regs,
447            ..
448        } => {
449            let mut regs: Vec<&Register> = Vec::new();
450            regs.extend(a_regs.iter());
451            regs.extend(b_regs.iter());
452            regs.extend(c_regs.iter());
453            regs
454        }
455
456        Instruction::Wgmma { desc_a, desc_b, .. } => vec![desc_a, desc_b],
457
458        Instruction::TmaLoad {
459            dst_shared,
460            desc,
461            coords,
462            barrier,
463            ..
464        } => {
465            let mut regs = operand_regs(dst_shared);
466            regs.push(desc);
467            regs.extend(coords.iter());
468            regs.push(barrier);
469            regs
470        }
471
472        // Atomic: reads addr and src
473        Instruction::Atom { addr, src, .. } | Instruction::Red { addr, src, .. } => {
474            let mut regs = operand_regs(addr);
475            regs.extend(operand_regs(src));
476            regs
477        }
478        // AtomCas: reads addr, compare, and value
479        Instruction::AtomCas {
480            addr,
481            compare,
482            value,
483            ..
484        } => {
485            let mut regs = operand_regs(addr);
486            regs.extend(operand_regs(compare));
487            regs.extend(operand_regs(value));
488            regs
489        }
490
491        // Texture: coord registers are used
492        Instruction::Tex1d { coord, .. } | Instruction::SurfLoad { coord, .. } => {
493            operand_regs(coord)
494        }
495        Instruction::Tex2d {
496            coord_x, coord_y, ..
497        } => {
498            let mut regs = operand_regs(coord_x);
499            regs.extend(operand_regs(coord_y));
500            regs
501        }
502        Instruction::Tex3d {
503            coord_x,
504            coord_y,
505            coord_z,
506            ..
507        } => {
508            let mut regs = operand_regs(coord_x);
509            regs.extend(operand_regs(coord_y));
510            regs.extend(operand_regs(coord_z));
511            regs
512        }
513        Instruction::SurfStore { coord, src, .. } => {
514            let mut regs = operand_regs(coord);
515            regs.push(src);
516            regs
517        }
518
519        // PTX 8.x instructions
520        Instruction::Stmatrix { dst_addr, src, .. } => {
521            let mut regs = operand_regs(dst_addr);
522            regs.push(src);
523            regs
524        }
525        Instruction::MbarrierInit { addr, count, .. } => {
526            let mut regs = operand_regs(addr);
527            regs.extend(operand_regs(count));
528            regs
529        }
530        Instruction::MbarrierWait { addr, phase } => {
531            let mut regs = operand_regs(addr);
532            regs.extend(operand_regs(phase));
533            regs
534        }
535
536        Instruction::Tcgen05Mma { a_desc, b_desc } => vec![a_desc, b_desc],
537
538        Instruction::CpAsyncBulk {
539            dst_smem,
540            src_gmem,
541            desc,
542        } => vec![dst_smem, src_gmem, desc],
543
544        Instruction::Ldmatrix { src_addr, .. } => operand_regs(src_addr),
545    }
546}
547
548/// Extract register references from an operand.
549fn operand_regs(op: &Operand) -> Vec<&Register> {
550    match op {
551        Operand::Register(reg) => vec![reg],
552        Operand::Address { base, .. } => vec![base],
553        Operand::Immediate(_) | Operand::Symbol(_) => vec![],
554    }
555}
556
557// ---------------------------------------------------------------------------
558// Tests
559// ---------------------------------------------------------------------------
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use crate::ir::{
565        CacheQualifier, FenceScope, ImmValue, Instruction, MemorySpace, MulMode, Operand, PtxType,
566        Register, SpecialReg, VectorWidth, WmmaOp,
567    };
568
569    fn reg(name: &str, ty: PtxType) -> Register {
570        Register {
571            name: name.to_string(),
572            ty,
573        }
574    }
575
576    fn reg_op(name: &str, ty: PtxType) -> Operand {
577        Operand::Register(reg(name, ty))
578    }
579
580    fn imm_u32(val: u32) -> Operand {
581        Operand::Immediate(ImmValue::U32(val))
582    }
583
584    /// Unused register definition should be removed.
585    #[test]
586    fn test_unused_register_removed() {
587        let instructions = vec![
588            Instruction::Add {
589                ty: PtxType::F32,
590                dst: reg("%f0", PtxType::F32),
591                a: imm_u32(1),
592                b: imm_u32(2),
593            },
594            // %f0 is never used
595        ];
596        let (result, eliminated) = eliminate_dead_code(&instructions);
597        assert_eq!(eliminated, 1);
598        assert!(result.is_empty());
599    }
600
601    /// Used register definition should be kept.
602    #[test]
603    fn test_used_register_kept() {
604        let instructions = vec![
605            Instruction::MovSpecial {
606                dst: reg("%r0", PtxType::U32),
607                special: SpecialReg::TidX,
608            },
609            Instruction::Store {
610                space: MemorySpace::Global,
611                qualifier: CacheQualifier::None,
612                vec: VectorWidth::V1,
613                ty: PtxType::U32,
614                addr: Operand::Address {
615                    base: reg("%rd0", PtxType::U64),
616                    offset: None,
617                },
618                src: reg("%r0", PtxType::U32),
619            },
620        ];
621        let (result, eliminated) = eliminate_dead_code(&instructions);
622        assert_eq!(eliminated, 0);
623        assert_eq!(result.len(), 2);
624    }
625
626    /// Store instructions are never removed (side effect).
627    #[test]
628    fn test_stores_never_removed() {
629        let instructions = vec![Instruction::Store {
630            space: MemorySpace::Global,
631            qualifier: CacheQualifier::None,
632            vec: VectorWidth::V1,
633            ty: PtxType::F32,
634            addr: Operand::Address {
635                base: reg("%rd0", PtxType::U64),
636                offset: None,
637            },
638            src: reg("%f0", PtxType::F32),
639        }];
640        let (result, eliminated) = eliminate_dead_code(&instructions);
641        assert_eq!(eliminated, 0);
642        assert_eq!(result.len(), 1);
643    }
644
645    /// Branches are never removed (control flow side effect).
646    #[test]
647    fn test_branches_never_removed() {
648        let instructions = vec![
649            Instruction::Branch {
650                target: "loop".to_string(),
651                predicate: None,
652            },
653            Instruction::Label("loop".to_string()),
654        ];
655        let (result, eliminated) = eliminate_dead_code(&instructions);
656        assert_eq!(eliminated, 0);
657        assert_eq!(result.len(), 2);
658    }
659
660    /// Barrier is never removed (synchronization side effect).
661    #[test]
662    fn test_barrier_never_removed() {
663        let instructions = vec![Instruction::BarSync { id: 0 }];
664        let (result, eliminated) = eliminate_dead_code(&instructions);
665        assert_eq!(eliminated, 0);
666        assert_eq!(result.len(), 1);
667    }
668
669    /// `BarArrive` is never removed.
670    #[test]
671    fn test_bar_arrive_never_removed() {
672        let instructions = vec![Instruction::BarArrive { id: 0, count: 32 }];
673        let (result, eliminated) = eliminate_dead_code(&instructions);
674        assert_eq!(eliminated, 0);
675        assert_eq!(result.len(), 1);
676    }
677
678    /// `FenceAcqRel` is never removed.
679    #[test]
680    fn test_fence_never_removed() {
681        let instructions = vec![Instruction::FenceAcqRel {
682            scope: FenceScope::Gpu,
683        }];
684        let (result, eliminated) = eliminate_dead_code(&instructions);
685        assert_eq!(eliminated, 0);
686        assert_eq!(result.len(), 1);
687    }
688
689    /// Return is never removed.
690    #[test]
691    fn test_return_never_removed() {
692        let instructions = vec![Instruction::Return];
693        let (result, eliminated) = eliminate_dead_code(&instructions);
694        assert_eq!(eliminated, 0);
695        assert_eq!(result.len(), 1);
696    }
697
698    /// Comment is never removed.
699    #[test]
700    fn test_comment_never_removed() {
701        let instructions = vec![Instruction::Comment("keep me".to_string())];
702        let (result, eliminated) = eliminate_dead_code(&instructions);
703        assert_eq!(eliminated, 0);
704        assert_eq!(result.len(), 1);
705    }
706
707    /// Raw PTX is never removed.
708    #[test]
709    fn test_raw_never_removed() {
710        let instructions = vec![Instruction::Raw("nop;".to_string())];
711        let (result, eliminated) = eliminate_dead_code(&instructions);
712        assert_eq!(eliminated, 0);
713        assert_eq!(result.len(), 1);
714    }
715
716    /// Chain of dead instructions: A defines X, B uses X to define Y, Y unused.
717    /// Both should be eliminated (fixed-point iteration).
718    #[test]
719    fn test_chain_of_dead_instructions() {
720        let instructions = vec![
721            // %f0 = 1 + 2 (dead because only used by next instruction)
722            Instruction::Add {
723                ty: PtxType::F32,
724                dst: reg("%f0", PtxType::F32),
725                a: imm_u32(1),
726                b: imm_u32(2),
727            },
728            // %f1 = %f0 + 3 (dead because %f1 is never used)
729            Instruction::Add {
730                ty: PtxType::F32,
731                dst: reg("%f1", PtxType::F32),
732                a: reg_op("%f0", PtxType::F32),
733                b: imm_u32(3),
734            },
735        ];
736        let (result, eliminated) = eliminate_dead_code(&instructions);
737        // First pass: %f1 unused → remove second instruction → eliminated=1
738        // Second pass: %f0 now unused → remove first instruction → eliminated=1
739        // Total: 2
740        assert_eq!(eliminated, 2);
741        assert!(result.is_empty());
742    }
743
744    /// Fixed-point: three-level chain of dead code.
745    #[test]
746    fn test_three_level_dead_chain() {
747        let instructions = vec![
748            Instruction::Add {
749                ty: PtxType::U32,
750                dst: reg("%r0", PtxType::U32),
751                a: imm_u32(1),
752                b: imm_u32(2),
753            },
754            Instruction::Mul {
755                ty: PtxType::U32,
756                mode: MulMode::Lo,
757                dst: reg("%r1", PtxType::U32),
758                a: reg_op("%r0", PtxType::U32),
759                b: imm_u32(3),
760            },
761            Instruction::Sub {
762                ty: PtxType::U32,
763                dst: reg("%r2", PtxType::U32),
764                a: reg_op("%r1", PtxType::U32),
765                b: imm_u32(4),
766            },
767        ];
768        let (result, eliminated) = eliminate_dead_code(&instructions);
769        assert_eq!(eliminated, 3);
770        assert!(result.is_empty());
771    }
772
773    /// Function with no dead code returns identical instructions.
774    #[test]
775    fn test_no_dead_code_unchanged() {
776        let instructions = vec![
777            Instruction::MovSpecial {
778                dst: reg("%r0", PtxType::U32),
779                special: SpecialReg::TidX,
780            },
781            Instruction::Add {
782                ty: PtxType::U32,
783                dst: reg("%r1", PtxType::U32),
784                a: reg_op("%r0", PtxType::U32),
785                b: imm_u32(1),
786            },
787            Instruction::Store {
788                space: MemorySpace::Global,
789                qualifier: CacheQualifier::None,
790                vec: VectorWidth::V1,
791                ty: PtxType::U32,
792                addr: Operand::Address {
793                    base: reg("%rd0", PtxType::U64),
794                    offset: None,
795                },
796                src: reg("%r1", PtxType::U32),
797            },
798        ];
799        let (result, eliminated) = eliminate_dead_code(&instructions);
800        assert_eq!(eliminated, 0);
801        assert_eq!(result.len(), 3);
802    }
803
804    /// `CpAsync` is never removed (DMA side effect).
805    #[test]
806    fn test_cp_async_never_removed() {
807        let instructions = vec![
808            Instruction::CpAsync {
809                bytes: 16,
810                dst_shared: Operand::Address {
811                    base: reg("%rd0", PtxType::U64),
812                    offset: None,
813                },
814                src_global: Operand::Address {
815                    base: reg("%rd1", PtxType::U64),
816                    offset: None,
817                },
818            },
819            Instruction::CpAsyncCommit,
820            Instruction::CpAsyncWait { n: 0 },
821        ];
822        let (result, eliminated) = eliminate_dead_code(&instructions);
823        assert_eq!(eliminated, 0);
824        assert_eq!(result.len(), 3);
825    }
826
827    /// `TmaLoad` is never removed.
828    #[test]
829    fn test_tma_load_never_removed() {
830        let instructions = vec![Instruction::TmaLoad {
831            dst_shared: Operand::Address {
832                base: reg("%rd0", PtxType::U64),
833                offset: None,
834            },
835            desc: reg("%rd1", PtxType::U64),
836            coords: vec![reg("%r0", PtxType::U32)],
837            barrier: reg("%rd2", PtxType::U64),
838        }];
839        let (result, eliminated) = eliminate_dead_code(&instructions);
840        assert_eq!(eliminated, 0);
841        assert_eq!(result.len(), 1);
842    }
843
844    /// Mixed live and dead instructions.
845    #[test]
846    fn test_mixed_live_and_dead() {
847        let instructions = vec![
848            // Live chain: tid → add → store
849            Instruction::MovSpecial {
850                dst: reg("%r0", PtxType::U32),
851                special: SpecialReg::TidX,
852            },
853            Instruction::Add {
854                ty: PtxType::U32,
855                dst: reg("%r1", PtxType::U32),
856                a: reg_op("%r0", PtxType::U32),
857                b: imm_u32(1),
858            },
859            // Dead: %r2 is never used
860            Instruction::Mul {
861                ty: PtxType::U32,
862                mode: MulMode::Lo,
863                dst: reg("%r2", PtxType::U32),
864                a: reg_op("%r0", PtxType::U32),
865                b: imm_u32(2),
866            },
867            Instruction::Store {
868                space: MemorySpace::Global,
869                qualifier: CacheQualifier::None,
870                vec: VectorWidth::V1,
871                ty: PtxType::U32,
872                addr: Operand::Address {
873                    base: reg("%rd0", PtxType::U64),
874                    offset: None,
875                },
876                src: reg("%r1", PtxType::U32),
877            },
878        ];
879        let (result, eliminated) = eliminate_dead_code(&instructions);
880        assert_eq!(eliminated, 1);
881        assert_eq!(result.len(), 3);
882    }
883
884    /// Empty input returns empty output with zero eliminated.
885    #[test]
886    fn test_empty_instructions() {
887        let (result, eliminated) = eliminate_dead_code(&[]);
888        assert_eq!(eliminated, 0);
889        assert!(result.is_empty());
890    }
891
892    /// Load instruction without subsequent use is dead (loads have no side
893    /// effect on other memory; they only write to the destination register).
894    #[test]
895    fn test_dead_load_removed() {
896        let instructions = vec![Instruction::Load {
897            space: MemorySpace::Global,
898            qualifier: CacheQualifier::None,
899            vec: VectorWidth::V1,
900            ty: PtxType::F32,
901            dst: reg("%f0", PtxType::F32),
902            addr: Operand::Address {
903                base: reg("%rd0", PtxType::U64),
904                offset: None,
905            },
906        }];
907        let (result, eliminated) = eliminate_dead_code(&instructions);
908        assert_eq!(eliminated, 1);
909        assert!(result.is_empty());
910    }
911
912    /// `Wmma` `StoreD` is a side effect and must not be removed.
913    #[test]
914    fn test_wmma_store_never_removed() {
915        use crate::ir::{WmmaLayout, WmmaShape};
916
917        let instructions = vec![Instruction::Wmma {
918            op: WmmaOp::StoreD,
919            shape: WmmaShape::M16N16K16,
920            layout: WmmaLayout::RowMajor,
921            ty: PtxType::F16,
922            fragments: vec![reg("%f0", PtxType::F16), reg("%f1", PtxType::F16)],
923            addr: Some(Operand::Address {
924                base: reg("%rd0", PtxType::U64),
925                offset: None,
926            }),
927            stride: None,
928        }];
929        let (result, eliminated) = eliminate_dead_code(&instructions);
930        assert_eq!(eliminated, 0);
931        assert_eq!(result.len(), 1);
932    }
933
934    /// `has_side_effects` correctly classifies all instruction categories.
935    #[test]
936    fn test_side_effects_classification() {
937        // Pure computation → no side effects
938        let add = Instruction::Add {
939            ty: PtxType::F32,
940            dst: reg("%f0", PtxType::F32),
941            a: imm_u32(0),
942            b: imm_u32(0),
943        };
944        assert!(!has_side_effects(&add));
945
946        // Store → side effect
947        let store = Instruction::Store {
948            space: MemorySpace::Global,
949            qualifier: CacheQualifier::None,
950            vec: VectorWidth::V1,
951            ty: PtxType::F32,
952            addr: Operand::Address {
953                base: reg("%rd0", PtxType::U64),
954                offset: None,
955            },
956            src: reg("%f0", PtxType::F32),
957        };
958        assert!(has_side_effects(&store));
959
960        // Branch → side effect
961        let branch = Instruction::Branch {
962            target: "L1".to_string(),
963            predicate: None,
964        };
965        assert!(has_side_effects(&branch));
966
967        // Label → side effect
968        let label = Instruction::Label("L1".to_string());
969        assert!(has_side_effects(&label));
970
971        // BarSync → side effect
972        let bar = Instruction::BarSync { id: 0 };
973        assert!(has_side_effects(&bar));
974
975        // MovSpecial → no side effect (can be DCE'd if unused)
976        let mov = Instruction::MovSpecial {
977            dst: reg("%r0", PtxType::U32),
978            special: SpecialReg::TidX,
979        };
980        assert!(!has_side_effects(&mov));
981    }
982
983    // -------------------------------------------------------------------------
984    // Additional DCE quality-gate tests
985    // -------------------------------------------------------------------------
986
987    /// An unreachable chain of pure computations after a branch is dead:
988    /// the defined registers are never consumed, so DCE eliminates them.
989    /// Note: the Branch and Label are kept (side effects); the pure
990    /// computations whose results feed nowhere are eliminated.
991    #[test]
992    fn test_dce_removes_unreachable_block() {
993        let instructions = vec![
994            // Unconditional branch — side effect, kept
995            Instruction::Branch {
996                target: "after_dead".to_string(),
997                predicate: None,
998            },
999            // The following pure computations are never consumed (the branch
1000            // causes them to be skipped at runtime, and their outputs are
1001            // never used anywhere in this list).
1002            Instruction::Add {
1003                ty: PtxType::F32,
1004                dst: reg("%f_dead0", PtxType::F32),
1005                a: imm_u32(1),
1006                b: imm_u32(2),
1007            },
1008            Instruction::Mul {
1009                ty: PtxType::F32,
1010                mode: MulMode::Lo,
1011                dst: reg("%f_dead1", PtxType::F32),
1012                a: reg_op("%f_dead0", PtxType::F32),
1013                b: imm_u32(3),
1014            },
1015            // Label — side effect, kept
1016            Instruction::Label("after_dead".to_string()),
1017            Instruction::Return,
1018        ];
1019
1020        let (result, eliminated) = eliminate_dead_code(&instructions);
1021
1022        // Branch, Label, Return are always kept (side effects = 3)
1023        // Add and Mul are dead (their outputs go nowhere) = 2 eliminated
1024        // Fixed-point: pass1 eliminates Mul (uses %f_dead0 which is defined but
1025        // then %f_dead0 has no other consumer after Mul is gone), pass2 eliminates Add
1026        assert_eq!(
1027            eliminated, 2,
1028            "DCE must eliminate both unreachable pure-computation instructions"
1029        );
1030        // Branch, Label, Return survive
1031        assert_eq!(
1032            result.len(),
1033            3,
1034            "Branch, Label and Return must be preserved"
1035        );
1036    }
1037
1038    /// Reachable blocks (pure computations whose results feed a store) must NOT
1039    /// be removed by DCE.
1040    #[test]
1041    fn test_dce_keeps_reachable_blocks() {
1042        let instructions = vec![
1043            Instruction::MovSpecial {
1044                dst: reg("%r0", PtxType::U32),
1045                special: SpecialReg::TidX,
1046            },
1047            Instruction::Add {
1048                ty: PtxType::U32,
1049                dst: reg("%r1", PtxType::U32),
1050                a: reg_op("%r0", PtxType::U32),
1051                b: imm_u32(10),
1052            },
1053            Instruction::Mul {
1054                ty: PtxType::U32,
1055                mode: MulMode::Lo,
1056                dst: reg("%r2", PtxType::U32),
1057                a: reg_op("%r1", PtxType::U32),
1058                b: imm_u32(4),
1059            },
1060            // Store consumes %r2 — the whole chain is live
1061            Instruction::Store {
1062                space: MemorySpace::Global,
1063                qualifier: CacheQualifier::None,
1064                vec: VectorWidth::V1,
1065                ty: PtxType::U32,
1066                addr: Operand::Address {
1067                    base: reg("%rd0", PtxType::U64),
1068                    offset: None,
1069                },
1070                src: reg("%r2", PtxType::U32),
1071            },
1072        ];
1073
1074        let (result, eliminated) = eliminate_dead_code(&instructions);
1075
1076        assert_eq!(
1077            eliminated, 0,
1078            "no instruction should be eliminated from a fully-live chain"
1079        );
1080        assert_eq!(
1081            result.len(),
1082            instructions.len(),
1083            "all instructions must survive DCE"
1084        );
1085    }
1086
1087    /// DCE must be idempotent: running the pass twice on the same input must
1088    /// produce the same result as running it once.
1089    #[test]
1090    fn test_dce_idempotent() {
1091        let instructions = vec![
1092            // Live chain
1093            Instruction::MovSpecial {
1094                dst: reg("%r0", PtxType::U32),
1095                special: SpecialReg::TidX,
1096            },
1097            // Dead computation
1098            Instruction::Add {
1099                ty: PtxType::F32,
1100                dst: reg("%f_unused", PtxType::F32),
1101                a: imm_u32(7),
1102                b: imm_u32(8),
1103            },
1104            Instruction::Store {
1105                space: MemorySpace::Global,
1106                qualifier: CacheQualifier::None,
1107                vec: VectorWidth::V1,
1108                ty: PtxType::U32,
1109                addr: Operand::Address {
1110                    base: reg("%rd0", PtxType::U64),
1111                    offset: None,
1112                },
1113                src: reg("%r0", PtxType::U32),
1114            },
1115        ];
1116
1117        let (first_result, first_eliminated) = eliminate_dead_code(&instructions);
1118        // DCE already runs to fixed-point internally, so a second call changes nothing
1119        let (second_result, second_eliminated) = eliminate_dead_code(&first_result);
1120
1121        assert_eq!(
1122            second_eliminated, 0,
1123            "second DCE pass must not eliminate anything additional (idempotent)"
1124        );
1125        assert_eq!(
1126            first_result.len(),
1127            second_result.len(),
1128            "result length must be the same on both passes"
1129        );
1130        assert_eq!(
1131            first_eliminated, 1,
1132            "first pass must eliminate the unused Add instruction"
1133        );
1134    }
1135}