proof_engine/render/shader_graph/
optimizer.rs1use std::collections::HashSet;
5use super::{ShaderGraph, NodeId};
6use super::nodes::NodeType;
7
8pub struct GraphOptimizer;
9
10impl GraphOptimizer {
11 pub fn run(graph: &ShaderGraph) -> ShaderGraph {
13 let mut g = graph.clone();
14 Self::eliminate_dead_nodes(&mut g);
15 Self::fold_constants(&mut g);
16 Self::remove_identity_operations(&mut g);
17 g
18 }
19
20 fn eliminate_dead_nodes(graph: &mut ShaderGraph) {
24 let reachable = Self::reachable_from_output(graph);
25 let dead: Vec<NodeId> = graph.nodes.keys()
26 .copied()
27 .filter(|id| !reachable.contains(id))
28 .collect();
29 for id in dead {
30 graph.remove_node(id);
31 }
32 }
33
34 fn reachable_from_output(graph: &ShaderGraph) -> HashSet<NodeId> {
35 let mut reachable = HashSet::new();
36 let mut stack = Vec::new();
37
38 if let Some(out) = graph.output_node {
39 stack.push(out);
40 } else {
41 return graph.nodes.keys().copied().collect();
43 }
44
45 while let Some(id) = stack.pop() {
46 if reachable.insert(id) {
47 for edge in graph.edges.iter().filter(|e| e.to_node == id) {
49 stack.push(edge.from_node);
50 }
51 }
52 }
53 reachable
54 }
55
56 fn fold_constants(graph: &mut ShaderGraph) {
60 let const_values = Self::collect_constant_values(graph);
61 let mut foldable: Vec<(NodeId, f32)> = Vec::new();
62
63 for (id, node) in &graph.nodes {
64 match &node.node_type {
65 NodeType::Add => {
66 if let (Some(a), Some(b)) = (
67 Self::get_input_const(&const_values, graph, *id, 0),
68 Self::get_input_const(&const_values, graph, *id, 1),
69 ) {
70 foldable.push((*id, a + b));
71 }
72 }
73 NodeType::Multiply => {
74 if let (Some(a), Some(b)) = (
75 Self::get_input_const(&const_values, graph, *id, 0),
76 Self::get_input_const(&const_values, graph, *id, 1),
77 ) {
78 foldable.push((*id, a * b));
79 }
80 }
81 NodeType::Subtract => {
82 if let (Some(a), Some(b)) = (
83 Self::get_input_const(&const_values, graph, *id, 0),
84 Self::get_input_const(&const_values, graph, *id, 1),
85 ) {
86 foldable.push((*id, a - b));
87 }
88 }
89 NodeType::Sin => {
90 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
91 foldable.push((*id, a.sin()));
92 }
93 }
94 NodeType::Cos => {
95 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
96 foldable.push((*id, a.cos()));
97 }
98 }
99 NodeType::Sqrt => {
100 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
101 if a >= 0.0 { foldable.push((*id, a.sqrt())); }
102 }
103 }
104 NodeType::Abs => {
105 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
106 foldable.push((*id, a.abs()));
107 }
108 }
109 NodeType::Negate => {
110 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
111 foldable.push((*id, -a));
112 }
113 }
114 NodeType::OneMinus => {
115 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
116 foldable.push((*id, 1.0 - a));
117 }
118 }
119 NodeType::Exp => {
120 if let Some(a) = Self::get_input_const(&const_values, graph, *id, 0) {
121 foldable.push((*id, a.exp()));
122 }
123 }
124 _ => {}
125 }
126 }
127
128 for (id, val) in foldable {
130 if let Some(node) = graph.nodes.get_mut(&id) {
131 node.node_type = NodeType::ConstFloat(val);
132 node.constant_inputs.clear();
133 }
134 graph.edges.retain(|e| e.to_node != id);
136 }
137 }
138
139 fn collect_constant_values(graph: &ShaderGraph) -> std::collections::HashMap<NodeId, f32> {
140 let mut map = std::collections::HashMap::new();
141 for (id, node) in &graph.nodes {
142 if let NodeType::ConstFloat(v) = node.node_type {
143 map.insert(*id, v);
144 }
145 }
146 map
147 }
148
149 fn get_input_const(
150 const_values: &std::collections::HashMap<NodeId, f32>,
151 graph: &ShaderGraph,
152 node_id: NodeId,
153 slot: u8,
154 ) -> Option<f32> {
155 for edge in graph.edges.iter().filter(|e| e.to_node == node_id && e.to_slot == slot) {
157 if let Some(&v) = const_values.get(&edge.from_node) {
158 return Some(v);
159 }
160 }
161 if let Some(node) = graph.nodes.get(&node_id) {
163 if let Some(s) = node.constant_inputs.get(&(slot as usize)) {
164 return s.parse().ok();
165 }
166 }
167 None
168 }
169
170 fn remove_identity_operations(graph: &mut ShaderGraph) {
173 let mut to_bypass: Vec<NodeId> = Vec::new();
174
175 for (id, node) in &graph.nodes {
176 match &node.node_type {
177 NodeType::Multiply => {
179 let b_const = Self::get_input_const(
180 &Self::collect_constant_values(graph), graph, *id, 1
181 );
182 if b_const == Some(1.0) { to_bypass.push(*id); }
183 }
184 NodeType::Add => {
186 let b_const = Self::get_input_const(
187 &Self::collect_constant_values(graph), graph, *id, 1
188 );
189 if b_const == Some(0.0) { to_bypass.push(*id); }
190 }
191 _ => {}
192 }
193 }
194
195 for id in to_bypass {
196 if let Some(node) = graph.nodes.get_mut(&id) {
197 node.bypassed = true;
198 }
199 }
200 }
201}
202
203#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::render::shader_graph::ShaderGraph;
209 use crate::render::shader_graph::nodes::NodeType;
210
211 #[test]
212 fn test_dead_node_elimination() {
213 let mut g = ShaderGraph::new("test");
214 let dead = g.add_node(NodeType::Sin); let uv = g.add_node(NodeType::UvCoord);
216 let out = g.add_node(NodeType::OutputColor);
217 g.set_output(out);
218 let _ = g.connect(uv, 0, out, 0);
219
220 let optimized = GraphOptimizer::run(&g);
221 assert!(optimized.node(dead).is_none());
223 assert!(optimized.node(uv).is_some());
225 assert!(optimized.node(out).is_some());
226 }
227
228 #[test]
229 fn test_constant_folding_add() {
230 let mut g = ShaderGraph::new("test");
231 let a = g.add_node(NodeType::ConstFloat(3.0));
232 let b = g.add_node(NodeType::ConstFloat(4.0));
233 let add = g.add_node(NodeType::Add);
234 let out = g.add_node(NodeType::OutputColor);
235 g.set_output(out);
236 let _ = g.connect(a, 0, add, 0);
237 let _ = g.connect(b, 0, add, 1);
238 let _ = g.connect(add, 0, out, 0);
239
240 let optimized = GraphOptimizer::run(&g);
241 if let Some(node) = optimized.node(add) {
243 assert_eq!(node.node_type, NodeType::ConstFloat(7.0));
244 }
245 }
246
247 #[test]
248 fn test_constant_folding_sin() {
249 let mut g = ShaderGraph::new("test");
250 let zero = g.add_node(NodeType::ConstFloat(0.0));
251 let sin_n = g.add_node(NodeType::Sin);
252 let out = g.add_node(NodeType::OutputColor);
253 g.set_output(out);
254 let _ = g.connect(zero, 0, sin_n, 0);
255 let _ = g.connect(sin_n, 0, out, 0);
256
257 let optimized = GraphOptimizer::run(&g);
258 if let Some(node) = optimized.node(sin_n) {
259 assert_eq!(node.node_type, NodeType::ConstFloat(0.0));
261 }
262 }
263
264 #[test]
265 fn test_no_crash_empty_graph() {
266 let g = ShaderGraph::new("empty");
267 let _ = GraphOptimizer::run(&g);
268 }
269
270 #[test]
271 fn test_reachable_includes_all_ancestors() {
272 let mut g = ShaderGraph::new("test");
273 let uv = g.add_node(NodeType::UvCoord);
274 let sin = g.add_node(NodeType::Sin);
275 let cos = g.add_node(NodeType::Cos);
276 let add = g.add_node(NodeType::Add);
277 let out = g.add_node(NodeType::OutputColor);
278 g.set_output(out);
279 let _ = g.connect(uv, 0, sin, 0);
280 let _ = g.connect(uv, 0, cos, 0);
281 let _ = g.connect(sin, 0, add, 0);
282 let _ = g.connect(cos, 0, add, 1);
283 let _ = g.connect(add, 0, out, 0);
284
285 let opt = GraphOptimizer::run(&g);
286 assert!(opt.node(uv).is_some());
287 assert!(opt.node(sin).is_some());
288 assert!(opt.node(cos).is_some());
289 assert!(opt.node(add).is_some());
290 }
291}