1use rlx_fusion::pass::Pass;
28use rlx_ir::{Graph, NodeId};
29use std::collections::{HashMap, HashSet, VecDeque};
30
31pub struct DeadCodeElimination;
32
33impl Pass for DeadCodeElimination {
34 fn name(&self) -> &str {
35 "dead_code_elimination"
36 }
37
38 fn run(&self, graph: Graph) -> Graph {
39 let mut live: HashSet<NodeId> = HashSet::new();
41 let mut queue: VecDeque<NodeId> = graph.outputs.iter().copied().collect();
42 while let Some(id) = queue.pop_front() {
43 if !live.insert(id) {
44 continue;
45 }
46 for &input in &graph.node(id).inputs {
47 queue.push_back(input);
48 }
49 }
50
51 let mut new_graph = Graph::new(&graph.name);
53 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
54 for node in graph.nodes() {
55 if !live.contains(&node.id) {
56 continue;
57 }
58 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|id| id_map[id]).collect();
60 let new_id = new_graph.add_node(node.op.clone(), new_inputs, node.shape.clone());
61 if node.name.is_some() || node.origin.is_some() {
62 let n = new_graph.node_mut(new_id);
63 n.name = node.name.clone();
64 n.origin = node.origin.clone();
65 }
66 id_map.insert(node.id, new_id);
67 }
68 let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|id| id_map[id]).collect();
69 new_graph.set_outputs(new_outputs);
70 new_graph
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use rlx_ir::*;
78
79 #[test]
80 fn drops_unreferenced_nodes() {
81 let mut g = Graph::new("test");
82 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
83 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
84 let _dead = g.param("unused", Shape::new(&[8], DType::F32)); let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
86 g.set_outputs(vec![mm]);
87
88 assert_eq!(g.len(), 4);
90 let after = DeadCodeElimination.run(g);
91 assert_eq!(after.len(), 3);
93 }
94
95 #[test]
96 fn keeps_used_nodes() {
97 let mut g = Graph::new("test");
98 let x = g.input("x", Shape::new(&[4], DType::F32));
99 let y = g.input("y", Shape::new(&[4], DType::F32));
100 let z = g.binary(op::BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
101 g.set_outputs(vec![z]);
102
103 let before = g.len();
104 let after = DeadCodeElimination.run(g);
105 assert_eq!(after.len(), before); }
107}