Skip to main content

proof_engine/render/shader_graph/
optimizer.rs

1//! Shader graph optimizer — dead-node elimination, constant folding,
2//! common subexpression sharing, and redundant node removal.
3
4use std::collections::HashSet;
5use super::{ShaderGraph, NodeId};
6use super::nodes::NodeType;
7
8pub struct GraphOptimizer;
9
10impl GraphOptimizer {
11    /// Run all optimization passes and return an optimized clone of the graph.
12    pub fn run(graph: &ShaderGraph) -> ShaderGraph {
13        let mut g = graph.clone();
14        Self::eliminate_dead_nodes(&mut g);
15        Self::fold_constants(&mut g);
16        Self::remove_identity_operations(&mut g);
17        g
18    }
19
20    // ── Dead-node elimination ──────────────────────────────────────────────────
21    /// Remove nodes that don't contribute to any output.
22
23    fn eliminate_dead_nodes(graph: &mut ShaderGraph) {
24        let reachable = Self::reachable_from_output(graph);
25        let dead: Vec<NodeId> = graph.nodes.keys()
26            .copied()
27            .filter(|id| !reachable.contains(id))
28            .collect();
29        for id in dead {
30            graph.remove_node(id);
31        }
32    }
33
34    fn reachable_from_output(graph: &ShaderGraph) -> HashSet<NodeId> {
35        let mut reachable = HashSet::new();
36        let mut stack = Vec::new();
37
38        if let Some(out) = graph.output_node {
39            stack.push(out);
40        } else {
41            // If no output set, everything is reachable
42            return graph.nodes.keys().copied().collect();
43        }
44
45        while let Some(id) = stack.pop() {
46            if reachable.insert(id) {
47                // Find all nodes feeding into this one
48                for edge in graph.edges.iter().filter(|e| e.to_node == id) {
49                    stack.push(edge.from_node);
50                }
51            }
52        }
53        reachable
54    }
55
56    // ── Constant folding ───────────────────────────────────────────────────────
57    /// Replace simple constant operations with ConstFloat nodes.
58
59    fn fold_constants(graph: &mut ShaderGraph) {
60        let const_values = Self::collect_constant_values(graph);
61        let mut foldable: Vec<(NodeId, f32)> = Vec::new();
62
63        for (id, node) in &graph.nodes {
64            match &node.node_type {
65                NodeType::Add => {
66                    if let (Some(a), Some(b)) = (
67                        Self::get_input_const(&const_values, graph, *id, 0),
68                        Self::get_input_const(&const_values, graph, *id, 1),
69                    ) {
70                        foldable.push((*id, a + b));
71                    }
72                }
73                NodeType::Multiply => {
74                    if let (Some(a), Some(b)) = (
75                        Self::get_input_const(&const_values, graph, *id, 0),
76                        Self::get_input_const(&const_values, graph, *id, 1),
77                    ) {
78                        foldable.push((*id, a * b));
79                    }
80                }
81                NodeType::Subtract => {
82                    if let (Some(a), Some(b)) = (
83                        Self::get_input_const(&const_values, graph, *id, 0),
84                        Self::get_input_const(&const_values, graph, *id, 1),
85                    ) {
86                        foldable.push((*id, a - b));
87                    }
88                }
89                NodeType::Sin => {
90                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
91                        foldable.push((*id, a.sin()));
92                    }
93                }
94                NodeType::Cos => {
95                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
96                        foldable.push((*id, a.cos()));
97                    }
98                }
99                NodeType::Sqrt => {
100                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
101                        if a >= 0.0 { foldable.push((*id, a.sqrt())); }
102                    }
103                }
104                NodeType::Abs => {
105                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
106                        foldable.push((*id, a.abs()));
107                    }
108                }
109                NodeType::Negate => {
110                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
111                        foldable.push((*id, -a));
112                    }
113                }
114                NodeType::OneMinus => {
115                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
116                        foldable.push((*id, 1.0 - a));
117                    }
118                }
119                NodeType::Exp => {
120                    if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
121                        foldable.push((*id, a.exp()));
122                    }
123                }
124                _ => {}
125            }
126        }
127
128        // Apply folds: change node type to ConstFloat, disconnect inputs
129        for (id, val) in foldable {
130            if let Some(node) = graph.nodes.get_mut(&id) {
131                node.node_type = NodeType::ConstFloat(val);
132                node.constant_inputs.clear();
133            }
134            // Remove all incoming edges to this node
135            graph.edges.retain(|e| e.to_node != id);
136        }
137    }
138
139    fn collect_constant_values(graph: &ShaderGraph) -> std::collections::HashMap<NodeId, f32> {
140        let mut map = std::collections::HashMap::new();
141        for (id, node) in &graph.nodes {
142            if let NodeType::ConstFloat(v) = node.node_type {
143                map.insert(*id, v);
144            }
145        }
146        map
147    }
148
149    fn get_input_const(
150        const_values: &std::collections::HashMap<NodeId, f32>,
151        graph:        &ShaderGraph,
152        node_id:      NodeId,
153        slot:         u8,
154    ) -> Option<f32> {
155        // Check if slot is connected to a ConstFloat node
156        for edge in graph.edges.iter().filter(|e| e.to_node == node_id && e.to_slot == slot) {
157            if let Some(&v) = const_values.get(&edge.from_node) {
158                return Some(v);
159            }
160        }
161        // Check constant_inputs fallback
162        if let Some(node) = graph.nodes.get(&node_id) {
163            if let Some(s) = node.constant_inputs.get(&(slot as usize)) {
164                return s.parse().ok();
165            }
166        }
167        None
168    }
169
170    // ── Identity operation removal ─────────────────────────────────────────────
171
172    fn remove_identity_operations(graph: &mut ShaderGraph) {
173        let mut to_bypass: Vec<NodeId> = Vec::new();
174
175        for (id, node) in &graph.nodes {
176            match &node.node_type {
177                // Multiply by 1.0 → bypass
178                NodeType::Multiply => {
179                    let b_const = Self::get_input_const(
180                        &Self::collect_constant_values(graph), graph, *id, 1
181                    );
182                    if b_const == Some(1.0) { to_bypass.push(*id); }
183                }
184                // Add 0.0 → bypass
185                NodeType::Add => {
186                    let b_const = Self::get_input_const(
187                        &Self::collect_constant_values(graph), graph, *id, 1
188                    );
189                    if b_const == Some(0.0) { to_bypass.push(*id); }
190                }
191                _ => {}
192            }
193        }
194
195        for id in to_bypass {
196            if let Some(node) = graph.nodes.get_mut(&id) {
197                node.bypassed = true;
198            }
199        }
200    }
201}
202
203// ── Tests ─────────────────────────────────────────────────────────────────────
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::render::shader_graph::ShaderGraph;
209    use crate::render::shader_graph::nodes::NodeType;
210
211    #[test]
212    fn test_dead_node_elimination() {
213        let mut g = ShaderGraph::new("test");
214        let dead = g.add_node(NodeType::Sin);           // not connected to output
215        let uv   = g.add_node(NodeType::UvCoord);
216        let out  = g.add_node(NodeType::OutputColor);
217        g.set_output(out);
218        let _ = g.connect(uv, 0, out, 0);
219
220        let optimized = GraphOptimizer::run(&g);
221        // Dead sin node should be removed
222        assert!(optimized.node(dead).is_none());
223        // UV and output should survive
224        assert!(optimized.node(uv).is_some());
225        assert!(optimized.node(out).is_some());
226    }
227
228    #[test]
229    fn test_constant_folding_add() {
230        let mut g = ShaderGraph::new("test");
231        let a   = g.add_node(NodeType::ConstFloat(3.0));
232        let b   = g.add_node(NodeType::ConstFloat(4.0));
233        let add = g.add_node(NodeType::Add);
234        let out = g.add_node(NodeType::OutputColor);
235        g.set_output(out);
236        let _ = g.connect(a,   0, add, 0);
237        let _ = g.connect(b,   0, add, 1);
238        let _ = g.connect(add, 0, out, 0);
239
240        let optimized = GraphOptimizer::run(&g);
241        // The add node should now be a ConstFloat(7.0)
242        if let Some(node) = optimized.node(add) {
243            assert_eq!(node.node_type, NodeType::ConstFloat(7.0));
244        }
245    }
246
247    #[test]
248    fn test_constant_folding_sin() {
249        let mut g   = ShaderGraph::new("test");
250        let zero    = g.add_node(NodeType::ConstFloat(0.0));
251        let sin_n   = g.add_node(NodeType::Sin);
252        let out     = g.add_node(NodeType::OutputColor);
253        g.set_output(out);
254        let _ = g.connect(zero,  0, sin_n, 0);
255        let _ = g.connect(sin_n, 0, out,   0);
256
257        let optimized = GraphOptimizer::run(&g);
258        if let Some(node) = optimized.node(sin_n) {
259            // sin(0) = 0
260            assert_eq!(node.node_type, NodeType::ConstFloat(0.0));
261        }
262    }
263
264    #[test]
265    fn test_no_crash_empty_graph() {
266        let g = ShaderGraph::new("empty");
267        let _ = GraphOptimizer::run(&g);
268    }
269
270    #[test]
271    fn test_reachable_includes_all_ancestors() {
272        let mut g  = ShaderGraph::new("test");
273        let uv     = g.add_node(NodeType::UvCoord);
274        let sin    = g.add_node(NodeType::Sin);
275        let cos    = g.add_node(NodeType::Cos);
276        let add    = g.add_node(NodeType::Add);
277        let out    = g.add_node(NodeType::OutputColor);
278        g.set_output(out);
279        let _ = g.connect(uv,  0, sin, 0);
280        let _ = g.connect(uv,  0, cos, 0);
281        let _ = g.connect(sin, 0, add, 0);
282        let _ = g.connect(cos, 0, add, 1);
283        let _ = g.connect(add, 0, out, 0);
284
285        let opt = GraphOptimizer::run(&g);
286        assert!(opt.node(uv).is_some());
287        assert!(opt.node(sin).is_some());
288        assert!(opt.node(cos).is_some());
289        assert!(opt.node(add).is_some());
290    }
291}