1use std::collections::{HashMap, HashSet};
6use super::nodes::{
7 Connection, DataType, NodeId, NodeType, ParamValue, ShaderGraph, ShaderNode,
8};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum OptimizationPass {
17 TypeInference,
19 RedundantCastRemoval,
21 AlgebraicSimplification,
23 LoopDetection,
25 NodeMerging,
27 InstructionCounting,
29 DeadCodeElimination,
31 ConstantPropagation,
33}
34
35#[derive(Debug, Clone)]
41pub struct OptimizerConfig {
42 pub passes: Vec<OptimizationPass>,
44 pub max_iterations: usize,
46 pub verbose: bool,
48 pub instruction_budget: u32,
50}
51
52impl Default for OptimizerConfig {
53 fn default() -> Self {
54 Self {
55 passes: vec![
56 OptimizationPass::TypeInference,
57 OptimizationPass::DeadCodeElimination,
58 OptimizationPass::AlgebraicSimplification,
59 OptimizationPass::RedundantCastRemoval,
60 OptimizationPass::NodeMerging,
61 OptimizationPass::ConstantPropagation,
62 OptimizationPass::InstructionCounting,
63 OptimizationPass::LoopDetection,
64 ],
65 max_iterations: 10,
66 verbose: false,
67 instruction_budget: 512,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
78pub struct OptimizationReport {
79 pub nodes_before: usize,
81 pub nodes_after: usize,
83 pub connections_before: usize,
85 pub connections_after: usize,
87 pub dead_nodes_removed: usize,
89 pub algebraic_simplifications: usize,
91 pub redundant_casts_removed: usize,
93 pub nodes_merged: usize,
95 pub cycle_detected: bool,
97 pub estimated_instructions: u32,
99 pub over_budget: bool,
101 pub inferred_types: HashMap<(u64, usize), DataType>,
103 pub warnings: Vec<String>,
105}
106
107impl OptimizationReport {
108 fn new(graph: &ShaderGraph) -> Self {
109 Self {
110 nodes_before: graph.node_count(),
111 nodes_after: graph.node_count(),
112 connections_before: graph.connections().len(),
113 connections_after: graph.connections().len(),
114 dead_nodes_removed: 0,
115 algebraic_simplifications: 0,
116 redundant_casts_removed: 0,
117 nodes_merged: 0,
118 cycle_detected: false,
119 estimated_instructions: 0,
120 over_budget: false,
121 inferred_types: HashMap::new(),
122 warnings: Vec::new(),
123 }
124 }
125}
126
127pub struct ShaderOptimizer {
133 config: OptimizerConfig,
134}
135
136impl ShaderOptimizer {
137 pub fn new(config: OptimizerConfig) -> Self {
138 Self { config }
139 }
140
141 pub fn with_defaults() -> Self {
142 Self::new(OptimizerConfig::default())
143 }
144
145 pub fn optimize(&self, graph: &ShaderGraph) -> (ShaderGraph, OptimizationReport) {
148 let mut optimized = graph.clone();
149 let mut report = OptimizationReport::new(graph);
150
151 for pass in &self.config.passes {
152 match pass {
153 OptimizationPass::TypeInference => {
154 self.run_type_inference(&optimized, &mut report);
155 }
156 OptimizationPass::RedundantCastRemoval => {
157 let removed = self.run_redundant_cast_removal(&mut optimized, &report.inferred_types);
158 report.redundant_casts_removed += removed;
159 }
160 OptimizationPass::AlgebraicSimplification => {
161 let count = self.run_algebraic_simplification(&mut optimized);
162 report.algebraic_simplifications += count;
163 }
164 OptimizationPass::LoopDetection => {
165 report.cycle_detected = self.detect_loops(&optimized);
166 if report.cycle_detected {
167 report.warnings.push("Cycle detected in shader graph".to_string());
168 }
169 }
170 OptimizationPass::NodeMerging => {
171 let merged = self.run_node_merging(&mut optimized);
172 report.nodes_merged += merged;
173 }
174 OptimizationPass::InstructionCounting => {
175 report.estimated_instructions = self.estimate_instructions(&optimized);
176 report.over_budget = report.estimated_instructions > self.config.instruction_budget;
177 if report.over_budget {
178 report.warnings.push(format!(
179 "Instruction count {} exceeds budget {}",
180 report.estimated_instructions, self.config.instruction_budget
181 ));
182 }
183 }
184 OptimizationPass::DeadCodeElimination => {
185 let removed = self.run_dead_code_elimination(&mut optimized);
186 report.dead_nodes_removed += removed;
187 }
188 OptimizationPass::ConstantPropagation => {
189 self.run_constant_propagation(&mut optimized);
190 }
191 }
192 }
193
194 report.nodes_after = optimized.node_count();
195 report.connections_after = optimized.connections().len();
196
197 (optimized, report)
198 }
199
200 fn run_type_inference(&self, graph: &ShaderGraph, report: &mut OptimizationReport) {
206 for node in graph.nodes() {
207 for (idx, socket) in node.outputs.iter().enumerate() {
208 let inferred = self.infer_output_type(graph, node, idx);
209 report.inferred_types.insert((node.id.0, idx), inferred.unwrap_or(socket.data_type));
210 }
211 }
212 }
213
214 fn infer_output_type(&self, graph: &ShaderGraph, node: &ShaderNode, output_idx: usize) -> Option<DataType> {
216 let base_type = node.outputs.get(output_idx)?.data_type;
218
219 match &node.node_type {
221 NodeType::Add | NodeType::Sub | NodeType::Mul | NodeType::Div
222 | NodeType::Lerp | NodeType::Clamp | NodeType::Smoothstep => {
223 let incoming = graph.incoming_connections(node.id);
224 let mut widest = base_type;
225 for conn in &incoming {
226 if let Some(src_node) = graph.node(conn.from_node) {
227 if let Some(src_type) = src_node.output_type(conn.from_socket) {
228 widest = wider_type(widest, src_type);
229 }
230 }
231 }
232 Some(widest)
233 }
234 _ => Some(base_type),
235 }
236 }
237
238 fn run_redundant_cast_removal(
244 &self,
245 graph: &mut ShaderGraph,
246 inferred_types: &HashMap<(u64, usize), DataType>,
247 ) -> usize {
248 let mut to_remove: Vec<NodeId> = Vec::new();
249
250 let node_ids: Vec<NodeId> = graph.node_ids().collect();
253 for nid in &node_ids {
254 let node = match graph.node(*nid) {
255 Some(n) => n,
256 None => continue,
257 };
258
259 if node.inputs.len() != 1 || node.outputs.len() != 1 {
262 continue;
263 }
264
265 let incoming = graph.incoming_connections(*nid);
266 if incoming.len() != 1 {
267 continue;
268 }
269
270 let conn = incoming[0];
271 let src_type = inferred_types.get(&(conn.from_node.0, conn.from_socket))
272 .copied()
273 .unwrap_or(DataType::Float);
274 let dst_type = node.outputs[0].data_type;
275
276 if src_type == dst_type {
280 let is_noop = match &node.node_type {
282 NodeType::Abs => {
283 false }
286 _ => false,
287 };
288 if is_noop {
289 to_remove.push(*nid);
290 }
291 }
292 }
293
294 let count = to_remove.len();
295
296 for nid in to_remove {
297 self.bypass_node(graph, nid);
298 }
299
300 count
301 }
302
303 fn bypass_node(&self, graph: &mut ShaderGraph, node_id: NodeId) {
306 let incoming: Vec<Connection> = graph.incoming_connections(node_id)
308 .into_iter().cloned().collect();
309 let outgoing: Vec<Connection> = graph.outgoing_connections(node_id)
310 .into_iter().cloned().collect();
311
312 if incoming.len() != 1 {
313 return;
314 }
315
316 let source = &incoming[0];
317
318 for out_conn in &outgoing {
320 graph.disconnect(node_id, out_conn.from_socket, out_conn.to_node, out_conn.to_socket);
321 graph.connect(source.from_node, source.from_socket, out_conn.to_node, out_conn.to_socket);
322 }
323
324 graph.remove_node(node_id);
326 }
327
328 fn run_algebraic_simplification(&self, graph: &mut ShaderGraph) -> usize {
334 let mut simplifications = 0;
335
336 for _iteration in 0..self.config.max_iterations {
337 let mut changes_this_round = 0;
338
339 let node_ids: Vec<NodeId> = graph.node_ids().collect();
340 for &nid in &node_ids {
341 let node = match graph.node(&nid) {
342 Some(n) => n.clone(),
343 None => continue,
344 };
345
346 let result = self.try_simplify_node(graph, &node);
347 match result {
348 SimplifyResult::NoChange => {}
349 SimplifyResult::ReplaceWithInput(input_idx) => {
350 let incoming: Vec<Connection> = graph.incoming_connections(nid)
352 .into_iter().cloned().collect();
353 let source_conn = incoming.iter().find(|c| c.to_socket == input_idx);
354 if let Some(src) = source_conn {
355 let outgoing: Vec<Connection> = graph.outgoing_connections(nid)
356 .into_iter().cloned().collect();
357 for out in &outgoing {
358 graph.connect(src.from_node, src.from_socket, out.to_node, out.to_socket);
359 }
360 graph.remove_node(nid);
361 changes_this_round += 1;
362 }
363 }
364 SimplifyResult::ReplaceWithConstant(value) => {
365 let outgoing: Vec<Connection> = graph.outgoing_connections(nid)
367 .into_iter().cloned().collect();
368
369 let mut replacement = ShaderNode::new(NodeId(0), NodeType::Color);
371 replacement.inputs[0].default_value = Some(match &value {
372 ParamValue::Float(v) => ParamValue::Vec4([*v, *v, *v, 1.0]),
373 ParamValue::Vec3(v) => ParamValue::Vec4([v[0], v[1], v[2], 1.0]),
374 other => other.clone(),
375 });
376 replacement.properties.insert("folded_constant".to_string(), value);
377
378 let new_id = graph.add_node_with(replacement);
379
380 for out in &outgoing {
382 graph.connect(new_id, 0, out.to_node, out.to_socket);
383 }
384
385 graph.remove_node(nid);
386 changes_this_round += 1;
387 }
388 }
389 }
390
391 simplifications += changes_this_round;
392 if changes_this_round == 0 {
393 break;
394 }
395 }
396
397 simplifications
398 }
399
400 fn try_simplify_node(&self, graph: &ShaderGraph, node: &ShaderNode) -> SimplifyResult {
402 let incoming: Vec<&Connection> = graph.incoming_connections(node.id);
403
404 match &node.node_type {
405 NodeType::Add => {
407 if let Some(result) = self.check_identity_binary(graph, node, &incoming, 0.0) {
408 return result;
409 }
410 }
411 NodeType::Sub => {
413 if self.is_input_constant(graph, node, &incoming, 1, 0.0) {
414 return SimplifyResult::ReplaceWithInput(0);
415 }
416 }
417 NodeType::Mul => {
419 if let Some(result) = self.check_identity_binary(graph, node, &incoming, 1.0) {
420 return result;
421 }
422 if self.is_input_constant(graph, node, &incoming, 0, 0.0) {
424 return SimplifyResult::ReplaceWithConstant(ParamValue::Float(0.0));
425 }
426 if self.is_input_constant(graph, node, &incoming, 1, 0.0) {
427 return SimplifyResult::ReplaceWithConstant(ParamValue::Float(0.0));
428 }
429 }
430 NodeType::Div => {
432 if self.is_input_constant(graph, node, &incoming, 1, 1.0) {
433 return SimplifyResult::ReplaceWithInput(0);
434 }
435 }
436 NodeType::Pow => {
438 if self.is_input_constant(graph, node, &incoming, 1, 1.0) {
439 return SimplifyResult::ReplaceWithInput(0);
440 }
441 if self.is_input_constant(graph, node, &incoming, 1, 0.0) {
442 return SimplifyResult::ReplaceWithConstant(ParamValue::Float(1.0));
443 }
444 }
445 NodeType::Lerp => {
447 if self.is_input_constant(graph, node, &incoming, 2, 0.0) {
448 return SimplifyResult::ReplaceWithInput(0);
449 }
450 if self.is_input_constant(graph, node, &incoming, 2, 1.0) {
451 return SimplifyResult::ReplaceWithInput(1);
452 }
453 }
454 NodeType::Clamp => {
456 }
459 NodeType::Step => {
461 if self.is_input_constant(graph, node, &incoming, 0, 0.0) {
462 }
464 }
465 _ => {}
466 }
467
468 SimplifyResult::NoChange
469 }
470
471 fn check_identity_binary(
474 &self,
475 graph: &ShaderGraph,
476 node: &ShaderNode,
477 incoming: &[&Connection],
478 identity: f32,
479 ) -> Option<SimplifyResult> {
480 if self.is_input_constant(graph, node, incoming, 0, identity) {
481 return Some(SimplifyResult::ReplaceWithInput(1));
482 }
483 if self.is_input_constant(graph, node, incoming, 1, identity) {
484 return Some(SimplifyResult::ReplaceWithInput(0));
485 }
486 None
487 }
488
489 fn is_input_constant(
491 &self,
492 _graph: &ShaderGraph,
493 node: &ShaderNode,
494 incoming: &[&Connection],
495 socket_idx: usize,
496 expected: f32,
497 ) -> bool {
498 let has_connection = incoming.iter().any(|c| c.to_socket == socket_idx);
500 if has_connection {
501 return false;
504 }
505
506 if let Some(default) = node.input_default(socket_idx) {
508 if let Some(val) = default.as_float() {
509 return (val - expected).abs() < 1e-7;
510 }
511 }
512
513 false
514 }
515
516 fn detect_loops(&self, graph: &ShaderGraph) -> bool {
522 let mut color: HashMap<NodeId, u8> = HashMap::new(); for nid in graph.node_ids() {
524 color.insert(nid, 0);
525 }
526
527 for nid in graph.node_ids() {
528 if color[&nid] == 0 {
529 if self.dfs_cycle(graph, nid, &mut color) {
530 return true;
531 }
532 }
533 }
534
535 false
536 }
537
538 fn dfs_cycle(&self, graph: &ShaderGraph, node_id: NodeId, color: &mut HashMap<NodeId, u8>) -> bool {
539 color.insert(node_id, 1); for conn in graph.outgoing_connections(node_id) {
542 let neighbor = conn.to_node;
543 match color.get(&neighbor) {
544 Some(1) => return true, Some(0) => {
546 if self.dfs_cycle(graph, neighbor, color) {
547 return true;
548 }
549 }
550 _ => {} }
552 }
553
554 color.insert(node_id, 2); false
556 }
557
558 fn run_node_merging(&self, graph: &mut ShaderGraph) -> usize {
566 let mut merged = 0;
567
568 let node_ids: Vec<NodeId> = graph.node_ids().collect();
575 let mut removed_set: HashSet<NodeId> = HashSet::new();
576
577 for &nid in &node_ids {
578 if removed_set.contains(&nid) {
579 continue;
580 }
581
582 let node = match graph.node(&nid) {
583 Some(n) => n,
584 None => continue,
585 };
586
587 let is_mergeable = matches!(
589 node.node_type,
590 NodeType::Add | NodeType::Sub | NodeType::Mul
591 );
592 if !is_mergeable {
593 continue;
594 }
595
596 let outgoing = graph.outgoing_connections(nid);
598 if outgoing.len() != 1 {
599 continue;
600 }
601
602 let out_conn = outgoing[0].clone();
603 let downstream = match graph.node(&out_conn.to_node) {
604 Some(n) => n,
605 None => continue,
606 };
607
608 if downstream.node_type != node.node_type {
610 continue;
611 }
612
613 if removed_set.contains(&out_conn.to_node) {
615 continue;
616 }
617
618 if let Some(downstream_mut) = graph.node_mut(out_conn.to_node) {
625 downstream_mut.properties.insert(
626 format!("merged_from_{}", nid.0),
627 ParamValue::Bool(true),
628 );
629 merged += 1;
630 }
631 }
632
633 merged
634 }
635
636 fn run_dead_code_elimination(&self, graph: &mut ShaderGraph) -> usize {
641 let outputs = graph.output_nodes();
642 if outputs.is_empty() {
643 return 0;
644 }
645
646 let mut reachable: HashSet<NodeId> = HashSet::new();
648 let mut queue: Vec<NodeId> = outputs;
649
650 while let Some(nid) = queue.pop() {
651 if !reachable.insert(nid) {
652 continue;
653 }
654 for conn in graph.connections() {
655 if conn.to_node == nid && !reachable.contains(&conn.from_node) {
656 queue.push(conn.from_node);
657 }
658 }
659 }
660
661 let all_ids: Vec<NodeId> = graph.node_ids().collect();
663 let mut removed = 0;
664 for nid in all_ids {
665 if !reachable.contains(&nid) {
666 graph.remove_node(nid);
667 removed += 1;
668 }
669 }
670
671 removed
672 }
673
674 fn run_constant_propagation(&self, graph: &mut ShaderGraph) {
680 let mut known_constants: HashMap<(NodeId, usize), ParamValue> = HashMap::new();
682
683 let node_ids: Vec<NodeId> = graph.node_ids().collect();
685 for &nid in &node_ids {
686 let node = match graph.node(&nid) {
687 Some(n) => n,
688 None => continue,
689 };
690
691 if node.node_type == NodeType::Color {
692 if let Some(val) = &node.inputs[0].default_value {
693 let incoming = graph.incoming_connections(nid);
695 if incoming.is_empty() {
696 known_constants.insert((nid, 0), val.clone());
697 }
698 }
699 }
700 }
701
702 for &nid in &node_ids {
706 let node = match graph.node(&nid) {
707 Some(n) => n,
708 None => continue,
709 };
710
711 if !node.node_type.is_pure_math() {
712 continue;
713 }
714
715 let incoming = graph.incoming_connections(nid);
716 let mut all_inputs_known = true;
717 let mut input_vals: Vec<ParamValue> = Vec::new();
718
719 for (idx, socket) in node.inputs.iter().enumerate() {
720 let conn = incoming.iter().find(|c| c.to_socket == idx);
721 if let Some(c) = conn {
722 if let Some(val) = known_constants.get(&(c.from_node, c.from_socket)) {
723 input_vals.push(val.clone());
724 } else {
725 all_inputs_known = false;
726 break;
727 }
728 } else if let Some(def) = &socket.default_value {
729 input_vals.push(def.clone());
730 } else {
731 all_inputs_known = false;
732 break;
733 }
734 }
735
736 if all_inputs_known && !input_vals.is_empty() {
737 if let Some(result) = evaluate_pure_node(&node.node_type, &input_vals) {
739 for (idx, val) in result.iter().enumerate() {
740 known_constants.insert((nid, idx), val.clone());
741 }
742 if let Some(node_mut) = graph.node_mut(nid) {
744 if let Some(first) = result.into_iter().next() {
745 node_mut.properties.insert(
746 "propagated_constant".to_string(),
747 first,
748 );
749 }
750 }
751 }
752 }
753 }
754 }
755
756 fn estimate_instructions(&self, graph: &ShaderGraph) -> u32 {
761 graph.estimated_cost()
762 }
763}
764
765enum SimplifyResult {
770 NoChange,
771 ReplaceWithInput(usize),
773 ReplaceWithConstant(ParamValue),
775}
776
777fn wider_type(a: DataType, b: DataType) -> DataType {
779 let rank = |t: DataType| -> u8 {
780 match t {
781 DataType::Bool => 0,
782 DataType::Int => 1,
783 DataType::Float => 2,
784 DataType::Vec2 => 3,
785 DataType::Vec3 => 4,
786 DataType::Vec4 => 5,
787 DataType::Mat3 => 6,
788 DataType::Mat4 => 7,
789 DataType::Sampler2D => 8,
790 }
791 };
792 if rank(a) >= rank(b) { a } else { b }
793}
794
795fn evaluate_pure_node(node_type: &NodeType, inputs: &[ParamValue]) -> Option<Vec<ParamValue>> {
797 match node_type {
798 NodeType::Add => {
799 let a = inputs.first()?.as_float()?;
800 let b = inputs.get(1)?.as_float()?;
801 Some(vec![ParamValue::Float(a + b)])
802 }
803 NodeType::Sub => {
804 let a = inputs.first()?.as_float()?;
805 let b = inputs.get(1)?.as_float()?;
806 Some(vec![ParamValue::Float(a - b)])
807 }
808 NodeType::Mul => {
809 let a = inputs.first()?.as_float()?;
810 let b = inputs.get(1)?.as_float()?;
811 Some(vec![ParamValue::Float(a * b)])
812 }
813 NodeType::Div => {
814 let a = inputs.first()?.as_float()?;
815 let b = inputs.get(1)?.as_float()?;
816 if b.abs() < 1e-10 { return None; }
817 Some(vec![ParamValue::Float(a / b)])
818 }
819 NodeType::Abs => {
820 let x = inputs.first()?.as_float()?;
821 Some(vec![ParamValue::Float(x.abs())])
822 }
823 NodeType::Floor => {
824 let x = inputs.first()?.as_float()?;
825 Some(vec![ParamValue::Float(x.floor())])
826 }
827 NodeType::Ceil => {
828 let x = inputs.first()?.as_float()?;
829 Some(vec![ParamValue::Float(x.ceil())])
830 }
831 NodeType::Fract => {
832 let x = inputs.first()?.as_float()?;
833 Some(vec![ParamValue::Float(x.fract())])
834 }
835 NodeType::Sqrt => {
836 let x = inputs.first()?.as_float()?;
837 Some(vec![ParamValue::Float(x.max(0.0).sqrt())])
838 }
839 NodeType::Sin => {
840 let x = inputs.first()?.as_float()?;
841 Some(vec![ParamValue::Float(x.sin())])
842 }
843 NodeType::Cos => {
844 let x = inputs.first()?.as_float()?;
845 Some(vec![ParamValue::Float(x.cos())])
846 }
847 NodeType::Pow => {
848 let base = inputs.first()?.as_float()?;
849 let exp = inputs.get(1)?.as_float()?;
850 Some(vec![ParamValue::Float(base.max(0.0).powf(exp))])
851 }
852 NodeType::Lerp => {
853 let a = inputs.first()?.as_float()?;
854 let b = inputs.get(1)?.as_float()?;
855 let t = inputs.get(2)?.as_float()?;
856 Some(vec![ParamValue::Float(a + (b - a) * t)])
857 }
858 NodeType::Clamp => {
859 let x = inputs.first()?.as_float()?;
860 let lo = inputs.get(1)?.as_float()?;
861 let hi = inputs.get(2)?.as_float()?;
862 Some(vec![ParamValue::Float(x.clamp(lo, hi))])
863 }
864 NodeType::Step => {
865 let edge = inputs.first()?.as_float()?;
866 let x = inputs.get(1)?.as_float()?;
867 Some(vec![ParamValue::Float(if x >= edge { 1.0 } else { 0.0 })])
868 }
869 NodeType::Invert => {
870 let c = inputs.first()?.as_vec3()?;
871 Some(vec![ParamValue::Vec3([1.0 - c[0], 1.0 - c[1], 1.0 - c[2]])])
872 }
873 _ => None,
874 }
875}
876
877pub fn optimize_graph(graph: &ShaderGraph) -> (ShaderGraph, OptimizationReport) {
883 ShaderOptimizer::with_defaults().optimize(graph)
884}
885
886pub fn estimate_instruction_count(graph: &ShaderGraph) -> u32 {
888 graph.estimated_cost()
889}
890
891pub fn has_cycles(graph: &ShaderGraph) -> bool {
893 ShaderOptimizer::with_defaults().detect_loops(graph)
894}