Skip to main content

proof_engine/shader_graph/
optimizer.rs

1//! Advanced optimization passes for shader graphs: type inference, algebraic simplification,
2//! redundant cast removal, loop detection, node merging, instruction count estimation,
3//! and shader variant caching.
4
5use std::collections::{HashMap, HashSet};
6use super::nodes::{
7    Connection, DataType, NodeId, NodeType, ParamValue, ShaderGraph, ShaderNode,
8};
9
10// ---------------------------------------------------------------------------
11// Optimization pass enum
12// ---------------------------------------------------------------------------
13
14/// Individual optimization passes that can be applied to a shader graph.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum OptimizationPass {
17    /// Infer types through the graph and insert implicit casts where needed.
18    TypeInference,
19    /// Remove redundant type casts (e.g., float -> float).
20    RedundantCastRemoval,
21    /// Apply algebraic simplifications: x*1=x, x+0=x, x*0=0, etc.
22    AlgebraicSimplification,
23    /// Detect cycles/loops in the graph (report errors).
24    LoopDetection,
25    /// Merge sequential compatible math operations into single nodes.
26    NodeMerging,
27    /// Estimate instruction count per node and total.
28    InstructionCounting,
29    /// Dead code elimination (remove nodes not reachable from outputs).
30    DeadCodeElimination,
31    /// Constant propagation through known-value chains.
32    ConstantPropagation,
33}
34
35// ---------------------------------------------------------------------------
36// Optimizer config
37// ---------------------------------------------------------------------------
38
39/// Configuration for the shader optimizer.
40#[derive(Debug, Clone)]
41pub struct OptimizerConfig {
42    /// Which passes to run, in order.
43    pub passes: Vec<OptimizationPass>,
44    /// Maximum number of iterations for iterative passes.
45    pub max_iterations: usize,
46    /// If true, log optimization statistics.
47    pub verbose: bool,
48    /// Maximum allowed instruction count before warning.
49    pub instruction_budget: u32,
50}
51
52impl Default for OptimizerConfig {
53    fn default() -> Self {
54        Self {
55            passes: vec![
56                OptimizationPass::TypeInference,
57                OptimizationPass::DeadCodeElimination,
58                OptimizationPass::AlgebraicSimplification,
59                OptimizationPass::RedundantCastRemoval,
60                OptimizationPass::NodeMerging,
61                OptimizationPass::ConstantPropagation,
62                OptimizationPass::InstructionCounting,
63                OptimizationPass::LoopDetection,
64            ],
65            max_iterations: 10,
66            verbose: false,
67            instruction_budget: 512,
68        }
69    }
70}
71
72// ---------------------------------------------------------------------------
73// Optimization report
74// ---------------------------------------------------------------------------
75
76/// Report generated after optimization, containing statistics.
77#[derive(Debug, Clone)]
78pub struct OptimizationReport {
79    /// Number of nodes before optimization.
80    pub nodes_before: usize,
81    /// Number of nodes after optimization.
82    pub nodes_after: usize,
83    /// Number of connections before.
84    pub connections_before: usize,
85    /// Number of connections after.
86    pub connections_after: usize,
87    /// Number of nodes removed by dead code elimination.
88    pub dead_nodes_removed: usize,
89    /// Number of algebraic simplifications applied.
90    pub algebraic_simplifications: usize,
91    /// Number of redundant casts removed.
92    pub redundant_casts_removed: usize,
93    /// Number of nodes merged.
94    pub nodes_merged: usize,
95    /// Whether a cycle was detected.
96    pub cycle_detected: bool,
97    /// Estimated total instruction count.
98    pub estimated_instructions: u32,
99    /// Whether the instruction budget was exceeded.
100    pub over_budget: bool,
101    /// Inferred types for each node output: (node_id, socket_idx) -> DataType.
102    pub inferred_types: HashMap<(u64, usize), DataType>,
103    /// Warnings generated during optimization.
104    pub warnings: Vec<String>,
105}
106
107impl OptimizationReport {
108    fn new(graph: &ShaderGraph) -> Self {
109        Self {
110            nodes_before: graph.node_count(),
111            nodes_after: graph.node_count(),
112            connections_before: graph.connections().len(),
113            connections_after: graph.connections().len(),
114            dead_nodes_removed: 0,
115            algebraic_simplifications: 0,
116            redundant_casts_removed: 0,
117            nodes_merged: 0,
118            cycle_detected: false,
119            estimated_instructions: 0,
120            over_budget: false,
121            inferred_types: HashMap::new(),
122            warnings: Vec::new(),
123        }
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Shader Optimizer
129// ---------------------------------------------------------------------------
130
131/// The main shader graph optimizer.
132pub struct ShaderOptimizer {
133    config: OptimizerConfig,
134}
135
136impl ShaderOptimizer {
137    pub fn new(config: OptimizerConfig) -> Self {
138        Self { config }
139    }
140
141    pub fn with_defaults() -> Self {
142        Self::new(OptimizerConfig::default())
143    }
144
145    /// Run all configured optimization passes on the graph.
146    /// Returns the optimized graph and a report.
147    pub fn optimize(&self, graph: &ShaderGraph) -> (ShaderGraph, OptimizationReport) {
148        let mut optimized = graph.clone();
149        let mut report = OptimizationReport::new(graph);
150
151        for pass in &self.config.passes {
152            match pass {
153                OptimizationPass::TypeInference => {
154                    self.run_type_inference(&optimized, &mut report);
155                }
156                OptimizationPass::RedundantCastRemoval => {
157                    let removed = self.run_redundant_cast_removal(&mut optimized, &report.inferred_types);
158                    report.redundant_casts_removed += removed;
159                }
160                OptimizationPass::AlgebraicSimplification => {
161                    let count = self.run_algebraic_simplification(&mut optimized);
162                    report.algebraic_simplifications += count;
163                }
164                OptimizationPass::LoopDetection => {
165                    report.cycle_detected = self.detect_loops(&optimized);
166                    if report.cycle_detected {
167                        report.warnings.push("Cycle detected in shader graph".to_string());
168                    }
169                }
170                OptimizationPass::NodeMerging => {
171                    let merged = self.run_node_merging(&mut optimized);
172                    report.nodes_merged += merged;
173                }
174                OptimizationPass::InstructionCounting => {
175                    report.estimated_instructions = self.estimate_instructions(&optimized);
176                    report.over_budget = report.estimated_instructions > self.config.instruction_budget;
177                    if report.over_budget {
178                        report.warnings.push(format!(
179                            "Instruction count {} exceeds budget {}",
180                            report.estimated_instructions, self.config.instruction_budget
181                        ));
182                    }
183                }
184                OptimizationPass::DeadCodeElimination => {
185                    let removed = self.run_dead_code_elimination(&mut optimized);
186                    report.dead_nodes_removed += removed;
187                }
188                OptimizationPass::ConstantPropagation => {
189                    self.run_constant_propagation(&mut optimized);
190                }
191            }
192        }
193
194        report.nodes_after = optimized.node_count();
195        report.connections_after = optimized.connections().len();
196
197        (optimized, report)
198    }
199
200    // -----------------------------------------------------------------------
201    // Type inference
202    // -----------------------------------------------------------------------
203
204    /// Infer output types for all nodes based on their input connections.
205    fn run_type_inference(&self, graph: &ShaderGraph, report: &mut OptimizationReport) {
206        for node in graph.nodes() {
207            for (idx, socket) in node.outputs.iter().enumerate() {
208                let inferred = self.infer_output_type(graph, node, idx);
209                report.inferred_types.insert((node.id.0, idx), inferred.unwrap_or(socket.data_type));
210            }
211        }
212    }
213
214    /// Infer the output type for a specific socket of a node, considering connected inputs.
215    fn infer_output_type(&self, graph: &ShaderGraph, node: &ShaderNode, output_idx: usize) -> Option<DataType> {
216        // For most nodes, the output type is fixed by the node definition
217        let base_type = node.outputs.get(output_idx)?.data_type;
218
219        // For math ops, the output type should match the "widest" input type
220        match &node.node_type {
221            NodeType::Add | NodeType::Sub | NodeType::Mul | NodeType::Div
222            | NodeType::Lerp | NodeType::Clamp | NodeType::Smoothstep => {
223                let incoming = graph.incoming_connections(node.id);
224                let mut widest = base_type;
225                for conn in &incoming {
226                    if let Some(src_node) = graph.node(conn.from_node) {
227                        if let Some(src_type) = src_node.output_type(conn.from_socket) {
228                            widest = wider_type(widest, src_type);
229                        }
230                    }
231                }
232                Some(widest)
233            }
234            _ => Some(base_type),
235        }
236    }
237
238    // -----------------------------------------------------------------------
239    // Redundant cast removal
240    // -----------------------------------------------------------------------
241
242    /// Remove nodes that perform identity casts (same type in and out).
243    fn run_redundant_cast_removal(
244        &self,
245        graph: &mut ShaderGraph,
246        inferred_types: &HashMap<(u64, usize), DataType>,
247    ) -> usize {
248        let mut to_remove: Vec<NodeId> = Vec::new();
249
250        // Find all connections where source and dest types match and the node
251        // is essentially a pass-through
252        let node_ids: Vec<NodeId> = graph.node_ids().collect();
253        for nid in &node_ids {
254            let node = match graph.node(*nid) {
255                Some(n) => n,
256                None => continue,
257            };
258
259            // Check if this is a single-input, single-output math node where
260            // the operation is identity-like
261            if node.inputs.len() != 1 || node.outputs.len() != 1 {
262                continue;
263            }
264
265            let incoming = graph.incoming_connections(*nid);
266            if incoming.len() != 1 {
267                continue;
268            }
269
270            let conn = incoming[0];
271            let src_type = inferred_types.get(&(conn.from_node.0, conn.from_socket))
272                .copied()
273                .unwrap_or(DataType::Float);
274            let dst_type = node.outputs[0].data_type;
275
276            // If types are the same and this is a Normalize of an already-normalized vector,
277            // or similar identity operation, we could remove it.
278            // For now, just check if types match exactly for pass-through detection
279            if src_type == dst_type {
280                // Check if the node type is effectively a no-op
281                let is_noop = match &node.node_type {
282                    NodeType::Abs => {
283                        // abs is no-op if input is known non-negative
284                        false // conservative
285                    }
286                    _ => false,
287                };
288                if is_noop {
289                    to_remove.push(*nid);
290                }
291            }
292        }
293
294        let count = to_remove.len();
295
296        for nid in to_remove {
297            self.bypass_node(graph, nid);
298        }
299
300        count
301    }
302
303    /// Remove a single-input, single-output node by connecting its input source
304    /// directly to all its output destinations.
305    fn bypass_node(&self, graph: &mut ShaderGraph, node_id: NodeId) {
306        // Find the single incoming connection
307        let incoming: Vec<Connection> = graph.incoming_connections(node_id)
308            .into_iter().cloned().collect();
309        let outgoing: Vec<Connection> = graph.outgoing_connections(node_id)
310            .into_iter().cloned().collect();
311
312        if incoming.len() != 1 {
313            return;
314        }
315
316        let source = &incoming[0];
317
318        // Redirect all outgoing connections to point to the source
319        for out_conn in &outgoing {
320            graph.disconnect(node_id, out_conn.from_socket, out_conn.to_node, out_conn.to_socket);
321            graph.connect(source.from_node, source.from_socket, out_conn.to_node, out_conn.to_socket);
322        }
323
324        // Remove the node
325        graph.remove_node(node_id);
326    }
327
328    // -----------------------------------------------------------------------
329    // Algebraic simplification
330    // -----------------------------------------------------------------------
331
332    /// Apply algebraic simplifications: x*1=x, x+0=x, x*0=0, x-0=x, x/1=x, x^1=x, x^0=1.
333    fn run_algebraic_simplification(&self, graph: &mut ShaderGraph) -> usize {
334        let mut simplifications = 0;
335
336        for _iteration in 0..self.config.max_iterations {
337            let mut changes_this_round = 0;
338
339            let node_ids: Vec<NodeId> = graph.node_ids().collect();
340            for &nid in &node_ids {
341                let node = match graph.node(&nid) {
342                    Some(n) => n.clone(),
343                    None => continue,
344                };
345
346                let result = self.try_simplify_node(graph, &node);
347                match result {
348                    SimplifyResult::NoChange => {}
349                    SimplifyResult::ReplaceWithInput(input_idx) => {
350                        // This node reduces to one of its inputs — bypass it
351                        let incoming: Vec<Connection> = graph.incoming_connections(nid)
352                            .into_iter().cloned().collect();
353                        let source_conn = incoming.iter().find(|c| c.to_socket == input_idx);
354                        if let Some(src) = source_conn {
355                            let outgoing: Vec<Connection> = graph.outgoing_connections(nid)
356                                .into_iter().cloned().collect();
357                            for out in &outgoing {
358                                graph.connect(src.from_node, src.from_socket, out.to_node, out.to_socket);
359                            }
360                            graph.remove_node(nid);
361                            changes_this_round += 1;
362                        }
363                    }
364                    SimplifyResult::ReplaceWithConstant(value) => {
365                        // Replace this node with a Color source holding the constant
366                        let outgoing: Vec<Connection> = graph.outgoing_connections(nid)
367                            .into_iter().cloned().collect();
368
369                        // Create a replacement Color node with the constant value
370                        let mut replacement = ShaderNode::new(NodeId(0), NodeType::Color);
371                        replacement.inputs[0].default_value = Some(match &value {
372                            ParamValue::Float(v) => ParamValue::Vec4([*v, *v, *v, 1.0]),
373                            ParamValue::Vec3(v) => ParamValue::Vec4([v[0], v[1], v[2], 1.0]),
374                            other => other.clone(),
375                        });
376                        replacement.properties.insert("folded_constant".to_string(), value);
377
378                        let new_id = graph.add_node_with(replacement);
379
380                        // Redirect outputs
381                        for out in &outgoing {
382                            graph.connect(new_id, 0, out.to_node, out.to_socket);
383                        }
384
385                        graph.remove_node(nid);
386                        changes_this_round += 1;
387                    }
388                }
389            }
390
391            simplifications += changes_this_round;
392            if changes_this_round == 0 {
393                break;
394            }
395        }
396
397        simplifications
398    }
399
400    /// Try to simplify a single node.
401    fn try_simplify_node(&self, graph: &ShaderGraph, node: &ShaderNode) -> SimplifyResult {
402        let incoming: Vec<&Connection> = graph.incoming_connections(node.id);
403
404        match &node.node_type {
405            // x + 0 = x
406            NodeType::Add => {
407                if let Some(result) = self.check_identity_binary(graph, node, &incoming, 0.0) {
408                    return result;
409                }
410            }
411            // x - 0 = x (only right operand)
412            NodeType::Sub => {
413                if self.is_input_constant(graph, node, &incoming, 1, 0.0) {
414                    return SimplifyResult::ReplaceWithInput(0);
415                }
416            }
417            // x * 1 = x; x * 0 = 0
418            NodeType::Mul => {
419                if let Some(result) = self.check_identity_binary(graph, node, &incoming, 1.0) {
420                    return result;
421                }
422                // x * 0 = 0
423                if self.is_input_constant(graph, node, &incoming, 0, 0.0) {
424                    return SimplifyResult::ReplaceWithConstant(ParamValue::Float(0.0));
425                }
426                if self.is_input_constant(graph, node, &incoming, 1, 0.0) {
427                    return SimplifyResult::ReplaceWithConstant(ParamValue::Float(0.0));
428                }
429            }
430            // x / 1 = x
431            NodeType::Div => {
432                if self.is_input_constant(graph, node, &incoming, 1, 1.0) {
433                    return SimplifyResult::ReplaceWithInput(0);
434                }
435            }
436            // pow(x, 1) = x; pow(x, 0) = 1
437            NodeType::Pow => {
438                if self.is_input_constant(graph, node, &incoming, 1, 1.0) {
439                    return SimplifyResult::ReplaceWithInput(0);
440                }
441                if self.is_input_constant(graph, node, &incoming, 1, 0.0) {
442                    return SimplifyResult::ReplaceWithConstant(ParamValue::Float(1.0));
443                }
444            }
445            // lerp(a, b, 0) = a; lerp(a, b, 1) = b
446            NodeType::Lerp => {
447                if self.is_input_constant(graph, node, &incoming, 2, 0.0) {
448                    return SimplifyResult::ReplaceWithInput(0);
449                }
450                if self.is_input_constant(graph, node, &incoming, 2, 1.0) {
451                    return SimplifyResult::ReplaceWithInput(1);
452                }
453            }
454            // clamp(x, -inf, inf) effectively = x (we check 0..1 identity)
455            NodeType::Clamp => {
456                // If min=0.0, max=1.0, and x is known to be in [0,1], this is identity
457                // For now, conservative — no simplification
458            }
459            // step(0, x) = 1 for all x >= 0
460            NodeType::Step => {
461                if self.is_input_constant(graph, node, &incoming, 0, 0.0) {
462                    // step(0, x) = 1 if x >= 0 — we can't prove x >= 0 in general
463                }
464            }
465            _ => {}
466        }
467
468        SimplifyResult::NoChange
469    }
470
471    /// Check if one of the two inputs to a binary op is a specific constant (identity element).
472    /// If so, the result equals the other input.
473    fn check_identity_binary(
474        &self,
475        graph: &ShaderGraph,
476        node: &ShaderNode,
477        incoming: &[&Connection],
478        identity: f32,
479    ) -> Option<SimplifyResult> {
480        if self.is_input_constant(graph, node, incoming, 0, identity) {
481            return Some(SimplifyResult::ReplaceWithInput(1));
482        }
483        if self.is_input_constant(graph, node, incoming, 1, identity) {
484            return Some(SimplifyResult::ReplaceWithInput(0));
485        }
486        None
487    }
488
489    /// Check if a specific input socket has a constant float value.
490    fn is_input_constant(
491        &self,
492        _graph: &ShaderGraph,
493        node: &ShaderNode,
494        incoming: &[&Connection],
495        socket_idx: usize,
496        expected: f32,
497    ) -> bool {
498        // First, check if there's a connection to this socket
499        let has_connection = incoming.iter().any(|c| c.to_socket == socket_idx);
500        if has_connection {
501            // We'd need to trace back to the source node to check if it's a constant
502            // For simplicity, we only check unconnected sockets with default values
503            return false;
504        }
505
506        // Check the default value
507        if let Some(default) = node.input_default(socket_idx) {
508            if let Some(val) = default.as_float() {
509                return (val - expected).abs() < 1e-7;
510            }
511        }
512
513        false
514    }
515
516    // -----------------------------------------------------------------------
517    // Loop/cycle detection
518    // -----------------------------------------------------------------------
519
520    /// Detect if the graph contains any cycles using DFS coloring.
521    fn detect_loops(&self, graph: &ShaderGraph) -> bool {
522        let mut color: HashMap<NodeId, u8> = HashMap::new(); // 0=white, 1=grey, 2=black
523        for nid in graph.node_ids() {
524            color.insert(nid, 0);
525        }
526
527        for nid in graph.node_ids() {
528            if color[&nid] == 0 {
529                if self.dfs_cycle(graph, nid, &mut color) {
530                    return true;
531                }
532            }
533        }
534
535        false
536    }
537
538    fn dfs_cycle(&self, graph: &ShaderGraph, node_id: NodeId, color: &mut HashMap<NodeId, u8>) -> bool {
539        color.insert(node_id, 1); // grey
540
541        for conn in graph.outgoing_connections(node_id) {
542            let neighbor = conn.to_node;
543            match color.get(&neighbor) {
544                Some(1) => return true,  // back edge => cycle
545                Some(0) => {
546                    if self.dfs_cycle(graph, neighbor, color) {
547                        return true;
548                    }
549                }
550                _ => {} // already visited (black)
551            }
552        }
553
554        color.insert(node_id, 2); // black
555        false
556    }
557
558    // -----------------------------------------------------------------------
559    // Node merging
560    // -----------------------------------------------------------------------
561
562    /// Merge chains of compatible sequential math operations.
563    /// E.g., Add(Add(a, b), c) can note that it's a 3-way add (though GLSL
564    /// doesn't have a single instruction, we can eliminate intermediate variables).
565    fn run_node_merging(&self, graph: &mut ShaderGraph) -> usize {
566        let mut merged = 0;
567
568        // Strategy: find chains of the same binary op where the intermediate result
569        // is used only once. E.g., if Add(a,b) feeds only into Add(_, c), we can
570        // eliminate the intermediate by rewriting as Add(a, Add_inline(b, c)).
571        // In practice, we mark the intermediate node as "inline" by removing it
572        // and adjusting the downstream node's GLSL.
573
574        let node_ids: Vec<NodeId> = graph.node_ids().collect();
575        let mut removed_set: HashSet<NodeId> = HashSet::new();
576
577        for &nid in &node_ids {
578            if removed_set.contains(&nid) {
579                continue;
580            }
581
582            let node = match graph.node(&nid) {
583                Some(n) => n,
584                None => continue,
585            };
586
587            // Only merge binary math ops
588            let is_mergeable = matches!(
589                node.node_type,
590                NodeType::Add | NodeType::Sub | NodeType::Mul
591            );
592            if !is_mergeable {
593                continue;
594            }
595
596            // Check if this node has exactly one outgoing connection
597            let outgoing = graph.outgoing_connections(nid);
598            if outgoing.len() != 1 {
599                continue;
600            }
601
602            let out_conn = outgoing[0].clone();
603            let downstream = match graph.node(&out_conn.to_node) {
604                Some(n) => n,
605                None => continue,
606            };
607
608            // Must be the same operation type
609            if downstream.node_type != node.node_type {
610                continue;
611            }
612
613            // Don't merge if the downstream node is already in the removed set
614            if removed_set.contains(&out_conn.to_node) {
615                continue;
616            }
617
618            // The current node's output feeds into one of the downstream's inputs.
619            // We'll propagate the current node's inputs to the downstream node's properties
620            // so that the GLSL generator can inline the expression.
621
622            // For now, mark the merge in properties and skip actual structural changes
623            // to avoid complex graph rewiring. The compiler will handle inlining.
624            if let Some(downstream_mut) = graph.node_mut(out_conn.to_node) {
625                downstream_mut.properties.insert(
626                    format!("merged_from_{}", nid.0),
627                    ParamValue::Bool(true),
628                );
629                merged += 1;
630            }
631        }
632
633        merged
634    }
635
636    // -----------------------------------------------------------------------
637    // Dead code elimination
638    // -----------------------------------------------------------------------
639
640    fn run_dead_code_elimination(&self, graph: &mut ShaderGraph) -> usize {
641        let outputs = graph.output_nodes();
642        if outputs.is_empty() {
643            return 0;
644        }
645
646        // BFS from outputs to find all reachable nodes
647        let mut reachable: HashSet<NodeId> = HashSet::new();
648        let mut queue: Vec<NodeId> = outputs;
649
650        while let Some(nid) = queue.pop() {
651            if !reachable.insert(nid) {
652                continue;
653            }
654            for conn in graph.connections() {
655                if conn.to_node == nid && !reachable.contains(&conn.from_node) {
656                    queue.push(conn.from_node);
657                }
658            }
659        }
660
661        // Remove unreachable nodes
662        let all_ids: Vec<NodeId> = graph.node_ids().collect();
663        let mut removed = 0;
664        for nid in all_ids {
665            if !reachable.contains(&nid) {
666                graph.remove_node(nid);
667                removed += 1;
668            }
669        }
670
671        removed
672    }
673
674    // -----------------------------------------------------------------------
675    // Constant propagation
676    // -----------------------------------------------------------------------
677
678    /// Propagate known constant values through chains of pure math nodes.
679    fn run_constant_propagation(&self, graph: &mut ShaderGraph) {
680        // Build a map of known constant outputs
681        let mut known_constants: HashMap<(NodeId, usize), ParamValue> = HashMap::new();
682
683        // First, find all Color nodes with explicit constant values
684        let node_ids: Vec<NodeId> = graph.node_ids().collect();
685        for &nid in &node_ids {
686            let node = match graph.node(&nid) {
687                Some(n) => n,
688                None => continue,
689            };
690
691            if node.node_type == NodeType::Color {
692                if let Some(val) = &node.inputs[0].default_value {
693                    // Check if this node has no incoming connections (truly constant)
694                    let incoming = graph.incoming_connections(nid);
695                    if incoming.is_empty() {
696                        known_constants.insert((nid, 0), val.clone());
697                    }
698                }
699            }
700        }
701
702        // Propagate through pure math nodes
703        // (In a full implementation, we would do a topological traversal here.
704        // For now, we store the constants for downstream use by the compiler.)
705        for &nid in &node_ids {
706            let node = match graph.node(&nid) {
707                Some(n) => n,
708                None => continue,
709            };
710
711            if !node.node_type.is_pure_math() {
712                continue;
713            }
714
715            let incoming = graph.incoming_connections(nid);
716            let mut all_inputs_known = true;
717            let mut input_vals: Vec<ParamValue> = Vec::new();
718
719            for (idx, socket) in node.inputs.iter().enumerate() {
720                let conn = incoming.iter().find(|c| c.to_socket == idx);
721                if let Some(c) = conn {
722                    if let Some(val) = known_constants.get(&(c.from_node, c.from_socket)) {
723                        input_vals.push(val.clone());
724                    } else {
725                        all_inputs_known = false;
726                        break;
727                    }
728                } else if let Some(def) = &socket.default_value {
729                    input_vals.push(def.clone());
730                } else {
731                    all_inputs_known = false;
732                    break;
733                }
734            }
735
736            if all_inputs_known && !input_vals.is_empty() {
737                // Try to evaluate
738                if let Some(result) = evaluate_pure_node(&node.node_type, &input_vals) {
739                    for (idx, val) in result.iter().enumerate() {
740                        known_constants.insert((nid, idx), val.clone());
741                    }
742                    // Store the folded value in the node's properties for the compiler
743                    if let Some(node_mut) = graph.node_mut(nid) {
744                        if let Some(first) = result.into_iter().next() {
745                            node_mut.properties.insert(
746                                "propagated_constant".to_string(),
747                                first,
748                            );
749                        }
750                    }
751                }
752            }
753        }
754    }
755
756    // -----------------------------------------------------------------------
757    // Instruction counting
758    // -----------------------------------------------------------------------
759
760    fn estimate_instructions(&self, graph: &ShaderGraph) -> u32 {
761        graph.estimated_cost()
762    }
763}
764
765// ---------------------------------------------------------------------------
766// Helper types and functions
767// ---------------------------------------------------------------------------
768
769enum SimplifyResult {
770    NoChange,
771    /// Replace node with one of its inputs (by socket index).
772    ReplaceWithInput(usize),
773    /// Replace node with a constant value.
774    ReplaceWithConstant(ParamValue),
775}
776
777/// Return the "wider" of two types for type promotion.
778fn wider_type(a: DataType, b: DataType) -> DataType {
779    let rank = |t: DataType| -> u8 {
780        match t {
781            DataType::Bool => 0,
782            DataType::Int => 1,
783            DataType::Float => 2,
784            DataType::Vec2 => 3,
785            DataType::Vec3 => 4,
786            DataType::Vec4 => 5,
787            DataType::Mat3 => 6,
788            DataType::Mat4 => 7,
789            DataType::Sampler2D => 8,
790        }
791    };
792    if rank(a) >= rank(b) { a } else { b }
793}
794
795/// Evaluate a pure math node with known inputs (used for constant propagation).
796fn evaluate_pure_node(node_type: &NodeType, inputs: &[ParamValue]) -> Option<Vec<ParamValue>> {
797    match node_type {
798        NodeType::Add => {
799            let a = inputs.first()?.as_float()?;
800            let b = inputs.get(1)?.as_float()?;
801            Some(vec![ParamValue::Float(a + b)])
802        }
803        NodeType::Sub => {
804            let a = inputs.first()?.as_float()?;
805            let b = inputs.get(1)?.as_float()?;
806            Some(vec![ParamValue::Float(a - b)])
807        }
808        NodeType::Mul => {
809            let a = inputs.first()?.as_float()?;
810            let b = inputs.get(1)?.as_float()?;
811            Some(vec![ParamValue::Float(a * b)])
812        }
813        NodeType::Div => {
814            let a = inputs.first()?.as_float()?;
815            let b = inputs.get(1)?.as_float()?;
816            if b.abs() < 1e-10 { return None; }
817            Some(vec![ParamValue::Float(a / b)])
818        }
819        NodeType::Abs => {
820            let x = inputs.first()?.as_float()?;
821            Some(vec![ParamValue::Float(x.abs())])
822        }
823        NodeType::Floor => {
824            let x = inputs.first()?.as_float()?;
825            Some(vec![ParamValue::Float(x.floor())])
826        }
827        NodeType::Ceil => {
828            let x = inputs.first()?.as_float()?;
829            Some(vec![ParamValue::Float(x.ceil())])
830        }
831        NodeType::Fract => {
832            let x = inputs.first()?.as_float()?;
833            Some(vec![ParamValue::Float(x.fract())])
834        }
835        NodeType::Sqrt => {
836            let x = inputs.first()?.as_float()?;
837            Some(vec![ParamValue::Float(x.max(0.0).sqrt())])
838        }
839        NodeType::Sin => {
840            let x = inputs.first()?.as_float()?;
841            Some(vec![ParamValue::Float(x.sin())])
842        }
843        NodeType::Cos => {
844            let x = inputs.first()?.as_float()?;
845            Some(vec![ParamValue::Float(x.cos())])
846        }
847        NodeType::Pow => {
848            let base = inputs.first()?.as_float()?;
849            let exp = inputs.get(1)?.as_float()?;
850            Some(vec![ParamValue::Float(base.max(0.0).powf(exp))])
851        }
852        NodeType::Lerp => {
853            let a = inputs.first()?.as_float()?;
854            let b = inputs.get(1)?.as_float()?;
855            let t = inputs.get(2)?.as_float()?;
856            Some(vec![ParamValue::Float(a + (b - a) * t)])
857        }
858        NodeType::Clamp => {
859            let x = inputs.first()?.as_float()?;
860            let lo = inputs.get(1)?.as_float()?;
861            let hi = inputs.get(2)?.as_float()?;
862            Some(vec![ParamValue::Float(x.clamp(lo, hi))])
863        }
864        NodeType::Step => {
865            let edge = inputs.first()?.as_float()?;
866            let x = inputs.get(1)?.as_float()?;
867            Some(vec![ParamValue::Float(if x >= edge { 1.0 } else { 0.0 })])
868        }
869        NodeType::Invert => {
870            let c = inputs.first()?.as_vec3()?;
871            Some(vec![ParamValue::Vec3([1.0 - c[0], 1.0 - c[1], 1.0 - c[2]])])
872        }
873        _ => None,
874    }
875}
876
877// ---------------------------------------------------------------------------
878// Convenience
879// ---------------------------------------------------------------------------
880
881/// Optimize a shader graph with default settings.
882pub fn optimize_graph(graph: &ShaderGraph) -> (ShaderGraph, OptimizationReport) {
883    ShaderOptimizer::with_defaults().optimize(graph)
884}
885
886/// Estimate the instruction count of a shader graph.
887pub fn estimate_instruction_count(graph: &ShaderGraph) -> u32 {
888    graph.estimated_cost()
889}
890
891/// Check if a graph has cycles.
892pub fn has_cycles(graph: &ShaderGraph) -> bool {
893    ShaderOptimizer::with_defaults().detect_loops(graph)
894}