Skip to main content

proof_engine/shader_graph/
compiler.rs

1//! Shader graph compiler: topological sort, dead-node elimination, constant folding,
2//! common subexpression elimination, and GLSL code generation.
3
4use std::collections::{HashMap, HashSet, VecDeque};
5use super::nodes::{
6    Connection, DataType, GlslSnippet, NodeId, NodeType, ParamValue, ShaderGraph, ShaderNode,
7};
8
9// ---------------------------------------------------------------------------
10// Compilation options
11// ---------------------------------------------------------------------------
12
13/// Options controlling the compilation process.
14#[derive(Debug, Clone)]
15pub struct CompileOptions {
16    /// If true, run dead-node elimination (remove unreachable from outputs).
17    pub dead_node_elimination: bool,
18    /// If true, evaluate constant subtrees at compile time.
19    pub constant_folding: bool,
20    /// If true, merge common subexpressions.
21    pub common_subexpression_elimination: bool,
22    /// If true, include comments in generated GLSL for debugging.
23    pub debug_comments: bool,
24    /// GLSL version string (e.g., "330 core", "300 es").
25    pub glsl_version: String,
26    /// If true, generate conditional branches for nodes with conditions.
27    pub enable_conditionals: bool,
28    /// If true, generate animated uniform declarations for time-dependent parameters.
29    pub animated_uniforms: bool,
30}
31
32impl Default for CompileOptions {
33    fn default() -> Self {
34        Self {
35            dead_node_elimination: true,
36            constant_folding: true,
37            common_subexpression_elimination: true,
38            debug_comments: false,
39            glsl_version: "330 core".to_string(),
40            enable_conditionals: true,
41            animated_uniforms: true,
42        }
43    }
44}
45
46// ---------------------------------------------------------------------------
47// Compile errors
48// ---------------------------------------------------------------------------
49
50/// Errors that can occur during shader graph compilation.
51#[derive(Debug, Clone)]
52pub enum CompileError {
53    /// The graph contains a cycle, making topological sort impossible.
54    CycleDetected(Vec<NodeId>),
55    /// A required input socket has no connection and no default value.
56    MissingInput { node_id: NodeId, socket_index: usize, socket_name: String },
57    /// The graph has no output nodes.
58    NoOutputNodes,
59    /// A type mismatch between connected sockets.
60    TypeMismatch {
61        from_node: NodeId,
62        from_socket: usize,
63        from_type: DataType,
64        to_node: NodeId,
65        to_socket: usize,
66        to_type: DataType,
67    },
68    /// Graph validation failed.
69    ValidationErrors(Vec<String>),
70    /// Internal compiler error.
71    Internal(String),
72}
73
74impl std::fmt::Display for CompileError {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            CompileError::CycleDetected(ids) => {
78                write!(f, "Cycle detected involving nodes: {:?}",
79                    ids.iter().map(|id| id.0).collect::<Vec<_>>())
80            }
81            CompileError::MissingInput { node_id, socket_index, socket_name } => {
82                write!(f, "Node {} missing input at socket {} ('{}')",
83                    node_id.0, socket_index, socket_name)
84            }
85            CompileError::NoOutputNodes => write!(f, "Graph has no output nodes"),
86            CompileError::TypeMismatch { from_node, from_socket, from_type, to_node, to_socket, to_type } => {
87                write!(f, "Type mismatch: node {}:{} ({}) -> node {}:{} ({})",
88                    from_node.0, from_socket, from_type,
89                    to_node.0, to_socket, to_type)
90            }
91            CompileError::ValidationErrors(errs) => {
92                write!(f, "Validation errors: {}", errs.join("; "))
93            }
94            CompileError::Internal(msg) => write!(f, "Internal error: {}", msg),
95        }
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Compiled shader output
101// ---------------------------------------------------------------------------
102
103/// The result of compiling a shader graph.
104#[derive(Debug, Clone)]
105pub struct CompiledShader {
106    /// The generated GLSL fragment shader source.
107    pub fragment_source: String,
108    /// The generated GLSL vertex shader source (boilerplate).
109    pub vertex_source: String,
110    /// All uniform declarations needed.
111    pub uniforms: Vec<UniformDecl>,
112    /// All varying declarations needed.
113    pub varyings: Vec<VaryingDecl>,
114    /// Estimated total instruction count.
115    pub instruction_count: u32,
116    /// Number of texture samplers used.
117    pub sampler_count: u32,
118    /// Number of nodes after dead-node elimination.
119    pub live_node_count: usize,
120    /// Topology hash for caching.
121    pub topology_hash: u64,
122    /// Nodes in topological order after all optimizations.
123    pub node_order: Vec<NodeId>,
124    /// Map from node output (node_id, socket_index) to GLSL variable name.
125    pub output_var_map: HashMap<(u64, usize), String>,
126}
127
128/// A uniform variable declaration.
129#[derive(Debug, Clone)]
130pub struct UniformDecl {
131    pub name: String,
132    pub data_type: DataType,
133    pub default_value: Option<ParamValue>,
134    pub is_animated: bool,
135}
136
137/// A varying variable declaration.
138#[derive(Debug, Clone)]
139pub struct VaryingDecl {
140    pub name: String,
141    pub data_type: DataType,
142}
143
144// ---------------------------------------------------------------------------
145// Shader Compiler
146// ---------------------------------------------------------------------------
147
148/// The main shader graph compiler.
149pub struct ShaderCompiler {
150    options: CompileOptions,
151}
152
153impl ShaderCompiler {
154    pub fn new(options: CompileOptions) -> Self {
155        Self { options }
156    }
157
158    pub fn with_defaults() -> Self {
159        Self::new(CompileOptions::default())
160    }
161
162    /// Compile a shader graph into GLSL source code.
163    pub fn compile(&self, graph: &ShaderGraph) -> Result<CompiledShader, CompileError> {
164        // Step 0: Validate
165        let errors = graph.validate();
166        if !errors.is_empty() {
167            return Err(CompileError::ValidationErrors(errors));
168        }
169
170        // Step 1: Find output nodes
171        let output_nodes = graph.output_nodes();
172        if output_nodes.is_empty() {
173            return Err(CompileError::NoOutputNodes);
174        }
175
176        // Step 2: Dead node elimination — find all nodes reachable from outputs
177        let live_nodes = if self.options.dead_node_elimination {
178            self.find_live_nodes(graph, &output_nodes)
179        } else {
180            graph.node_ids().collect()
181        };
182
183        // Step 3: Topological sort of live nodes
184        let sorted = self.topological_sort(graph, &live_nodes)?;
185
186        // Step 4: Constant folding
187        let folded_values = if self.options.constant_folding {
188            self.constant_fold(graph, &sorted)
189        } else {
190            HashMap::new()
191        };
192
193        // Step 5: Common subexpression elimination
194        let cse_map = if self.options.common_subexpression_elimination {
195            self.find_common_subexpressions(graph, &sorted)
196        } else {
197            HashMap::new()
198        };
199
200        // Step 6: Collect uniforms and varyings
201        let (uniforms, varyings) = self.collect_declarations(graph, &sorted);
202
203        // Step 7: Generate GLSL
204        let (fragment_source, output_var_map) = self.generate_glsl(
205            graph, &sorted, &folded_values, &cse_map, &uniforms, &varyings,
206        );
207
208        // Step 8: Generate vertex shader
209        let vertex_source = self.generate_vertex_shader(&varyings);
210
211        // Step 9: Compute stats
212        let instruction_count: u32 = sorted.iter()
213            .filter_map(|id| graph.node(*id).map(|n| n.estimated_cost()))
214            .sum();
215        let sampler_count = uniforms.iter()
216            .filter(|u| u.data_type == DataType::Sampler2D)
217            .count() as u32;
218
219        Ok(CompiledShader {
220            fragment_source,
221            vertex_source,
222            uniforms,
223            varyings,
224            instruction_count,
225            sampler_count,
226            live_node_count: sorted.len(),
227            topology_hash: graph.topology_hash(),
228            node_order: sorted,
229            output_var_map,
230        })
231    }
232
233    // -----------------------------------------------------------------------
234    // Dead node elimination
235    // -----------------------------------------------------------------------
236
237    /// Walk backwards from output nodes, collecting all reachable node IDs.
238    fn find_live_nodes(&self, graph: &ShaderGraph, outputs: &[NodeId]) -> HashSet<NodeId> {
239        let mut live = HashSet::new();
240        let mut queue: VecDeque<NodeId> = outputs.iter().copied().collect();
241
242        while let Some(node_id) = queue.pop_front() {
243            if !live.insert(node_id) {
244                continue; // already visited
245            }
246            // Walk incoming connections
247            for conn in graph.connections() {
248                if conn.to_node == node_id && !live.contains(&conn.from_node) {
249                    queue.push_back(conn.from_node);
250                }
251            }
252        }
253
254        live
255    }
256
257    // -----------------------------------------------------------------------
258    // Topological sort
259    // -----------------------------------------------------------------------
260
261    /// Kahn's algorithm for topological sorting. Returns sorted node IDs or a cycle error.
262    fn topological_sort(
263        &self,
264        graph: &ShaderGraph,
265        live_nodes: &HashSet<NodeId>,
266    ) -> Result<Vec<NodeId>, CompileError> {
267        // Build adjacency and in-degree maps considering only live nodes
268        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
269        let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
270
271        for &nid in live_nodes {
272            in_degree.entry(nid).or_insert(0);
273            adjacency.entry(nid).or_insert_with(Vec::new);
274        }
275
276        for conn in graph.connections() {
277            if live_nodes.contains(&conn.from_node) && live_nodes.contains(&conn.to_node) {
278                adjacency.entry(conn.from_node).or_insert_with(Vec::new).push(conn.to_node);
279                *in_degree.entry(conn.to_node).or_insert(0) += 1;
280            }
281        }
282
283        // Start with all nodes that have zero in-degree
284        let mut queue: VecDeque<NodeId> = in_degree.iter()
285            .filter(|(_, &deg)| deg == 0)
286            .map(|(&id, _)| id)
287            .collect();
288
289        // Sort the queue for deterministic output
290        let mut queue_vec: Vec<NodeId> = queue.drain(..).collect();
291        queue_vec.sort_by_key(|id| id.0);
292        queue = queue_vec.into_iter().collect();
293
294        let mut sorted = Vec::new();
295
296        while let Some(node_id) = queue.pop_front() {
297            sorted.push(node_id);
298            if let Some(neighbors) = adjacency.get(&node_id) {
299                let mut next_neighbors: Vec<NodeId> = Vec::new();
300                for &neighbor in neighbors {
301                    if let Some(deg) = in_degree.get_mut(&neighbor) {
302                        *deg -= 1;
303                        if *deg == 0 {
304                            next_neighbors.push(neighbor);
305                        }
306                    }
307                }
308                next_neighbors.sort_by_key(|id| id.0);
309                for n in next_neighbors {
310                    queue.push_back(n);
311                }
312            }
313        }
314
315        if sorted.len() != live_nodes.len() {
316            // Cycle detected — find participating nodes
317            let sorted_set: HashSet<NodeId> = sorted.iter().copied().collect();
318            let cycle_nodes: Vec<NodeId> = live_nodes.iter()
319                .filter(|id| !sorted_set.contains(id))
320                .copied()
321                .collect();
322            return Err(CompileError::CycleDetected(cycle_nodes));
323        }
324
325        Ok(sorted)
326    }
327
328    // -----------------------------------------------------------------------
329    // Constant folding
330    // -----------------------------------------------------------------------
331
332    /// Identify nodes whose inputs are all constant (literal defaults or other folded nodes)
333    /// and evaluate them at compile time.
334    fn constant_fold(
335        &self,
336        graph: &ShaderGraph,
337        sorted: &[NodeId],
338    ) -> HashMap<NodeId, Vec<ParamValue>> {
339        let mut folded: HashMap<NodeId, Vec<ParamValue>> = HashMap::new();
340
341        for &node_id in sorted {
342            let node = match graph.node(node_id) {
343                Some(n) => n,
344                None => continue,
345            };
346
347            if !node.node_type.is_pure_math() {
348                continue;
349            }
350
351            // Check if all inputs are constants
352            let incoming = graph.incoming_connections(node_id);
353            let mut input_values: Vec<Option<ParamValue>> = Vec::new();
354            let mut all_constant = true;
355
356            for (idx, socket) in node.inputs.iter().enumerate() {
357                // Find connection to this socket
358                let conn = incoming.iter().find(|c| c.to_socket == idx);
359                if let Some(c) = conn {
360                    // Check if source is folded
361                    if let Some(folded_vals) = folded.get(&c.from_node) {
362                        if c.from_socket < folded_vals.len() {
363                            input_values.push(Some(folded_vals[c.from_socket].clone()));
364                            continue;
365                        }
366                    }
367                    all_constant = false;
368                    break;
369                } else if let Some(def) = &socket.default_value {
370                    input_values.push(Some(def.clone()));
371                } else {
372                    all_constant = false;
373                    break;
374                }
375            }
376
377            if !all_constant {
378                continue;
379            }
380
381            // Try to evaluate
382            let values: Vec<ParamValue> = input_values.into_iter().filter_map(|v| v).collect();
383            if let Some(result) = self.evaluate_constant(&node.node_type, &values) {
384                folded.insert(node_id, result);
385            }
386        }
387
388        folded
389    }
390
391    /// Evaluate a pure-math node with constant inputs.
392    fn evaluate_constant(&self, node_type: &NodeType, inputs: &[ParamValue]) -> Option<Vec<ParamValue>> {
393        match node_type {
394            NodeType::Add => {
395                let a = inputs.first()?.as_float()?;
396                let b = inputs.get(1)?.as_float()?;
397                Some(vec![ParamValue::Float(a + b)])
398            }
399            NodeType::Sub => {
400                let a = inputs.first()?.as_float()?;
401                let b = inputs.get(1)?.as_float()?;
402                Some(vec![ParamValue::Float(a - b)])
403            }
404            NodeType::Mul => {
405                let a = inputs.first()?.as_float()?;
406                let b = inputs.get(1)?.as_float()?;
407                Some(vec![ParamValue::Float(a * b)])
408            }
409            NodeType::Div => {
410                let a = inputs.first()?.as_float()?;
411                let b = inputs.get(1)?.as_float()?;
412                if b.abs() < 1e-10 { return None; }
413                Some(vec![ParamValue::Float(a / b)])
414            }
415            NodeType::Abs => {
416                let x = inputs.first()?.as_float()?;
417                Some(vec![ParamValue::Float(x.abs())])
418            }
419            NodeType::Floor => {
420                let x = inputs.first()?.as_float()?;
421                Some(vec![ParamValue::Float(x.floor())])
422            }
423            NodeType::Ceil => {
424                let x = inputs.first()?.as_float()?;
425                Some(vec![ParamValue::Float(x.ceil())])
426            }
427            NodeType::Fract => {
428                let x = inputs.first()?.as_float()?;
429                Some(vec![ParamValue::Float(x.fract())])
430            }
431            NodeType::Mod => {
432                let x = inputs.first()?.as_float()?;
433                let y = inputs.get(1)?.as_float()?;
434                if y.abs() < 1e-10 { return None; }
435                Some(vec![ParamValue::Float(x % y)])
436            }
437            NodeType::Pow => {
438                let base = inputs.first()?.as_float()?;
439                let exp = inputs.get(1)?.as_float()?;
440                Some(vec![ParamValue::Float(base.max(0.0).powf(exp))])
441            }
442            NodeType::Sqrt => {
443                let x = inputs.first()?.as_float()?;
444                Some(vec![ParamValue::Float(x.max(0.0).sqrt())])
445            }
446            NodeType::Sin => {
447                let x = inputs.first()?.as_float()?;
448                Some(vec![ParamValue::Float(x.sin())])
449            }
450            NodeType::Cos => {
451                let x = inputs.first()?.as_float()?;
452                Some(vec![ParamValue::Float(x.cos())])
453            }
454            NodeType::Tan => {
455                let x = inputs.first()?.as_float()?;
456                Some(vec![ParamValue::Float(x.tan())])
457            }
458            NodeType::Atan2 => {
459                let y = inputs.first()?.as_float()?;
460                let x = inputs.get(1)?.as_float()?;
461                Some(vec![ParamValue::Float(y.atan2(x))])
462            }
463            NodeType::Lerp => {
464                let a = inputs.first()?.as_float()?;
465                let b = inputs.get(1)?.as_float()?;
466                let t = inputs.get(2)?.as_float()?;
467                Some(vec![ParamValue::Float(a + (b - a) * t)])
468            }
469            NodeType::Clamp => {
470                let x = inputs.first()?.as_float()?;
471                let lo = inputs.get(1)?.as_float()?;
472                let hi = inputs.get(2)?.as_float()?;
473                Some(vec![ParamValue::Float(x.clamp(lo, hi))])
474            }
475            NodeType::Smoothstep => {
476                let e0 = inputs.first()?.as_float()?;
477                let e1 = inputs.get(1)?.as_float()?;
478                let x = inputs.get(2)?.as_float()?;
479                let range = e1 - e0;
480                if range.abs() < 1e-10 {
481                    return Some(vec![ParamValue::Float(if x < e0 { 0.0 } else { 1.0 })]);
482                }
483                let t = ((x - e0) / range).clamp(0.0, 1.0);
484                Some(vec![ParamValue::Float(t * t * (3.0 - 2.0 * t))])
485            }
486            NodeType::Remap => {
487                let x = inputs.first()?.as_float()?;
488                let in_min = inputs.get(1)?.as_float()?;
489                let in_max = inputs.get(2)?.as_float()?;
490                let out_min = inputs.get(3)?.as_float()?;
491                let out_max = inputs.get(4)?.as_float()?;
492                let range = in_max - in_min;
493                if range.abs() < 1e-10 { return None; }
494                let t = (x - in_min) / range;
495                Some(vec![ParamValue::Float(out_min + (out_max - out_min) * t)])
496            }
497            NodeType::Step => {
498                let edge = inputs.first()?.as_float()?;
499                let x = inputs.get(1)?.as_float()?;
500                Some(vec![ParamValue::Float(if x >= edge { 1.0 } else { 0.0 })])
501            }
502            NodeType::Invert => {
503                let c = inputs.first()?.as_vec3()?;
504                Some(vec![ParamValue::Vec3([1.0 - c[0], 1.0 - c[1], 1.0 - c[2]])])
505            }
506            NodeType::Posterize => {
507                let c = inputs.first()?.as_vec3()?;
508                let levels = inputs.get(1)?.as_float()?;
509                if levels < 1.0 { return None; }
510                Some(vec![ParamValue::Vec3([
511                    (c[0] * levels).floor() / levels,
512                    (c[1] * levels).floor() / levels,
513                    (c[2] * levels).floor() / levels,
514                ])])
515            }
516            NodeType::Contrast => {
517                let c = inputs.first()?.as_vec3()?;
518                let amount = inputs.get(1)?.as_float()?;
519                Some(vec![ParamValue::Vec3([
520                    (c[0] - 0.5) * amount + 0.5,
521                    (c[1] - 0.5) * amount + 0.5,
522                    (c[2] - 0.5) * amount + 0.5,
523                ])])
524            }
525            NodeType::Saturation => {
526                let c = inputs.first()?.as_vec3()?;
527                let amount = inputs.get(1)?.as_float()?;
528                let lum = c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114;
529                Some(vec![ParamValue::Vec3([
530                    lum + (c[0] - lum) * amount,
531                    lum + (c[1] - lum) * amount,
532                    lum + (c[2] - lum) * amount,
533                ])])
534            }
535            _ => None, // Not implemented for this node type
536        }
537    }
538
539    // -----------------------------------------------------------------------
540    // Common subexpression elimination
541    // -----------------------------------------------------------------------
542
543    /// Identify nodes that produce identical results and map duplicates to the canonical version.
544    fn find_common_subexpressions(
545        &self,
546        graph: &ShaderGraph,
547        sorted: &[NodeId],
548    ) -> HashMap<NodeId, NodeId> {
549        let mut cse_map: HashMap<NodeId, NodeId> = HashMap::new();
550        // Signature: (node_type_name, inputs_signature) -> canonical node ID
551        let mut signatures: HashMap<String, NodeId> = HashMap::new();
552
553        for &node_id in sorted {
554            let node = match graph.node(node_id) {
555                Some(n) => n,
556                None => continue,
557            };
558
559            // Build signature
560            let incoming = graph.incoming_connections(node_id);
561            let mut sig_parts: Vec<String> = vec![node.node_type.display_name().to_string()];
562
563            for (idx, socket) in node.inputs.iter().enumerate() {
564                let conn = incoming.iter().find(|c| c.to_socket == idx);
565                if let Some(c) = conn {
566                    // Resolve through CSE map
567                    let resolved = cse_map.get(&c.from_node).copied().unwrap_or(c.from_node);
568                    sig_parts.push(format!("c{}:{}", resolved.0, c.from_socket));
569                } else if let Some(def) = &socket.default_value {
570                    sig_parts.push(format!("d:{}", def.to_glsl()));
571                } else {
572                    sig_parts.push("none".to_string());
573                }
574            }
575
576            let signature = sig_parts.join("|");
577
578            if let Some(&canonical) = signatures.get(&signature) {
579                cse_map.insert(node_id, canonical);
580            } else {
581                signatures.insert(signature, node_id);
582            }
583        }
584
585        cse_map
586    }
587
588    // -----------------------------------------------------------------------
589    // Declaration collection
590    // -----------------------------------------------------------------------
591
592    fn collect_declarations(
593        &self,
594        graph: &ShaderGraph,
595        sorted: &[NodeId],
596    ) -> (Vec<UniformDecl>, Vec<VaryingDecl>) {
597        let mut uniforms: Vec<UniformDecl> = Vec::new();
598        let mut uniform_names: HashSet<String> = HashSet::new();
599        let mut varyings: Vec<VaryingDecl> = Vec::new();
600        let mut varying_names: HashSet<String> = HashSet::new();
601
602        // Always include standard uniforms
603        let standard_uniforms = vec![
604            ("u_time", DataType::Float, true),
605            ("u_model", DataType::Mat4, false),
606            ("u_view", DataType::Mat4, false),
607            ("u_projection", DataType::Mat4, false),
608            ("u_camera_pos", DataType::Vec3, false),
609            ("u_inv_model", DataType::Mat4, false),
610        ];
611        for (name, dt, animated) in standard_uniforms {
612            if uniform_names.insert(name.to_string()) {
613                uniforms.push(UniformDecl {
614                    name: name.to_string(),
615                    data_type: dt,
616                    default_value: None,
617                    is_animated: animated,
618                });
619            }
620        }
621
622        // Standard varyings
623        let standard_varyings = vec![
624            ("v_position", DataType::Vec3),
625            ("v_normal", DataType::Vec3),
626            ("v_uv", DataType::Vec2),
627        ];
628        for (name, dt) in standard_varyings {
629            if varying_names.insert(name.to_string()) {
630                varyings.push(VaryingDecl { name: name.to_string(), data_type: dt });
631            }
632        }
633
634        for &node_id in sorted {
635            let node = match graph.node(node_id) {
636                Some(n) => n,
637                None => continue,
638            };
639
640            match &node.node_type {
641                NodeType::Texture => {
642                    // Add sampler uniform
643                    let sampler_idx = node.inputs.get(1)
644                        .and_then(|s| s.default_value.as_ref())
645                        .and_then(|v| v.as_int())
646                        .unwrap_or(0);
647                    let name = format!("u_texture{}", sampler_idx);
648                    if uniform_names.insert(name.clone()) {
649                        uniforms.push(UniformDecl {
650                            name,
651                            data_type: DataType::Sampler2D,
652                            default_value: None,
653                            is_animated: false,
654                        });
655                    }
656                }
657                NodeType::GameStateVar => {
658                    // Add game state uniform
659                    let var_name = node.inputs.first()
660                        .and_then(|s| s.default_value.as_ref())
661                        .and_then(|v| v.as_string())
662                        .unwrap_or("game_var_0");
663                    let name = format!("u_gs_{}", var_name);
664                    if uniform_names.insert(name.clone()) {
665                        uniforms.push(UniformDecl {
666                            name,
667                            data_type: DataType::Float,
668                            default_value: Some(ParamValue::Float(0.0)),
669                            is_animated: false,
670                        });
671                    }
672                }
673                _ => {}
674            }
675
676            // Conditional node uniforms
677            if let Some(ref var_name) = node.conditional_var {
678                let name = format!("u_gs_{}", var_name);
679                if uniform_names.insert(name.clone()) {
680                    uniforms.push(UniformDecl {
681                        name,
682                        data_type: DataType::Float,
683                        default_value: Some(ParamValue::Float(0.0)),
684                        is_animated: false,
685                    });
686                }
687            }
688
689            // Check properties for any that need uniform binding
690            for (key, val) in &node.properties {
691                if key.starts_with("uniform_") {
692                    let name = format!("u_prop_{}_{}", node.id.0, key.trim_start_matches("uniform_"));
693                    if uniform_names.insert(name.clone()) {
694                        uniforms.push(UniformDecl {
695                            name,
696                            data_type: val.data_type(),
697                            default_value: Some(val.clone()),
698                            is_animated: self.options.animated_uniforms,
699                        });
700                    }
701                }
702            }
703        }
704
705        (uniforms, varyings)
706    }
707
708    // -----------------------------------------------------------------------
709    // GLSL code generation
710    // -----------------------------------------------------------------------
711
712    fn generate_glsl(
713        &self,
714        graph: &ShaderGraph,
715        sorted: &[NodeId],
716        folded: &HashMap<NodeId, Vec<ParamValue>>,
717        cse_map: &HashMap<NodeId, NodeId>,
718        uniforms: &[UniformDecl],
719        varyings: &[VaryingDecl],
720    ) -> (String, HashMap<(u64, usize), String>) {
721        let mut code = String::new();
722        let mut output_var_map: HashMap<(u64, usize), String> = HashMap::new();
723
724        // Header
725        code.push_str(&format!("#version {}\n", self.options.glsl_version));
726        code.push_str("precision highp float;\n\n");
727
728        // Uniform declarations
729        for u in uniforms {
730            code.push_str(&format!("uniform {} {};\n", u.data_type, u.name));
731        }
732        code.push('\n');
733
734        // Varying declarations
735        for v in varyings {
736            code.push_str(&format!("in {} {};\n", v.data_type, v.name));
737        }
738        code.push('\n');
739
740        // Output declarations for MRT
741        code.push_str("layout(location = 0) out vec4 fragColor;\n");
742        code.push_str("layout(location = 1) out vec4 fragEmission;\n");
743        code.push_str("layout(location = 2) out vec4 fragBloom;\n");
744        code.push_str("layout(location = 3) out vec4 fragNormal;\n");
745        code.push('\n');
746
747        // Main function
748        code.push_str("void main() {\n");
749
750        // Track which CSE nodes have already been emitted
751        let mut emitted_cse: HashSet<NodeId> = HashSet::new();
752
753        for &node_id in sorted {
754            // If this node is a CSE duplicate, skip it but register its output vars
755            if let Some(&canonical) = cse_map.get(&node_id) {
756                // Map this node's outputs to the canonical node's outputs
757                if let Some(node) = graph.node(node_id) {
758                    for (idx, _) in node.outputs.iter().enumerate() {
759                        if let Some(var) = output_var_map.get(&(canonical.0, idx)) {
760                            output_var_map.insert((node_id.0, idx), var.clone());
761                        }
762                    }
763                }
764                continue;
765            }
766
767            let node = match graph.node(node_id) {
768                Some(n) => n,
769                None => continue,
770            };
771
772            if !node.enabled {
773                continue;
774            }
775
776            // Handle constant-folded nodes
777            if let Some(folded_vals) = folded.get(&node_id) {
778                if self.options.debug_comments {
779                    code.push_str(&format!("  // [FOLDED] {} (node {})\n",
780                        node.node_type.display_name(), node_id.0));
781                }
782                for (idx, val) in folded_vals.iter().enumerate() {
783                    let var_name = format!("n{}_{}", node_id.0, idx);
784                    code.push_str(&format!("  {} {} = {};\n",
785                        val.data_type(), var_name, val.to_glsl()));
786                    output_var_map.insert((node_id.0, idx), var_name);
787                }
788                continue;
789            }
790
791            // Debug comment
792            if self.options.debug_comments {
793                code.push_str(&format!("  // {} (node {})\n",
794                    node.node_type.display_name(), node_id.0));
795            }
796
797            // Conditional open
798            let has_condition = self.options.enable_conditionals && node.conditional_var.is_some();
799            if has_condition {
800                let var_name = node.conditional_var.as_ref().unwrap();
801                code.push_str(&format!("  if (u_gs_{} > {}) {{\n",
802                    var_name, format_float_glsl(node.conditional_threshold)));
803            }
804
805            // Resolve input variables
806            let incoming = graph.incoming_connections(node_id);
807            let mut input_vars: Vec<String> = Vec::new();
808            for (idx, socket) in node.inputs.iter().enumerate() {
809                let conn = incoming.iter().find(|c| c.to_socket == idx);
810                if let Some(c) = conn {
811                    let resolved_from = cse_map.get(&c.from_node).copied().unwrap_or(c.from_node);
812                    if let Some(var) = output_var_map.get(&(resolved_from.0, c.from_socket)) {
813                        input_vars.push(var.clone());
814                    } else {
815                        // Fallback: use default
816                        input_vars.push(socket.default_value.as_ref()
817                            .map(|v| v.to_glsl())
818                            .unwrap_or_default());
819                    }
820                } else {
821                    input_vars.push(String::new());
822                }
823            }
824
825            // Generate GLSL for this node
826            let prefix = node.var_prefix();
827            let snippet = node.node_type.generate_glsl(&prefix, &input_vars);
828
829            let indent = if has_condition { "    " } else { "  " };
830            for line in &snippet.lines {
831                code.push_str(&format!("{}{}\n", indent, line));
832            }
833
834            // Register output variables
835            for (idx, var) in snippet.output_vars.iter().enumerate() {
836                output_var_map.insert((node_id.0, idx), var.clone());
837            }
838
839            let _ = emitted_cse.insert(node_id);
840
841            // Conditional close
842            if has_condition {
843                code.push_str("  }\n");
844            }
845        }
846
847        code.push_str("}\n");
848
849        (code, output_var_map)
850    }
851
852    fn generate_vertex_shader(&self, varyings: &[VaryingDecl]) -> String {
853        let mut code = String::new();
854        code.push_str(&format!("#version {}\n", self.options.glsl_version));
855        code.push_str("precision highp float;\n\n");
856
857        // Vertex attributes
858        code.push_str("layout(location = 0) in vec3 a_position;\n");
859        code.push_str("layout(location = 1) in vec3 a_normal;\n");
860        code.push_str("layout(location = 2) in vec2 a_uv;\n\n");
861
862        // Uniforms
863        code.push_str("uniform mat4 u_model;\n");
864        code.push_str("uniform mat4 u_view;\n");
865        code.push_str("uniform mat4 u_projection;\n\n");
866
867        // Varyings
868        for v in varyings {
869            code.push_str(&format!("out {} {};\n", v.data_type, v.name));
870        }
871        code.push('\n');
872
873        code.push_str("void main() {\n");
874        code.push_str("  vec4 world_pos = u_model * vec4(a_position, 1.0);\n");
875        code.push_str("  v_position = world_pos.xyz;\n");
876        code.push_str("  v_normal = normalize((u_model * vec4(a_normal, 0.0)).xyz);\n");
877        code.push_str("  v_uv = a_uv;\n");
878        code.push_str("  gl_Position = u_projection * u_view * world_pos;\n");
879        code.push_str("}\n");
880
881        code
882    }
883}
884
885fn format_float_glsl(v: f32) -> String {
886    if v == v.floor() && v.abs() < 1e9 {
887        format!("{:.1}", v)
888    } else {
889        format!("{}", v)
890    }
891}
892
893// ---------------------------------------------------------------------------
894// Convenience function
895// ---------------------------------------------------------------------------
896
897/// Compile a shader graph with default options.
898pub fn compile_graph(graph: &ShaderGraph) -> Result<CompiledShader, CompileError> {
899    ShaderCompiler::with_defaults().compile(graph)
900}
901
902/// Compile a shader graph with custom options.
903pub fn compile_graph_with(graph: &ShaderGraph, options: CompileOptions) -> Result<CompiledShader, CompileError> {
904    ShaderCompiler::new(options).compile(graph)
905}
906
907// ---------------------------------------------------------------------------
908// Type compatibility checking
909// ---------------------------------------------------------------------------
910
911/// Check if a source type can be implicitly cast to a destination type.
912pub fn types_compatible(from: DataType, to: DataType) -> bool {
913    if from == to {
914        return true;
915    }
916    // Implicit promotions
917    matches!((from, to),
918        (DataType::Float, DataType::Vec2)
919        | (DataType::Float, DataType::Vec3)
920        | (DataType::Float, DataType::Vec4)
921        | (DataType::Int, DataType::Float)
922        | (DataType::Bool, DataType::Float)
923        | (DataType::Bool, DataType::Int)
924    )
925}
926
927/// Generate GLSL cast expression from one type to another.
928pub fn generate_cast(expr: &str, from: DataType, to: DataType) -> String {
929    if from == to {
930        return expr.to_string();
931    }
932    match (from, to) {
933        (DataType::Float, DataType::Vec2) => format!("vec2({})", expr),
934        (DataType::Float, DataType::Vec3) => format!("vec3({})", expr),
935        (DataType::Float, DataType::Vec4) => format!("vec4({})", expr),
936        (DataType::Int, DataType::Float) => format!("float({})", expr),
937        (DataType::Bool, DataType::Float) => format!("float({})", expr),
938        (DataType::Bool, DataType::Int) => format!("int({})", expr),
939        (DataType::Vec2, DataType::Vec3) => format!("vec3({}, 0.0)", expr),
940        (DataType::Vec2, DataType::Vec4) => format!("vec4({}, 0.0, 1.0)", expr),
941        (DataType::Vec3, DataType::Vec4) => format!("vec4({}, 1.0)", expr),
942        (DataType::Vec4, DataType::Vec3) => format!("{}.xyz", expr),
943        (DataType::Vec3, DataType::Vec2) => format!("{}.xy", expr),
944        (DataType::Vec4, DataType::Vec2) => format!("{}.xy", expr),
945        (DataType::Vec3, DataType::Float) => format!("length({})", expr),
946        (DataType::Vec4, DataType::Float) => format!("{}.x", expr),
947        _ => format!("{}({})", to, expr), // best-effort
948    }
949}
950
951// ---------------------------------------------------------------------------
952// Shader variant cache
953// ---------------------------------------------------------------------------
954
955/// A cache for compiled shader variants, keyed by topology hash.
956pub struct ShaderVariantCache {
957    cache: HashMap<u64, CompiledShader>,
958}
959
960impl ShaderVariantCache {
961    pub fn new() -> Self {
962        Self { cache: HashMap::new() }
963    }
964
965    /// Get a cached shader by topology hash, or compile and cache it.
966    pub fn get_or_compile(
967        &mut self,
968        graph: &ShaderGraph,
969        compiler: &ShaderCompiler,
970    ) -> Result<&CompiledShader, CompileError> {
971        let hash = graph.topology_hash();
972        if !self.cache.contains_key(&hash) {
973            let compiled = compiler.compile(graph)?;
974            self.cache.insert(hash, compiled);
975        }
976        Ok(self.cache.get(&hash).unwrap())
977    }
978
979    /// Invalidate a specific cache entry.
980    pub fn invalidate(&mut self, hash: u64) {
981        self.cache.remove(&hash);
982    }
983
984    /// Clear the entire cache.
985    pub fn clear(&mut self) {
986        self.cache.clear();
987    }
988
989    /// Number of cached variants.
990    pub fn len(&self) -> usize {
991        self.cache.len()
992    }
993
994    /// Whether the cache is empty.
995    pub fn is_empty(&self) -> bool {
996        self.cache.is_empty()
997    }
998}
999
1000impl Default for ShaderVariantCache {
1001    fn default() -> Self {
1002        Self::new()
1003    }
1004}