Skip to main content

rlx_compile/
dce.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Dead Code Elimination — drop nodes that aren't reachable from any output.
17//!
18//! Walks the graph backwards from `graph.outputs`, marks every transitively
19//! consumed node as live, then rebuilds the graph keeping only live nodes.
20//!
21//! Why it matters:
22//! - Frees arena memory for buffers nobody reads.
23//! - Avoids running kernels whose outputs are discarded.
24//! - Catches accidental dead code (e.g., the early Vision graph builder
25//!   emitted a patch projection that wasn't actually wired into the encoder).
26
27use rlx_fusion::pass::Pass;
28use rlx_ir::{Graph, NodeId};
29use std::collections::{HashMap, HashSet, VecDeque};
30
31pub struct DeadCodeElimination;
32
33impl Pass for DeadCodeElimination {
34    fn name(&self) -> &str {
35        "dead_code_elimination"
36    }
37
38    fn run(&self, graph: Graph) -> Graph {
39        // BFS backwards from outputs to find all reachable nodes.
40        let mut live: HashSet<NodeId> = HashSet::new();
41        let mut queue: VecDeque<NodeId> = graph.outputs.iter().copied().collect();
42        while let Some(id) = queue.pop_front() {
43            if !live.insert(id) {
44                continue;
45            }
46            for &input in &graph.node(id).inputs {
47                queue.push_back(input);
48            }
49        }
50
51        // Rebuild graph keeping only live nodes (preserves topological order).
52        let mut new_graph = Graph::new(&graph.name);
53        let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
54        for node in graph.nodes() {
55            if !live.contains(&node.id) {
56                continue;
57            }
58            // Inputs and Params are kept as-is; everything else gets remapped inputs.
59            let new_inputs: Vec<NodeId> = node.inputs.iter().map(|id| id_map[id]).collect();
60            let new_id = new_graph.add_node(node.op.clone(), new_inputs, node.shape.clone());
61            if node.name.is_some() || node.origin.is_some() {
62                let n = new_graph.node_mut(new_id);
63                n.name = node.name.clone();
64                n.origin = node.origin.clone();
65            }
66            id_map.insert(node.id, new_id);
67        }
68        let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|id| id_map[id]).collect();
69        new_graph.set_outputs(new_outputs);
70        new_graph
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use rlx_ir::*;
78
79    #[test]
80    fn drops_unreferenced_nodes() {
81        let mut g = Graph::new("test");
82        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
83        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
84        let _dead = g.param("unused", Shape::new(&[8], DType::F32)); // never referenced
85        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
86        g.set_outputs(vec![mm]);
87
88        // Original has 4 nodes (x, w, unused, mm)
89        assert_eq!(g.len(), 4);
90        let after = DeadCodeElimination.run(g);
91        // After DCE: 3 nodes (x, w, mm) — `unused` is gone
92        assert_eq!(after.len(), 3);
93    }
94
95    #[test]
96    fn keeps_used_nodes() {
97        let mut g = Graph::new("test");
98        let x = g.input("x", Shape::new(&[4], DType::F32));
99        let y = g.input("y", Shape::new(&[4], DType::F32));
100        let z = g.binary(op::BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
101        g.set_outputs(vec![z]);
102
103        let before = g.len();
104        let after = DeadCodeElimination.run(g);
105        assert_eq!(after.len(), before); // nothing dead
106    }
107}