use rlx_fusion::pass::Pass;
use rlx_ir::{Graph, NodeId};
use std::collections::{HashMap, HashSet, VecDeque};
pub struct DeadCodeElimination;
impl Pass for DeadCodeElimination {
fn name(&self) -> &str {
"dead_code_elimination"
}
fn run(&self, graph: Graph) -> Graph {
let mut live: HashSet<NodeId> = HashSet::new();
let mut queue: VecDeque<NodeId> = graph.outputs.iter().copied().collect();
while let Some(id) = queue.pop_front() {
if !live.insert(id) {
continue;
}
for &input in &graph.node(id).inputs {
queue.push_back(input);
}
}
let mut new_graph = Graph::new(&graph.name);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
if !live.contains(&node.id) {
continue;
}
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|id| id_map[id]).collect();
let new_id = new_graph.add_node(node.op.clone(), new_inputs, node.shape.clone());
if node.name.is_some() || node.origin.is_some() {
let n = new_graph.node_mut(new_id);
n.name = node.name.clone();
n.origin = node.origin.clone();
}
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|id| id_map[id]).collect();
new_graph.set_outputs(new_outputs);
new_graph
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::*;
#[test]
fn drops_unreferenced_nodes() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let w = g.param("w", Shape::new(&[4, 3], DType::F32));
let _dead = g.param("unused", Shape::new(&[8], DType::F32)); let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![mm]);
assert_eq!(g.len(), 4);
let after = DeadCodeElimination.run(g);
assert_eq!(after.len(), 3);
}
#[test]
fn keeps_used_nodes() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.input("y", Shape::new(&[4], DType::F32));
let z = g.binary(op::BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
g.set_outputs(vec![z]);
let before = g.len();
let after = DeadCodeElimination.run(g);
assert_eq!(after.len(), before); }
}