Skip to main content

proof_engine/render/shader_graph/
compiler.rs

1//! Shader graph → GLSL compiler.
2//!
3//! Walks the topologically-sorted node list and emits a GLSL fragment shader.
4//! Each node type has a corresponding GLSL emission function that receives
5//! the names of its input variables and outputs the name of its output variable.
6
7use std::collections::HashMap;
8use super::{ShaderGraph, NodeId, GraphError};
9use super::nodes::{NodeType, SocketType};
10
11// ── CompiledShader ────────────────────────────────────────────────────────────
12
13/// The result of compiling a ShaderGraph.
14#[derive(Debug, Clone)]
15pub struct CompiledShader {
16    /// Complete GLSL fragment shader source.
17    pub fragment_source: String,
18    /// Vertex shader (pass-through).
19    pub vertex_source:   String,
20    /// Uniform declarations extracted from the graph.
21    pub uniforms:        Vec<UniformDecl>,
22    /// Named render targets this shader writes to.
23    pub render_targets:  Vec<String>,
24}
25
26#[derive(Debug, Clone)]
27pub struct UniformDecl {
28    pub name:     String,
29    pub glsl_type: String,
30    pub default:  String,
31}
32
33impl CompiledShader {
34    /// Return all uniform names for binding.
35    pub fn uniform_names(&self) -> Vec<&str> {
36        self.uniforms.iter().map(|u| u.name.as_str()).collect()
37    }
38}
39
40// ── GraphCompiler ─────────────────────────────────────────────────────────────
41
42pub struct GraphCompiler;
43
44impl GraphCompiler {
45    pub fn compile(graph: &ShaderGraph) -> Result<CompiledShader, GraphError> {
46        let order = graph.topological_order()?;
47
48        let mut uniforms:       Vec<UniformDecl>     = Vec::new();
49        let mut render_targets: Vec<String>          = Vec::new();
50        let mut body:           Vec<String>          = Vec::new();
51        // Map from (NodeId, slot) → variable name
52        let mut var_map: HashMap<(NodeId, u8), String> = HashMap::new();
53
54        // ── Preamble uniforms from parameters ──────────────────────────────────
55        for param in &graph.parameters {
56            uniforms.push(UniformDecl {
57                name:      param.glsl_name.clone(),
58                glsl_type: param.value.glsl_type().to_string(),
59                default:   param.value.glsl_literal(),
60            });
61        }
62
63        // ── Node emission ──────────────────────────────────────────────────────
64        for &node_id in &order {
65            let node = match graph.node(node_id) {
66                Some(n) => n,
67                None    => continue,
68            };
69
70            if node.muted {
71                // Muted: emit zero for all outputs
72                for (i, sock) in node.node_type.output_sockets().iter().enumerate() {
73                    let var = node.var_name(i);
74                    body.push(format!("{} {} = {};",
75                        sock.socket_type.glsl_type(), var, sock.socket_type.default_value()));
76                    var_map.insert((node_id, i as u8), var);
77                }
78                continue;
79            }
80
81            // Collect input variable names (from connected edges or constants)
82            let inputs = node.node_type.input_sockets();
83            let mut input_vars: Vec<String> = Vec::new();
84            for (slot_idx, sock) in inputs.iter().enumerate() {
85                let connected = graph.edges.iter()
86                    .find(|e| e.to_node == node_id && e.to_slot == slot_idx as u8);
87                let var = if let Some(edge) = connected {
88                    var_map.get(&(edge.from_node, edge.from_slot))
89                        .cloned()
90                        .unwrap_or_else(|| sock.default.clone())
91                } else if let Some(const_val) = node.constant_inputs.get(&slot_idx) {
92                    const_val.clone()
93                } else {
94                    sock.default.clone()
95                };
96                input_vars.push(var);
97            }
98
99            // Emit node code
100            Self::emit_node(node_id, &node.node_type, &input_vars, &mut body,
101                            &mut var_map, &mut uniforms, &mut render_targets);
102        }
103
104        // ── Output collection ──────────────────────────────────────────────────
105        let output_var = if let Some(out_id) = graph.output_node {
106            let out_node = graph.node(out_id)
107                .ok_or(GraphError::NodeNotFound(out_id))?;
108            // For output nodes, the input (color) is their first input
109            let edge = graph.edges.iter()
110                .find(|e| e.to_node == out_id && e.to_slot == 0);
111            if let Some(e) = edge {
112                var_map.get(&(e.from_node, e.from_slot))
113                    .cloned()
114                    .unwrap_or_else(|| "vec4(0.0, 0.0, 0.0, 1.0)".to_string())
115            } else {
116                out_node.constant_inputs.get(&0)
117                    .cloned()
118                    .unwrap_or_else(|| "vec4(0.0, 0.0, 0.0, 1.0)".to_string())
119            }
120        } else {
121            "vec4(0.0, 0.0, 0.0, 1.0)".to_string()
122        };
123
124        // ── Assemble fragment shader ───────────────────────────────────────────
125        let fragment_source = Self::assemble_fragment(&uniforms, &body, &output_var);
126        let vertex_source   = PASSTHROUGH_VERTEX.to_string();
127
128        Ok(CompiledShader { fragment_source, vertex_source, uniforms, render_targets })
129    }
130
131    fn emit_node(
132        node_id:        NodeId,
133        node_type:      &NodeType,
134        inputs:         &[String],
135        body:           &mut Vec<String>,
136        var_map:        &mut HashMap<(NodeId, u8), String>,
137        uniforms:       &mut Vec<UniformDecl>,
138        render_targets: &mut Vec<String>,
139    ) {
140        let out0 = format!("n{}_{}", node_id.0, 0);
141        let out1 = format!("n{}_{}", node_id.0, 1);
142        let out2 = format!("n{}_{}", node_id.0, 2);
143        let out3 = format!("n{}_{}", node_id.0, 3);
144
145        let i = |n: usize| inputs.get(n).cloned().unwrap_or_else(|| "0.0".to_string());
146
147        match node_type {
148            // ── Input nodes ────────────────────────────────────────────────────
149            NodeType::UvCoord => {
150                body.push(format!("vec2 {} = vUv;", out0));
151                var_map.insert((node_id, 0), out0);
152            }
153            NodeType::WorldPos => {
154                body.push(format!("vec3 {} = vWorldPos;", out0));
155                var_map.insert((node_id, 0), out0);
156            }
157            NodeType::CameraPos => {
158                body.push(format!("vec3 {} = uCameraPos;", out0));
159                var_map.insert((node_id, 0), out0);
160                uniforms.push(UniformDecl { name: "uCameraPos".into(), glsl_type: "vec3".into(), default: "vec3(0.0)".into() });
161            }
162            NodeType::Time => {
163                body.push(format!("float {} = uTime;", out0));
164                var_map.insert((node_id, 0), out0);
165                if !uniforms.iter().any(|u| u.name == "uTime") {
166                    uniforms.push(UniformDecl { name: "uTime".into(), glsl_type: "float".into(), default: "0.0".into() });
167                }
168            }
169            NodeType::Resolution => {
170                body.push(format!("vec2 {} = uResolution;", out0));
171                var_map.insert((node_id, 0), out0);
172                uniforms.push(UniformDecl { name: "uResolution".into(), glsl_type: "vec2".into(), default: "vec2(1.0)".into() });
173            }
174            NodeType::ConstFloat(v) => {
175                body.push(format!("float {} = {:.6};", out0, v));
176                var_map.insert((node_id, 0), out0);
177            }
178            NodeType::ConstVec2(x, y) => {
179                body.push(format!("vec2 {} = vec2({:.6}, {:.6});", out0, x, y));
180                var_map.insert((node_id, 0), out0);
181            }
182            NodeType::ConstVec3(x, y, z) => {
183                body.push(format!("vec3 {} = vec3({:.6}, {:.6}, {:.6});", out0, x, y, z));
184                var_map.insert((node_id, 0), out0);
185            }
186            NodeType::ConstVec4(x, y, z, w) => {
187                body.push(format!("vec4 {} = vec4({:.6},{:.6},{:.6},{:.6});", out0, x, y, z, w));
188                var_map.insert((node_id, 0), out0);
189            }
190            NodeType::VertexColor => {
191                body.push(format!("vec4 {} = vColor;", out0));
192                var_map.insert((node_id, 0), out0);
193            }
194            NodeType::ScreenCoord => {
195                body.push(format!("vec2 {} = gl_FragCoord.xy;", out0));
196                var_map.insert((node_id, 0), out0);
197            }
198            NodeType::Uniform(name, t) => {
199                let glsl_t = t.glsl_type();
200                uniforms.push(UniformDecl { name: name.clone(), glsl_type: glsl_t.to_string(), default: t.default_value().to_string() });
201                body.push(format!("{} {} = {};", glsl_t, out0, name));
202                var_map.insert((node_id, 0), out0);
203            }
204            // ── Math ───────────────────────────────────────────────────────────
205            NodeType::Add       => { body.push(format!("auto {} = {} + {};", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
206            NodeType::Subtract  => { body.push(format!("auto {} = {} - {};", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
207            NodeType::Multiply  => { body.push(format!("auto {} = {} * {};", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
208            NodeType::Divide    => { body.push(format!("auto {} = ({} != 0.0) ? {} / {} : 0.0;", out0, i(1), i(0), i(1))); var_map.insert((node_id,0),out0); }
209            NodeType::Power     => { body.push(format!("float {} = pow(max({}, 0.0), {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
210            NodeType::Sqrt      => { body.push(format!("auto {} = sqrt(max({}, 0.0));", out0, i(0))); var_map.insert((node_id,0),out0); }
211            NodeType::Abs       => { body.push(format!("auto {} = abs({});", out0, i(0))); var_map.insert((node_id,0),out0); }
212            NodeType::Sign      => { body.push(format!("auto {} = sign({});", out0, i(0))); var_map.insert((node_id,0),out0); }
213            NodeType::Floor     => { body.push(format!("auto {} = floor({});", out0, i(0))); var_map.insert((node_id,0),out0); }
214            NodeType::Ceil      => { body.push(format!("auto {} = ceil({});", out0, i(0))); var_map.insert((node_id,0),out0); }
215            NodeType::Fract     => { body.push(format!("auto {} = fract({});", out0, i(0))); var_map.insert((node_id,0),out0); }
216            NodeType::Min       => { body.push(format!("auto {} = min({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
217            NodeType::Max       => { body.push(format!("auto {} = max({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
218            NodeType::Clamp     => { body.push(format!("auto {} = clamp({}, {}, {});", out0, i(0), i(1), i(2))); var_map.insert((node_id,0),out0); }
219            NodeType::Mix       => { body.push(format!("auto {} = mix({}, {}, {});", out0, i(0), i(1), i(2))); var_map.insert((node_id,0),out0); }
220            NodeType::Smoothstep=> { body.push(format!("float {} = smoothstep({}, {}, {});", out0, i(0), i(1), i(2))); var_map.insert((node_id,0),out0); }
221            NodeType::Step      => { body.push(format!("auto {} = step({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
222            NodeType::Mod       => { body.push(format!("auto {} = mod({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
223            NodeType::Sin       => { body.push(format!("auto {} = sin({});", out0, i(0))); var_map.insert((node_id,0),out0); }
224            NodeType::Cos       => { body.push(format!("auto {} = cos({});", out0, i(0))); var_map.insert((node_id,0),out0); }
225            NodeType::Tan       => { body.push(format!("auto {} = tan({});", out0, i(0))); var_map.insert((node_id,0),out0); }
226            NodeType::Atan      => { body.push(format!("float {} = atan({});", out0, i(0))); var_map.insert((node_id,0),out0); }
227            NodeType::Exp       => { body.push(format!("auto {} = exp({});", out0, i(0))); var_map.insert((node_id,0),out0); }
228            NodeType::Log       => { body.push(format!("auto {} = log(max({}, 1e-6));", out0, i(0))); var_map.insert((node_id,0),out0); }
229            NodeType::Log2      => { body.push(format!("auto {} = log2(max({}, 1e-6));", out0, i(0))); var_map.insert((node_id,0),out0); }
230            NodeType::OneMinus  => { body.push(format!("auto {} = 1.0 - {};", out0, i(0))); var_map.insert((node_id,0),out0); }
231            NodeType::Saturate  => { body.push(format!("auto {} = clamp({}, 0.0, 1.0);", out0, i(0))); var_map.insert((node_id,0),out0); }
232            NodeType::Negate    => { body.push(format!("auto {} = -{};", out0, i(0))); var_map.insert((node_id,0),out0); }
233            NodeType::Reciprocal=> { body.push(format!("float {} = 1.0 / max({}, 1e-6);", out0, i(0))); var_map.insert((node_id,0),out0); }
234            NodeType::Dot       => { body.push(format!("float {} = dot({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
235            NodeType::Cross     => { body.push(format!("vec3 {} = cross({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
236            NodeType::Normalize => { body.push(format!("auto {} = normalize({});", out0, i(0))); var_map.insert((node_id,0),out0); }
237            NodeType::Length    => { body.push(format!("float {} = length({});", out0, i(0))); var_map.insert((node_id,0),out0); }
238            NodeType::LengthSquared => { body.push(format!("float {} = dot({},{});", out0, i(0), i(0))); var_map.insert((node_id,0),out0); }
239            NodeType::Distance  => { body.push(format!("float {} = distance({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
240            NodeType::Reflect   => { body.push(format!("auto {} = reflect({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
241            NodeType::Refract   => { body.push(format!("vec3 {} = refract({}, {}, {});", out0, i(0), i(1), i(2))); var_map.insert((node_id,0),out0); }
242            NodeType::Remap     => {
243                body.push(format!(
244                    "float {} = ({} - {}) / max({} - {}, 1e-6) * ({} - {}) + {};",
245                    out0, i(0), i(1), i(2), i(1), i(4), i(3), i(3)
246                ));
247                var_map.insert((node_id,0),out0);
248            }
249            // ── Vector ─────────────────────────────────────────────────────────
250            NodeType::CombineVec2 => { body.push(format!("vec2 {} = vec2({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
251            NodeType::CombineVec3 => { body.push(format!("vec3 {} = vec3({}, {}, {});", out0, i(0), i(1), i(2))); var_map.insert((node_id,0),out0); }
252            NodeType::CombineVec4 => { body.push(format!("vec4 {} = vec4({}, {});", out0, i(0), i(1))); var_map.insert((node_id,0),out0); }
253            NodeType::SplitVec2  => {
254                body.push(format!("float {} = ({}).x;", out0, i(0)));
255                body.push(format!("float {} = ({}).y;", out1, i(0)));
256                var_map.insert((node_id,0),out0);
257                var_map.insert((node_id,1),out1);
258            }
259            NodeType::SplitVec3  => {
260                body.push(format!("float {} = ({}).x;", out0, i(0)));
261                body.push(format!("float {} = ({}).y;", out1, i(0)));
262                body.push(format!("float {} = ({}).z;", out2, i(0)));
263                var_map.insert((node_id,0),out0); var_map.insert((node_id,1),out1); var_map.insert((node_id,2),out2);
264            }
265            NodeType::SplitVec4  => {
266                body.push(format!("float {} = ({}).x;", out0, i(0)));
267                body.push(format!("float {} = ({}).y;", out1, i(0)));
268                body.push(format!("float {} = ({}).z;", out2, i(0)));
269                body.push(format!("float {} = ({}).w;", out3, i(0)));
270                var_map.insert((node_id,0),out0); var_map.insert((node_id,1),out1);
271                var_map.insert((node_id,2),out2); var_map.insert((node_id,3),out3);
272            }
273            NodeType::Swizzle(s) => {
274                body.push(format!("auto {} = ({}).{};", out0, i(0), s));
275                var_map.insert((node_id,0),out0);
276            }
277            NodeType::RotateVec2 => {
278                body.push(format!(
279                    "vec2 {} = vec2(({} - {}).x * cos({}) - ({} - {}).y * sin({}), ({} - {}).x * sin({}) + ({} - {}).y * cos({})) + {};",
280                    out0, i(0), i(2), i(1), i(0), i(2), i(1), i(0), i(2), i(1), i(0), i(2), i(1), i(2)
281                ));
282                var_map.insert((node_id,0),out0);
283            }
284            // ── Color ──────────────────────────────────────────────────────────
285            NodeType::HsvToRgb => {
286                body.push(format!("vec3 {} = hsv2rgb({});", out0, i(0)));
287                var_map.insert((node_id,0),out0);
288            }
289            NodeType::RgbToHsv => {
290                body.push(format!("vec3 {} = rgb2hsv({});", out0, i(0)));
291                var_map.insert((node_id,0),out0);
292            }
293            NodeType::Luminance => {
294                body.push(format!("float {} = dot({}, vec3(0.2126, 0.7152, 0.0722));", out0, i(0)));
295                var_map.insert((node_id,0),out0);
296            }
297            NodeType::Saturation => {
298                body.push(format!(
299                    "vec3 {} = mix(vec3(dot({}, vec3(0.2126,0.7152,0.0722))), {}, clamp({}, 0.0, 2.0));",
300                    out0, i(0), i(0), i(1)
301                ));
302                var_map.insert((node_id,0),out0);
303            }
304            NodeType::HueRotate => {
305                body.push(format!("vec3 {} = rotateHue({}, radians({}));", out0, i(0), i(1)));
306                var_map.insert((node_id,0),out0);
307            }
308            NodeType::LinearToSrgb => {
309                body.push(format!("vec3 {} = pow(max({}, 0.0), vec3(1.0/2.2));", out0, i(0)));
310                var_map.insert((node_id,0),out0);
311            }
312            NodeType::SrgbToLinear => {
313                body.push(format!("vec3 {} = pow(max({}, 0.0), vec3(2.2));", out0, i(0)));
314                var_map.insert((node_id,0),out0);
315            }
316            NodeType::GammaCorrect => {
317                body.push(format!("vec3 {} = pow(max({}, 0.0), vec3(1.0 / max({}, 0.001)));", out0, i(0), i(1)));
318                var_map.insert((node_id,0),out0);
319            }
320            NodeType::ScreenBlend => {
321                body.push(format!("vec3 {} = 1.0 - (1.0 - {}) * (1.0 - {});", out0, i(0), i(1)));
322                var_map.insert((node_id,0),out0);
323            }
324            NodeType::OverlayBlend => {
325                body.push(format!(
326                    "vec3 {} = mix(2.0*{}*{}, 1.0 - 2.0*(1.0-{})*(1.0-{}), step(vec3(0.5), {}));",
327                    out0, i(0), i(1), i(0), i(1), i(0)
328                ));
329                var_map.insert((node_id,0),out0);
330            }
331            NodeType::HardLight => {
332                body.push(format!(
333                    "vec3 {} = mix(2.0*{}*{}, 1.0-2.0*(1.0-{})*(1.0-{}), step(vec3(0.5), {}));",
334                    out0, i(0), i(1), i(0), i(1), i(1)
335                ));
336                var_map.insert((node_id,0),out0);
337            }
338            NodeType::SoftLight => {
339                body.push(format!(
340                    "vec3 {} = mix({} - (1.0-2.0*{})*{}*(1.0-{}), {} + (2.0*{}-1.0)*(sqrt({})-{}), step(vec3(0.5), {}));",
341                    out0, i(0), i(1), i(0), i(0), i(0), i(1), i(0), i(0), i(1)
342                ));
343                var_map.insert((node_id,0),out0);
344            }
345            NodeType::ColorBurn => {
346                body.push(format!("vec3 {} = 1.0 - (1.0 - {}) / max({}, 0.001);", out0, i(0), i(1)));
347                var_map.insert((node_id,0),out0);
348            }
349            NodeType::ColorDodge => {
350                body.push(format!("vec3 {} = {} / max(1.0 - {}, 0.001);", out0, i(0), i(1)));
351                var_map.insert((node_id,0),out0);
352            }
353            NodeType::Difference => {
354                body.push(format!("vec3 {} = abs({} - {});", out0, i(0), i(1)));
355                var_map.insert((node_id,0),out0);
356            }
357            NodeType::Invert => {
358                body.push(format!("auto {} = 1.0 - {};", out0, i(0)));
359                var_map.insert((node_id,0),out0);
360            }
361            NodeType::Posterize => {
362                body.push(format!("vec3 {} = floor({} * {}) / max({}, 1.0);", out0, i(0), i(1), i(1)));
363                var_map.insert((node_id,0),out0);
364            }
365            NodeType::Duotone => {
366                body.push(format!(
367                    "vec3 {} = mix({}, {}, dot({}, vec3(0.2126, 0.7152, 0.0722)));",
368                    out0, i(1), i(2), i(0)
369                ));
370                var_map.insert((node_id,0),out0);
371            }
372            // ── Noise ──────────────────────────────────────────────────────────
373            NodeType::ValueNoise => {
374                body.push(format!("float {} = valueNoise({} * {});", out0, i(0), i(1)));
375                var_map.insert((node_id,0),out0);
376            }
377            NodeType::PerlinNoise => {
378                body.push(format!("float {} = perlinNoise({} * {});", out0, i(0), i(1)));
379                var_map.insert((node_id,0),out0);
380            }
381            NodeType::SimplexNoise => {
382                body.push(format!("float {} = simplexNoise({} * {});", out0, i(0), i(1)));
383                var_map.insert((node_id,0),out0);
384            }
385            NodeType::Fbm => {
386                body.push(format!("float {} = fbm({}, int({}), {}, {});", out0, i(0), i(1), i(2), i(3)));
387                var_map.insert((node_id,0),out0);
388            }
389            NodeType::Voronoi => {
390                body.push(format!("float {} = voronoi({} * {}, {}).x;", out0, i(0), i(1), i(2)));
391                var_map.insert((node_id,0),out0);
392            }
393            NodeType::Worley => {
394                body.push(format!("float {} = worley({} * {}).x;", out0, i(0), i(1)));
395                var_map.insert((node_id,0),out0);
396            }
397            NodeType::Checkerboard => {
398                body.push(format!(
399                    "float {} = mod(floor({}.x * {}) + floor({}.y * {}), 2.0);",
400                    out0, i(0), i(1), i(0), i(1)
401                ));
402                var_map.insert((node_id,0),out0);
403            }
404            NodeType::SineWave => {
405                body.push(format!(
406                    "float {} = {} * sin({} * {} * 6.28318 + {});",
407                    out0, i(2), i(0), i(1), i(3)
408                ));
409                var_map.insert((node_id,0),out0);
410            }
411            NodeType::RadialGradient => {
412                body.push(format!("float {} = distance({}, {}) / max({}, 0.001);", out0, i(0), i(1), i(2)));
413                var_map.insert((node_id,0),out0);
414            }
415            NodeType::LinearGradient => {
416                body.push(format!(
417                    "float {} = dot({} - vec2(0.5), vec2(cos({}), sin({}))) * 0.5 + 0.5;",
418                    out0, i(0), i(1), i(1)
419                ));
420                var_map.insert((node_id,0),out0);
421            }
422            NodeType::Spiral => {
423                body.push(format!(
424                    "float {} = fract(atan({}.y - 0.5, {}.x - 0.5) / 6.28318 * {} + length({} - vec2(0.5)) * {} - {} * {});",
425                    out0, i(0), i(0), i(1), i(0), i(1), i(2), i(3)
426                ));
427                var_map.insert((node_id,0),out0);
428            }
429            NodeType::Rings => {
430                body.push(format!(
431                    "float {} = fract(length({} - vec2(0.5)) * {}) < {};",
432                    out0, i(0), i(1), i(2)
433                ));
434                var_map.insert((node_id,0),out0);
435            }
436            NodeType::StarBurst => {
437                body.push(format!(
438                    "float {} = abs(sin(atan({}.y-0.5,{}.x-0.5)*{}*0.5)) * pow(length({}-vec2(0.5))*2.0, {});",
439                    out0, i(0), i(0), i(1), i(0), i(2)
440                ));
441                var_map.insert((node_id,0),out0);
442            }
443            NodeType::Grid => {
444                body.push(format!(
445                    "float {} = max(step(0.95, fract({}.x * {})), step(0.95, fract({}.y * {})));",
446                    out0, i(0), i(1), i(0), i(1)
447                ));
448                var_map.insert((node_id,0),out0);
449            }
450            // ── SDF ────────────────────────────────────────────────────────────
451            NodeType::SdfCircle => {
452                body.push(format!("float {} = length({} - {}) - {};", out0, i(0), i(1), i(2)));
453                var_map.insert((node_id,0),out0);
454            }
455            NodeType::SdfBox => {
456                body.push(format!(
457                    "{{ vec2 _q{} = abs({}-{}) - {}; float {} = length(max(_q{}, 0.0)) + min(max(_q{}.x,_q{}.y), 0.0) - {}; }}",
458                    node_id.0, i(0), i(1), i(2), out0, node_id.0, node_id.0, node_id.0, i(3)
459                ));
460                var_map.insert((node_id,0),out0);
461            }
462            NodeType::SdfLine => {
463                body.push(format!(
464                    "{{ vec2 _pa{} = {} - {}; vec2 _ba{} = {} - {}; float _h{} = clamp(dot(_pa{},_ba{})/dot(_ba{},_ba{}),0.0,1.0); float {} = length(_pa{} - _ba{}*_h{}); }}",
465                    node_id.0, i(0), i(1), node_id.0, i(2), i(1), node_id.0,
466                    node_id.0, node_id.0, node_id.0, node_id.0,
467                    out0, node_id.0, node_id.0, node_id.0
468                ));
469                var_map.insert((node_id,0),out0);
470            }
471            NodeType::SdfSmoothUnion => {
472                body.push(format!(
473                    "{{ float _h{} = clamp(0.5+0.5*({}-{})/{}, 0.0, 1.0); float {} = mix({},{},_h{}) - {}*_h{}*(1.0-_h{}); }}",
474                    node_id.0, i(1), i(0), i(2), out0, i(1), i(0), node_id.0, i(2), node_id.0, node_id.0
475                ));
476                var_map.insert((node_id,0),out0);
477            }
478            NodeType::SdfSmoothSubtract => {
479                body.push(format!(
480                    "{{ float _h{} = clamp(0.5-0.5*({} + {})/{}, 0.0, 1.0); float {} = mix({}, -{}, _h{}) + {}*_h{}*(1.0-_h{}); }}",
481                    node_id.0, i(1), i(0), i(2), out0, i(1), i(0), node_id.0, i(2), node_id.0, node_id.0
482                ));
483                var_map.insert((node_id,0),out0);
484            }
485            NodeType::SdfToAlpha => {
486                body.push(format!("float {} = step({}, -{});", out0, i(1), i(0)));
487                var_map.insert((node_id,0),out0);
488            }
489            NodeType::SdfToSoftAlpha => {
490                body.push(format!("float {} = 1.0 - smoothstep(-{} - {}, -{} + {}, {});", out0, i(2), i(1), i(2), i(1), i(0)));
491                var_map.insert((node_id,0),out0);
492            }
493            // ── Fractals ───────────────────────────────────────────────────────
494            NodeType::Mandelbrot => {
495                body.push(format!(
496                    r#"float {} = mandelbrotIter({} * {} - vec2(0.5), int({}));"#,
497                    out0, i(0), i(2), i(1)
498                ));
499                var_map.insert((node_id,0),out0);
500            }
501            NodeType::Julia => {
502                body.push(format!(
503                    "float {} = juliaIter({} * {}, vec2({}, {}), int({}));",
504                    out0, i(0), i(2), i(3), i(4), i(1)
505                ));
506                var_map.insert((node_id,0),out0);
507            }
508            // ── Vignette / Grain ───────────────────────────────────────────────
509            NodeType::Vignette => {
510                body.push(format!(
511                    "float {} = 1.0 - smoothstep(1.0 - {}, 1.0 - {} + {}, length(({} - vec2(0.5)) * 2.0));",
512                    out0, i(1), i(1), i(2), i(0)
513                ));
514                var_map.insert((node_id,0),out0);
515            }
516            NodeType::FilmGrain => {
517                body.push(format!(
518                    "float {} = {} * (fract(sin(dot({} + vec2({} * 123.456), vec2(12.9898, 78.233))) * 43758.5453) - 0.5) * 2.0;",
519                    out0, i(2), i(0), i(1)
520                ));
521                var_map.insert((node_id,0),out0);
522            }
523            NodeType::Scanlines => {
524                body.push(format!(
525                    "float {} = 1.0 - {} * (sin({}.y * {} * 3.14159) * 0.5 + 0.5);",
526                    out0, i(1), i(0), i(2)
527                ));
528                var_map.insert((node_id,0),out0);
529            }
530            NodeType::Pixelate => {
531                body.push(format!(
532                    "vec2 {} = floor({} * {}) / {};",
533                    out0, i(0), i(1), i(1)
534                ));
535                var_map.insert((node_id,0),out0);
536            }
537            NodeType::BarrelDistort => {
538                body.push(format!(
539                    "{{ vec2 _uv{} = {} - vec2(0.5); float _r{} = dot(_uv{},_uv{}); vec2 {} = {} + _uv{} * _r{} * {}; }}",
540                    node_id.0, i(0), node_id.0, node_id.0, node_id.0,
541                    out0, i(0), node_id.0, node_id.0, i(1)
542                ));
543                var_map.insert((node_id,0),out0);
544            }
545            NodeType::HeatHaze => {
546                body.push(format!(
547                    "vec2 {} = {} + vec2(sin({}.y * 20.0 + {} * {}) * {}, 0.0);",
548                    out0, i(0), i(0), i(1), i(3), i(2)
549                ));
550                var_map.insert((node_id,0),out0);
551            }
552            NodeType::GlitchOffset => {
553                body.push(format!(
554                    "vec2 {} = {} + vec2(step(0.9, fract(sin({}.y * 100.0 + {}) * 43758.5)) * {} * 0.1, 0.0);",
555                    out0, i(0), i(0), i(1), i(2)
556                ));
557                var_map.insert((node_id,0),out0);
558            }
559            // ── Logic ──────────────────────────────────────────────────────────
560            NodeType::IfGreater => {
561                body.push(format!("auto {} = ({} > {}) ? {} : {};", out0, i(0), i(1), i(2), i(3)));
562                var_map.insert((node_id,0),out0);
563            }
564            NodeType::IfLess => {
565                body.push(format!("auto {} = ({} < {}) ? {} : {};", out0, i(0), i(1), i(2), i(3)));
566                var_map.insert((node_id,0),out0);
567            }
568            NodeType::ConditionalBlend => {
569                body.push(format!(
570                    "auto {} = mix({}, {}, smoothstep(0.5 - {}, 0.5 + {}, {}));",
571                    out0, i(1), i(2), i(3), i(3), i(0)
572                ));
573                var_map.insert((node_id,0),out0);
574            }
575            // ── Output (handled separately above) ─────────────────────────────
576            NodeType::OutputColor | NodeType::OutputTarget(_) | NodeType::OutputWithBloom => {}
577            // ── Default for unimplemented nodes ───────────────────────────────
578            _ => {
579                body.push(format!("float {} = 0.0; // TODO: {}", out0, node_type.label()));
580                var_map.insert((node_id,0),out0);
581            }
582        }
583    }
584
585    fn assemble_fragment(
586        uniforms:     &[UniformDecl],
587        body:         &[String],
588        output_var:   &str,
589    ) -> String {
590        let mut src = String::from("#version 330 core\n");
591
592        // Varyings from vertex shader
593        src.push_str("in  vec2 vUv;\n");
594        src.push_str("in  vec3 vWorldPos;\n");
595        src.push_str("in  vec4 vColor;\n");
596        src.push_str("out vec4 fragColor;\n\n");
597
598        // Uniforms
599        for u in uniforms {
600            src.push_str(&format!("uniform {} {};\n", u.glsl_type, u.name));
601        }
602        src.push('\n');
603
604        // Standard math helpers
605        src.push_str(SHADER_HELPERS);
606        src.push('\n');
607
608        // Main function
609        src.push_str("void main() {\n");
610        for line in body {
611            src.push_str("    ");
612            src.push_str(line);
613            src.push('\n');
614        }
615        // Coerce output to vec4 if needed
616        src.push_str(&format!("    fragColor = vec4({});\n", output_var));
617        src.push_str("}\n");
618        src
619    }
620}
621
622// ── Passthrough vertex shader ─────────────────────────────────────────────────
623
624pub const PASSTHROUGH_VERTEX: &str = r#"
625#version 330 core
626layout(location = 0) in vec2 aPos;
627layout(location = 1) in vec2 aUv;
628layout(location = 2) in vec4 aColor;
629
630out vec2 vUv;
631out vec3 vWorldPos;
632out vec4 vColor;
633
634uniform mat4 uMVP;
635
636void main() {
637    vUv       = aUv;
638    vWorldPos = vec3(aPos, 0.0);
639    vColor    = aColor;
640    gl_Position = uMVP * vec4(aPos, 0.0, 1.0);
641}
642"#;
643
644// ── GLSL helper functions ─────────────────────────────────────────────────────
645
646pub const SHADER_HELPERS: &str = r#"
647// ── Color helpers ─────────────────────────────────────────────────────────────
648vec3 hsv2rgb(vec3 c) {
649    vec4 K = vec4(1.0, 2.0/3.0, 1.0/3.0, 3.0);
650    vec3 p = abs(fract(c.xxx + K.xyz) * 6.0 - K.www);
651    return c.z * mix(K.xxx, clamp(p - K.xxx, 0.0, 1.0), c.y);
652}
653vec3 rgb2hsv(vec3 c) {
654    vec4 K = vec4(0.0,-1.0/3.0,2.0/3.0,-1.0);
655    vec4 p = mix(vec4(c.bg, K.wz), vec4(c.gb, K.xy), step(c.b, c.g));
656    vec4 q = mix(vec4(p.xyw, c.r), vec4(c.r, p.yzx), step(p.x, c.r));
657    float d = q.x - min(q.w, q.y);
658    float e = 1e-10;
659    return vec3(abs(q.z+(q.w-q.y)/(6.0*d+e)), d/(q.x+e), q.x);
660}
661vec3 rotateHue(vec3 c, float angle) {
662    vec3 hsv = rgb2hsv(c);
663    hsv.x = fract(hsv.x + angle / 6.28318);
664    return hsv2rgb(hsv);
665}
666
667// ── Noise helpers ─────────────────────────────────────────────────────────────
668float hash(vec2 p) { return fract(sin(dot(p, vec2(127.1, 311.7))) * 43758.5453); }
669float valueNoise(vec2 p) {
670    vec2 i = floor(p); vec2 f = fract(p);
671    vec2 u = f*f*(3.0-2.0*f);
672    return mix(mix(hash(i),hash(i+vec2(1,0)),u.x),mix(hash(i+vec2(0,1)),hash(i+vec2(1,1)),u.x),u.y);
673}
674float perlinNoise(vec2 p) {
675    vec2 i = floor(p); vec2 f = fract(p);
676    vec2 u = f*f*f*(f*(f*6.0-15.0)+10.0);
677    float a = hash(i), b = hash(i+vec2(1,0)), c = hash(i+vec2(0,1)), d = hash(i+vec2(1,1));
678    return mix(mix(a,b,u.x),mix(c,d,u.x),u.y)*2.0-1.0;
679}
680float simplexNoise(vec2 v) {
681    const vec4 C = vec4(0.211324865405187,0.366025403784439,-0.577350269189626,0.024390243902439);
682    vec2 i = floor(v + dot(v, C.yy));
683    vec2 x0 = v - i + dot(i, C.xx);
684    vec2 i1 = (x0.x > x0.y) ? vec2(1.0,0.0) : vec2(0.0,1.0);
685    vec4 x12 = x0.xyxy + C.xxzz;
686    x12.xy -= i1;
687    i = mod(i, 289.0);
688    vec3 p = fract(((i.z+vec3(0,i1.y,1))*34.0+1.0)*(i.z+vec3(0,i1.y,1))/289.0)*(i.y+vec3(0,i1.x,1));
689    vec3 m = max(0.5 - vec3(dot(x0,x0), dot(x12.xy,x12.xy), dot(x12.zw,x12.zw)), 0.0);
690    m = m*m*m*m;
691    vec3 g; g.x = dot(vec2(cos(p.x*6.28318),sin(p.x*6.28318)),x0);
692    g.y = dot(vec2(cos(p.y*6.28318),sin(p.y*6.28318)),x12.xy);
693    g.z = dot(vec2(cos(p.z*6.28318),sin(p.z*6.28318)),x12.zw);
694    return 130.0 * dot(m, g);
695}
696float fbm(vec2 p, int octaves, float lacunarity, float gain) {
697    float v = 0.0, amp = 0.5;
698    for (int i = 0; i < 8; i++) {
699        if (i >= octaves) break;
700        v += amp * perlinNoise(p);
701        p *= lacunarity; amp *= gain;
702    }
703    return v;
704}
705vec2 voronoi(vec2 p, float jitter) {
706    vec2 i = floor(p); vec2 f = fract(p);
707    float d1 = 8.0, d2 = 8.0;
708    for (int y = -1; y <= 1; y++) for (int x = -1; x <= 1; x++) {
709        vec2 n = vec2(x,y); vec2 g = n + jitter*(hash(i+n)*2.0-1.0);
710        float d = length(g - f);
711        if (d < d1) { d2 = d1; d1 = d; }
712        else if (d < d2) { d2 = d; }
713    }
714    return vec2(d1, d2);
715}
716vec2 worley(vec2 p) { return voronoi(p, 1.0); }
717
718// ── Fractal helpers ───────────────────────────────────────────────────────────
719float mandelbrotIter(vec2 c, int maxIter) {
720    vec2 z = vec2(0.0);
721    for (int i = 0; i < 512; i++) {
722        if (i >= maxIter) break;
723        if (dot(z,z) > 4.0) return float(i) / float(maxIter);
724        z = vec2(z.x*z.x - z.y*z.y + c.x, 2.0*z.x*z.y + c.y);
725    }
726    return 0.0;
727}
728float juliaIter(vec2 z, vec2 c, int maxIter) {
729    for (int i = 0; i < 512; i++) {
730        if (i >= maxIter) break;
731        if (dot(z,z) > 4.0) return float(i) / float(maxIter);
732        z = vec2(z.x*z.x - z.y*z.y + c.x, 2.0*z.x*z.y + c.y);
733    }
734    return 0.0;
735}
736"#;
737
738// ── Tests ─────────────────────────────────────────────────────────────────────
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743    use crate::render::shader_graph::{ShaderGraph, GraphError};
744    use crate::render::shader_graph::nodes::NodeType;
745
746    fn simple_graph() -> ShaderGraph {
747        let mut g = ShaderGraph::new("test");
748        let uv  = g.add_node(NodeType::UvCoord);
749        let out = g.add_node(NodeType::OutputColor);
750        g.set_output(out);
751        // Connect UV as vec2, output expects vec4 — just test compilation doesn't crash
752        let sin = g.add_node(NodeType::Sin);
753        let _ = g.connect(uv, 0, sin, 0);
754        g
755    }
756
757    #[test]
758    fn test_compile_simple_graph() {
759        let g = simple_graph();
760        let result = g.compile();
761        // May fail due to type mismatches in this simple test but shouldn't panic
762        match result {
763            Ok(shader) => {
764                assert!(!shader.fragment_source.is_empty());
765                assert!(shader.fragment_source.contains("#version 330"));
766            }
767            Err(_) => {} // expected for incomplete graph
768        }
769    }
770
771    #[test]
772    fn test_uniform_decl() {
773        let u = UniformDecl {
774            name:      "uTime".to_string(),
775            glsl_type: "float".to_string(),
776            default:   "0.0".to_string(),
777        };
778        assert_eq!(u.name, "uTime");
779    }
780
781    #[test]
782    fn test_passthrough_vertex_has_version() {
783        assert!(PASSTHROUGH_VERTEX.contains("#version 330"));
784    }
785
786    #[test]
787    fn test_helpers_contain_hsv() {
788        assert!(SHADER_HELPERS.contains("hsv2rgb"));
789        assert!(SHADER_HELPERS.contains("perlinNoise"));
790        assert!(SHADER_HELPERS.contains("mandelbrotIter"));
791    }
792}