Skip to main content

proof_engine/render/shader_graph/
nodes.rs

1//! Shader graph node types and socket definitions.
2//!
3//! Each `NodeType` maps to a GLSL expression snippet. The `GraphCompiler`
4//! collects these snippets and assembles them into a complete shader.
5
6use super::NodeId;
7use std::collections::HashMap;
8
9// ── SocketType ────────────────────────────────────────────────────────────────
10
11/// GLSL data type flowing through a socket.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum SocketType {
14    Float,
15    Vec2,
16    Vec3,
17    Vec4,
18    Int,
19    Bool,
20    Sampler2D,
21    /// Any scalar or vector — resolved at compile time.
22    Any,
23}
24
25impl SocketType {
26    pub fn glsl_type(self) -> &'static str {
27        match self {
28            SocketType::Float     => "float",
29            SocketType::Vec2      => "vec2",
30            SocketType::Vec3      => "vec3",
31            SocketType::Vec4      => "vec4",
32            SocketType::Int       => "int",
33            SocketType::Bool      => "bool",
34            SocketType::Sampler2D => "sampler2D",
35            SocketType::Any       => "float",
36        }
37    }
38
39    pub fn default_value(self) -> &'static str {
40        match self {
41            SocketType::Float     => "0.0",
42            SocketType::Vec2      => "vec2(0.0)",
43            SocketType::Vec3      => "vec3(0.0)",
44            SocketType::Vec4      => "vec4(0.0, 0.0, 0.0, 1.0)",
45            SocketType::Int       => "0",
46            SocketType::Bool      => "false",
47            SocketType::Sampler2D => "/* sampler */",
48            SocketType::Any       => "0.0",
49        }
50    }
51
52    pub fn is_compatible_with(self, other: SocketType) -> bool {
53        if self == other { return true; }
54        if self == SocketType::Any || other == SocketType::Any { return true; }
55        // Vec3 ↔ Vec4 (auto-swizzle)
56        matches!((self, other),
57            (SocketType::Vec3, SocketType::Vec4) | (SocketType::Vec4, SocketType::Vec3) |
58            (SocketType::Float, SocketType::Vec2) | (SocketType::Vec2, SocketType::Float)
59        )
60    }
61}
62
63// ── NodeSocket ────────────────────────────────────────────────────────────────
64
65/// Definition of one input or output socket on a node.
66#[derive(Debug, Clone)]
67pub struct NodeSocket {
68    pub name:     String,
69    pub socket_type: SocketType,
70    pub required: bool,
71    /// Default value string (used when disconnected and no constant set).
72    pub default:  String,
73}
74
75impl NodeSocket {
76    pub fn required(name: &str, t: SocketType) -> Self {
77        Self { name: name.to_string(), socket_type: t, required: true, default: t.default_value().to_string() }
78    }
79    pub fn optional(name: &str, t: SocketType, default: &str) -> Self {
80        Self { name: name.to_string(), socket_type: t, required: false, default: default.to_string() }
81    }
82}
83
84// ── NodeType ──────────────────────────────────────────────────────────────────
85
86/// All supported shader node types.
87#[derive(Debug, Clone, PartialEq)]
88pub enum NodeType {
89    // ── Inputs ────────────────────────────────────────────────────────────────
90    /// UV texture coordinates (vec2).
91    UvCoord,
92    /// World-space position of the fragment (vec3).
93    WorldPos,
94    /// Camera position (vec3).
95    CameraPos,
96    /// Scene time (float).
97    Time,
98    /// Screen resolution (vec2).
99    Resolution,
100    /// Constant float.
101    ConstFloat(f32),
102    /// Constant vec2.
103    ConstVec2(f32, f32),
104    /// Constant vec3 (color or vector).
105    ConstVec3(f32, f32, f32),
106    /// Constant vec4.
107    ConstVec4(f32, f32, f32, f32),
108    /// A named uniform parameter.
109    Uniform(String, SocketType),
110    /// Sample a texture.
111    TextureSample,
112    /// Vertex color passed from vertex shader.
113    VertexColor,
114    /// Fragment screen-space coordinates (vec2).
115    ScreenCoord,
116
117    // ── Math ──────────────────────────────────────────────────────────────────
118    /// A + B
119    Add,
120    /// A - B
121    Subtract,
122    /// A * B
123    Multiply,
124    /// A / B (safe: returns 0 when B=0)
125    Divide,
126    /// Power: A^B
127    Power,
128    /// sqrt(A)
129    Sqrt,
130    /// abs(A)
131    Abs,
132    /// sign(A)
133    Sign,
134    /// floor(A)
135    Floor,
136    /// ceil(A)
137    Ceil,
138    /// fract(A)
139    Fract,
140    /// min(A, B)
141    Min,
142    /// max(A, B)
143    Max,
144    /// clamp(A, min, max)
145    Clamp,
146    /// mix(A, B, T)
147    Mix,
148    /// smoothstep(edge0, edge1, x)
149    Smoothstep,
150    /// step(edge, x)
151    Step,
152    /// dot(A, B)
153    Dot,
154    /// cross(A, B)
155    Cross,
156    /// normalize(A)
157    Normalize,
158    /// length(A)
159    Length,
160    /// distance(A, B)
161    Distance,
162    /// reflect(I, N)
163    Reflect,
164    /// refract(I, N, eta)
165    Refract,
166    /// mod(A, B)
167    Mod,
168    /// sin(A)
169    Sin,
170    /// cos(A)
171    Cos,
172    /// tan(A)
173    Tan,
174    /// atan(A) or atan(Y, X)
175    Atan,
176    /// exp(A)
177    Exp,
178    /// log(A)
179    Log,
180    /// log2(A)
181    Log2,
182    /// Remap: rescales A from [in_min,in_max] to [out_min,out_max]
183    Remap,
184    /// 1.0 - A
185    OneMinus,
186    /// Saturate: clamp to [0,1]
187    Saturate,
188    /// Negate: -A
189    Negate,
190    /// Reciprocal: 1.0 / A
191    Reciprocal,
192
193    // ── Vector ────────────────────────────────────────────────────────────────
194    /// Combine (float, float) → vec2
195    CombineVec2,
196    /// Combine (float, float, float) → vec3
197    CombineVec3,
198    /// Combine (vec3, float) → vec4
199    CombineVec4,
200    /// Split vec2 → (x, y)
201    SplitVec2,
202    /// Split vec3 → (x, y, z)
203    SplitVec3,
204    /// Split vec4 → (x, y, z, w)
205    SplitVec4,
206    /// Swizzle: extract named components (e.g. ".xyz", ".yyy")
207    Swizzle(String),
208    /// vec3 length squared
209    LengthSquared,
210    /// Rotate a 2D vector by angle (radians)
211    RotateVec2,
212
213    // ── Color ─────────────────────────────────────────────────────────────────
214    /// HSV → RGB conversion
215    HsvToRgb,
216    /// RGB → HSV conversion
217    RgbToHsv,
218    /// Luminance (grayscale value)
219    Luminance,
220    /// Saturation adjustment
221    Saturation,
222    /// Hue rotation
223    HueRotate,
224    /// Color burn blend
225    ColorBurn,
226    /// Color dodge blend
227    ColorDodge,
228    /// Screen blend: 1 - (1-A)(1-B)
229    ScreenBlend,
230    /// Overlay blend
231    OverlayBlend,
232    /// Hard light blend
233    HardLight,
234    /// Soft light blend
235    SoftLight,
236    /// Difference blend
237    Difference,
238    /// Gamma correction
239    GammaCorrect,
240    /// Linear to sRGB
241    LinearToSrgb,
242    /// sRGB to linear
243    SrgbToLinear,
244
245    // ── Noise and Patterns ────────────────────────────────────────────────────
246    /// Value noise
247    ValueNoise,
248    /// Gradient (Perlin) noise
249    PerlinNoise,
250    /// Simplex noise
251    SimplexNoise,
252    /// Fractal Brownian Motion (fBm)
253    Fbm,
254    /// Voronoi / cellular noise
255    Voronoi,
256    /// Worley noise (F1, F2, F1-F2)
257    Worley,
258    /// Checkerboard pattern
259    Checkerboard,
260    /// Polka dots pattern
261    PolkaDots,
262    /// Sine wave pattern
263    SineWave,
264    /// Square wave pattern
265    SquareWave,
266    /// Triangle wave pattern
267    TriangleWave,
268    /// Sawtooth wave pattern
269    SawtoothWave,
270    /// Grid pattern
271    Grid,
272    /// Radial gradient from center
273    RadialGradient,
274    /// Linear gradient along an axis
275    LinearGradient,
276    /// Spiral pattern
277    Spiral,
278    /// Concentric rings
279    Rings,
280    /// Star burst pattern
281    StarBurst,
282    /// Hexagonal tiling
283    HexTile,
284
285    // ── Effects ───────────────────────────────────────────────────────────────
286    /// Chromatic aberration (RGB split)
287    ChromaticAberration,
288    /// Screen-space edge detection
289    EdgeDetect,
290    /// Pixelation
291    Pixelate,
292    /// Barrel distortion
293    BarrelDistort,
294    /// Fish-eye distortion
295    FishEye,
296    /// Vignette darkening
297    Vignette,
298    /// Film grain
299    FilmGrain,
300    /// CRT scanlines
301    Scanlines,
302    /// Heat haze / refraction
303    HeatHaze,
304    /// Glitch offset
305    GlitchOffset,
306    /// Screen shake (UV offset)
307    ScreenShake,
308    /// Blur (box blur via sampling)
309    BoxBlur,
310    /// Sharpen filter
311    Sharpen,
312    /// Emboss filter
313    Emboss,
314    /// Invert colors
315    Invert,
316    /// Posterize
317    Posterize,
318    /// Duotone (shadows one color, highlights another)
319    Duotone,
320    /// Outline (find edge and colorize)
321    Outline,
322
323    // ── SDF (Signed Distance Fields) ─────────────────────────────────────────
324    /// SDF Circle
325    SdfCircle,
326    /// SDF Box
327    SdfBox,
328    /// SDF Line segment
329    SdfLine,
330    /// SDF Triangle
331    SdfTriangle,
332    /// SDF Ring/Annulus
333    SdfRing,
334    /// SDF Star
335    SdfStar,
336    /// SDF smooth union
337    SdfSmoothUnion,
338    /// SDF smooth subtraction
339    SdfSmoothSubtract,
340    /// SDF smooth intersection
341    SdfSmoothIntersect,
342    /// SDF → alpha (step at edge)
343    SdfToAlpha,
344    /// SDF → soft alpha (smoothstep at edge)
345    SdfToSoftAlpha,
346
347    // ── Attractor / Math-Driven ───────────────────────────────────────────────
348    /// Evaluate a Lorenz attractor at UV position
349    LorenzAttractor,
350    /// Mandelbrot set iteration count
351    Mandelbrot,
352    /// Julia set
353    Julia,
354    /// Burning Ship fractal
355    BurningShip,
356    /// Newton fractal
357    NewtonFractal,
358    /// Lyapunov exponent visualization
359    LyapunovViz,
360
361    // ── Logic / Conditional ────────────────────────────────────────────────────
362    /// if A > threshold, output B else C
363    IfGreater,
364    /// if A < threshold, output B else C
365    IfLess,
366    /// Conditional blend (threshold with smooth transition)
367    ConditionalBlend,
368    /// Boolean AND
369    BoolAnd,
370    /// Boolean OR
371    BoolOr,
372    /// Boolean NOT
373    BoolNot,
374
375    // ── Output ────────────────────────────────────────────────────────────────
376    /// Final fragment color output (must be vec4).
377    OutputColor,
378    /// Secondary output to a named render target.
379    OutputTarget(String),
380    /// Output to bloom buffer simultaneously.
381    OutputWithBloom,
382}
383
384impl NodeType {
385    pub fn label(&self) -> &str {
386        match self {
387            NodeType::UvCoord            => "UV Coord",
388            NodeType::WorldPos           => "World Pos",
389            NodeType::CameraPos          => "Camera Pos",
390            NodeType::Time               => "Time",
391            NodeType::Resolution         => "Resolution",
392            NodeType::ConstFloat(_)      => "Float",
393            NodeType::ConstVec2(_, _)    => "Vec2",
394            NodeType::ConstVec3(..)      => "Vec3",
395            NodeType::ConstVec4(..)      => "Vec4",
396            NodeType::Uniform(n, _)      => n.as_str(),
397            NodeType::TextureSample      => "Texture Sample",
398            NodeType::VertexColor        => "Vertex Color",
399            NodeType::ScreenCoord        => "Screen Coord",
400            NodeType::Add                => "Add",
401            NodeType::Subtract           => "Subtract",
402            NodeType::Multiply           => "Multiply",
403            NodeType::Divide             => "Divide",
404            NodeType::Power              => "Power",
405            NodeType::Sqrt               => "Sqrt",
406            NodeType::Abs                => "Abs",
407            NodeType::Sign               => "Sign",
408            NodeType::Floor              => "Floor",
409            NodeType::Ceil               => "Ceil",
410            NodeType::Fract              => "Fract",
411            NodeType::Min                => "Min",
412            NodeType::Max                => "Max",
413            NodeType::Clamp              => "Clamp",
414            NodeType::Mix                => "Mix",
415            NodeType::Smoothstep         => "Smoothstep",
416            NodeType::Step               => "Step",
417            NodeType::Dot                => "Dot",
418            NodeType::Cross              => "Cross",
419            NodeType::Normalize          => "Normalize",
420            NodeType::Length             => "Length",
421            NodeType::Distance           => "Distance",
422            NodeType::Reflect            => "Reflect",
423            NodeType::Refract            => "Refract",
424            NodeType::Mod                => "Mod",
425            NodeType::Sin                => "Sin",
426            NodeType::Cos                => "Cos",
427            NodeType::Tan                => "Tan",
428            NodeType::Atan               => "Atan",
429            NodeType::Exp                => "Exp",
430            NodeType::Log                => "Log",
431            NodeType::Log2               => "Log2",
432            NodeType::Remap              => "Remap",
433            NodeType::OneMinus           => "One Minus",
434            NodeType::Saturate           => "Saturate",
435            NodeType::Negate             => "Negate",
436            NodeType::Reciprocal         => "Reciprocal",
437            NodeType::CombineVec2        => "Combine Vec2",
438            NodeType::CombineVec3        => "Combine Vec3",
439            NodeType::CombineVec4        => "Combine Vec4",
440            NodeType::SplitVec2          => "Split Vec2",
441            NodeType::SplitVec3          => "Split Vec3",
442            NodeType::SplitVec4          => "Split Vec4",
443            NodeType::Swizzle(s)         => s.as_str(),
444            NodeType::LengthSquared      => "Length²",
445            NodeType::RotateVec2         => "Rotate Vec2",
446            NodeType::HsvToRgb           => "HSV → RGB",
447            NodeType::RgbToHsv           => "RGB → HSV",
448            NodeType::Luminance          => "Luminance",
449            NodeType::Saturation         => "Saturation",
450            NodeType::HueRotate          => "Hue Rotate",
451            NodeType::ColorBurn          => "Color Burn",
452            NodeType::ColorDodge         => "Color Dodge",
453            NodeType::ScreenBlend        => "Screen",
454            NodeType::OverlayBlend       => "Overlay",
455            NodeType::HardLight          => "Hard Light",
456            NodeType::SoftLight          => "Soft Light",
457            NodeType::Difference         => "Difference",
458            NodeType::GammaCorrect       => "Gamma",
459            NodeType::LinearToSrgb       => "Linear→sRGB",
460            NodeType::SrgbToLinear       => "sRGB→Linear",
461            NodeType::ValueNoise         => "Value Noise",
462            NodeType::PerlinNoise        => "Perlin Noise",
463            NodeType::SimplexNoise       => "Simplex Noise",
464            NodeType::Fbm                => "fBm",
465            NodeType::Voronoi            => "Voronoi",
466            NodeType::Worley             => "Worley",
467            NodeType::Checkerboard       => "Checkerboard",
468            NodeType::PolkaDots          => "Polka Dots",
469            NodeType::SineWave           => "Sine Wave",
470            NodeType::SquareWave         => "Square Wave",
471            NodeType::TriangleWave       => "Triangle Wave",
472            NodeType::SawtoothWave       => "Sawtooth Wave",
473            NodeType::Grid               => "Grid",
474            NodeType::RadialGradient     => "Radial Gradient",
475            NodeType::LinearGradient     => "Linear Gradient",
476            NodeType::Spiral             => "Spiral",
477            NodeType::Rings              => "Rings",
478            NodeType::StarBurst          => "Star Burst",
479            NodeType::HexTile            => "Hex Tile",
480            NodeType::ChromaticAberration => "Chromatic Ab.",
481            NodeType::EdgeDetect         => "Edge Detect",
482            NodeType::Pixelate           => "Pixelate",
483            NodeType::BarrelDistort      => "Barrel Distort",
484            NodeType::FishEye            => "Fish Eye",
485            NodeType::Vignette           => "Vignette",
486            NodeType::FilmGrain          => "Film Grain",
487            NodeType::Scanlines          => "Scanlines",
488            NodeType::HeatHaze           => "Heat Haze",
489            NodeType::GlitchOffset       => "Glitch",
490            NodeType::ScreenShake        => "Screen Shake",
491            NodeType::BoxBlur            => "Box Blur",
492            NodeType::Sharpen            => "Sharpen",
493            NodeType::Emboss             => "Emboss",
494            NodeType::Invert             => "Invert",
495            NodeType::Posterize          => "Posterize",
496            NodeType::Duotone            => "Duotone",
497            NodeType::Outline            => "Outline",
498            NodeType::SdfCircle          => "SDF Circle",
499            NodeType::SdfBox             => "SDF Box",
500            NodeType::SdfLine            => "SDF Line",
501            NodeType::SdfTriangle        => "SDF Triangle",
502            NodeType::SdfRing            => "SDF Ring",
503            NodeType::SdfStar            => "SDF Star",
504            NodeType::SdfSmoothUnion     => "SDF Union",
505            NodeType::SdfSmoothSubtract  => "SDF Subtract",
506            NodeType::SdfSmoothIntersect => "SDF Intersect",
507            NodeType::SdfToAlpha         => "SDF Alpha",
508            NodeType::SdfToSoftAlpha     => "SDF Soft Alpha",
509            NodeType::LorenzAttractor    => "Lorenz",
510            NodeType::Mandelbrot         => "Mandelbrot",
511            NodeType::Julia              => "Julia",
512            NodeType::BurningShip        => "Burning Ship",
513            NodeType::NewtonFractal      => "Newton",
514            NodeType::LyapunovViz        => "Lyapunov",
515            NodeType::IfGreater          => "If Greater",
516            NodeType::IfLess             => "If Less",
517            NodeType::ConditionalBlend   => "Cond. Blend",
518            NodeType::BoolAnd            => "AND",
519            NodeType::BoolOr             => "OR",
520            NodeType::BoolNot            => "NOT",
521            NodeType::OutputColor        => "Output Color",
522            NodeType::OutputTarget(n)    => n.as_str(),
523            NodeType::OutputWithBloom    => "Output+Bloom",
524        }
525    }
526
527    /// Input socket definitions.
528    pub fn input_sockets(&self) -> Vec<NodeSocket> {
529        match self {
530            NodeType::Add | NodeType::Subtract | NodeType::Multiply |
531            NodeType::Divide | NodeType::Power | NodeType::Mod |
532            NodeType::Min | NodeType::Max | NodeType::Dot | NodeType::Distance => vec![
533                NodeSocket::required("A", SocketType::Any),
534                NodeSocket::required("B", SocketType::Any),
535            ],
536            NodeType::Mix => vec![
537                NodeSocket::required("A",  SocketType::Any),
538                NodeSocket::required("B",  SocketType::Any),
539                NodeSocket::required("T",  SocketType::Float),
540            ],
541            NodeType::Smoothstep => vec![
542                NodeSocket::optional("Edge0", SocketType::Float, "0.0"),
543                NodeSocket::optional("Edge1", SocketType::Float, "1.0"),
544                NodeSocket::required("X",     SocketType::Any),
545            ],
546            NodeType::Step => vec![
547                NodeSocket::optional("Edge", SocketType::Float, "0.5"),
548                NodeSocket::required("X",    SocketType::Any),
549            ],
550            NodeType::Clamp => vec![
551                NodeSocket::required("X",   SocketType::Any),
552                NodeSocket::optional("Min", SocketType::Float, "0.0"),
553                NodeSocket::optional("Max", SocketType::Float, "1.0"),
554            ],
555            NodeType::Remap => vec![
556                NodeSocket::required("X",      SocketType::Any),
557                NodeSocket::optional("InMin",  SocketType::Float, "0.0"),
558                NodeSocket::optional("InMax",  SocketType::Float, "1.0"),
559                NodeSocket::optional("OutMin", SocketType::Float, "0.0"),
560                NodeSocket::optional("OutMax", SocketType::Float, "1.0"),
561            ],
562            NodeType::Sqrt | NodeType::Abs | NodeType::Sign |
563            NodeType::Floor | NodeType::Ceil | NodeType::Fract |
564            NodeType::Normalize | NodeType::Length | NodeType::LengthSquared |
565            NodeType::OneMinus | NodeType::Saturate | NodeType::Negate |
566            NodeType::Reciprocal | NodeType::Exp | NodeType::Log |
567            NodeType::Log2 | NodeType::Sin | NodeType::Cos |
568            NodeType::Tan | NodeType::Atan | NodeType::HsvToRgb |
569            NodeType::RgbToHsv | NodeType::Luminance | NodeType::Invert |
570            NodeType::LinearToSrgb | NodeType::SrgbToLinear |
571            NodeType::BoolNot => vec![
572                NodeSocket::required("In", SocketType::Any),
573            ],
574            NodeType::Reflect | NodeType::Cross | NodeType::BoolAnd | NodeType::BoolOr => vec![
575                NodeSocket::required("A", SocketType::Any),
576                NodeSocket::required("B", SocketType::Any),
577            ],
578            NodeType::Refract => vec![
579                NodeSocket::required("I",   SocketType::Vec3),
580                NodeSocket::required("N",   SocketType::Vec3),
581                NodeSocket::optional("Eta", SocketType::Float, "1.5"),
582            ],
583            NodeType::CombineVec2 => vec![
584                NodeSocket::required("X", SocketType::Float),
585                NodeSocket::required("Y", SocketType::Float),
586            ],
587            NodeType::CombineVec3 => vec![
588                NodeSocket::required("X", SocketType::Float),
589                NodeSocket::required("Y", SocketType::Float),
590                NodeSocket::required("Z", SocketType::Float),
591            ],
592            NodeType::CombineVec4 => vec![
593                NodeSocket::required("RGB", SocketType::Vec3),
594                NodeSocket::required("A",   SocketType::Float),
595            ],
596            NodeType::SplitVec2 | NodeType::SplitVec3 | NodeType::SplitVec4 |
597            NodeType::Swizzle(_) => vec![
598                NodeSocket::required("In", SocketType::Any),
599            ],
600            NodeType::RotateVec2 => vec![
601                NodeSocket::required("UV",    SocketType::Vec2),
602                NodeSocket::optional("Angle", SocketType::Float, "0.0"),
603                NodeSocket::optional("Center",SocketType::Vec2, "vec2(0.5)"),
604            ],
605            NodeType::Saturation => vec![
606                NodeSocket::required("Color", SocketType::Vec3),
607                NodeSocket::optional("Sat",   SocketType::Float, "1.0"),
608            ],
609            NodeType::HueRotate => vec![
610                NodeSocket::required("Color",   SocketType::Vec3),
611                NodeSocket::optional("Degrees", SocketType::Float, "0.0"),
612            ],
613            NodeType::GammaCorrect => vec![
614                NodeSocket::required("Color", SocketType::Vec3),
615                NodeSocket::optional("Gamma", SocketType::Float, "2.2"),
616            ],
617            NodeType::ColorBurn | NodeType::ColorDodge | NodeType::ScreenBlend |
618            NodeType::OverlayBlend | NodeType::HardLight | NodeType::SoftLight |
619            NodeType::Difference => vec![
620                NodeSocket::required("A", SocketType::Vec3),
621                NodeSocket::required("B", SocketType::Vec3),
622            ],
623            NodeType::ValueNoise | NodeType::PerlinNoise | NodeType::SimplexNoise => vec![
624                NodeSocket::required("UV",    SocketType::Vec2),
625                NodeSocket::optional("Scale", SocketType::Float, "1.0"),
626                NodeSocket::optional("Seed",  SocketType::Float, "0.0"),
627            ],
628            NodeType::Fbm => vec![
629                NodeSocket::required("UV",       SocketType::Vec2),
630                NodeSocket::optional("Octaves",  SocketType::Float, "4.0"),
631                NodeSocket::optional("Lacunarity",SocketType::Float,"2.0"),
632                NodeSocket::optional("Gain",     SocketType::Float, "0.5"),
633            ],
634            NodeType::Voronoi | NodeType::Worley => vec![
635                NodeSocket::required("UV",    SocketType::Vec2),
636                NodeSocket::optional("Scale", SocketType::Float, "1.0"),
637                NodeSocket::optional("Jitter",SocketType::Float, "1.0"),
638            ],
639            NodeType::Checkerboard | NodeType::PolkaDots | NodeType::Grid => vec![
640                NodeSocket::required("UV",    SocketType::Vec2),
641                NodeSocket::optional("Scale", SocketType::Float, "10.0"),
642            ],
643            NodeType::SineWave | NodeType::SquareWave | NodeType::TriangleWave |
644            NodeType::SawtoothWave => vec![
645                NodeSocket::required("UV",        SocketType::Any),
646                NodeSocket::optional("Frequency", SocketType::Float, "1.0"),
647                NodeSocket::optional("Amplitude", SocketType::Float, "1.0"),
648                NodeSocket::optional("Phase",     SocketType::Float, "0.0"),
649            ],
650            NodeType::RadialGradient => vec![
651                NodeSocket::required("UV",     SocketType::Vec2),
652                NodeSocket::optional("Center", SocketType::Vec2, "vec2(0.5)"),
653                NodeSocket::optional("Radius", SocketType::Float, "0.5"),
654            ],
655            NodeType::LinearGradient => vec![
656                NodeSocket::required("UV",     SocketType::Vec2),
657                NodeSocket::optional("Angle",  SocketType::Float, "0.0"),
658            ],
659            NodeType::Spiral => vec![
660                NodeSocket::required("UV",     SocketType::Vec2),
661                NodeSocket::optional("Arms",   SocketType::Float, "3.0"),
662                NodeSocket::optional("Speed",  SocketType::Float, "1.0"),
663                NodeSocket::optional("Time",   SocketType::Float, "0.0"),
664            ],
665            NodeType::Rings => vec![
666                NodeSocket::required("UV",     SocketType::Vec2),
667                NodeSocket::optional("Count",  SocketType::Float, "5.0"),
668                NodeSocket::optional("Width",  SocketType::Float, "0.5"),
669            ],
670            NodeType::StarBurst => vec![
671                NodeSocket::required("UV",    SocketType::Vec2),
672                NodeSocket::optional("Arms",  SocketType::Float, "8.0"),
673                NodeSocket::optional("Sharp", SocketType::Float, "0.5"),
674            ],
675            NodeType::HexTile => vec![
676                NodeSocket::required("UV",    SocketType::Vec2),
677                NodeSocket::optional("Scale", SocketType::Float, "10.0"),
678            ],
679            NodeType::Vignette => vec![
680                NodeSocket::required("UV",       SocketType::Vec2),
681                NodeSocket::optional("Strength", SocketType::Float, "0.5"),
682                NodeSocket::optional("Feather",  SocketType::Float, "0.5"),
683            ],
684            NodeType::FilmGrain => vec![
685                NodeSocket::required("UV",       SocketType::Vec2),
686                NodeSocket::optional("Time",     SocketType::Float, "0.0"),
687                NodeSocket::optional("Strength", SocketType::Float, "0.05"),
688            ],
689            NodeType::Scanlines => vec![
690                NodeSocket::required("UV",        SocketType::Vec2),
691                NodeSocket::optional("Intensity", SocketType::Float, "0.2"),
692                NodeSocket::optional("Count",     SocketType::Float, "300.0"),
693            ],
694            NodeType::ChromaticAberration => vec![
695                NodeSocket::required("UV",       SocketType::Vec2),
696                NodeSocket::optional("Strength", SocketType::Float, "0.005"),
697            ],
698            NodeType::EdgeDetect | NodeType::Sharpen | NodeType::Emboss => vec![
699                NodeSocket::required("Tex",       SocketType::Sampler2D),
700                NodeSocket::required("UV",        SocketType::Vec2),
701                NodeSocket::optional("Strength",  SocketType::Float, "1.0"),
702                NodeSocket::optional("TexelSize", SocketType::Vec2, "vec2(0.001)"),
703            ],
704            NodeType::Pixelate => vec![
705                NodeSocket::required("UV",        SocketType::Vec2),
706                NodeSocket::optional("Resolution",SocketType::Float,"64.0"),
707            ],
708            NodeType::BarrelDistort | NodeType::FishEye => vec![
709                NodeSocket::required("UV",        SocketType::Vec2),
710                NodeSocket::optional("Strength",  SocketType::Float, "0.3"),
711            ],
712            NodeType::HeatHaze => vec![
713                NodeSocket::required("UV",        SocketType::Vec2),
714                NodeSocket::optional("Time",      SocketType::Float, "0.0"),
715                NodeSocket::optional("Strength",  SocketType::Float, "0.02"),
716                NodeSocket::optional("Speed",     SocketType::Float, "1.0"),
717            ],
718            NodeType::GlitchOffset => vec![
719                NodeSocket::required("UV",        SocketType::Vec2),
720                NodeSocket::optional("Time",      SocketType::Float, "0.0"),
721                NodeSocket::optional("Intensity", SocketType::Float, "0.5"),
722                NodeSocket::optional("Seed",      SocketType::Float, "0.0"),
723            ],
724            NodeType::BoxBlur => vec![
725                NodeSocket::required("Tex",      SocketType::Sampler2D),
726                NodeSocket::required("UV",       SocketType::Vec2),
727                NodeSocket::optional("Radius",   SocketType::Float,"1.0"),
728                NodeSocket::optional("TexelSize",SocketType::Vec2, "vec2(0.001)"),
729            ],
730            NodeType::Posterize => vec![
731                NodeSocket::required("Color",  SocketType::Vec3),
732                NodeSocket::optional("Steps",  SocketType::Float, "4.0"),
733            ],
734            NodeType::Duotone => vec![
735                NodeSocket::required("Color",     SocketType::Vec3),
736                NodeSocket::optional("Shadow",    SocketType::Vec3, "vec3(0.0,0.0,0.3)"),
737                NodeSocket::optional("Highlight", SocketType::Vec3, "vec3(1.0,0.8,0.2)"),
738            ],
739            NodeType::Outline => vec![
740                NodeSocket::required("SDF",      SocketType::Float),
741                NodeSocket::optional("Color",    SocketType::Vec3, "vec3(1.0)"),
742                NodeSocket::optional("Thickness",SocketType::Float,"0.02"),
743            ],
744            NodeType::SdfCircle => vec![
745                NodeSocket::required("UV",     SocketType::Vec2),
746                NodeSocket::optional("Center", SocketType::Vec2, "vec2(0.5)"),
747                NodeSocket::optional("Radius", SocketType::Float, "0.3"),
748            ],
749            NodeType::SdfBox => vec![
750                NodeSocket::required("UV",     SocketType::Vec2),
751                NodeSocket::optional("Center", SocketType::Vec2, "vec2(0.5)"),
752                NodeSocket::optional("Size",   SocketType::Vec2, "vec2(0.3)"),
753                NodeSocket::optional("Corner", SocketType::Float, "0.0"),
754            ],
755            NodeType::SdfLine => vec![
756                NodeSocket::required("UV", SocketType::Vec2),
757                NodeSocket::required("A",  SocketType::Vec2),
758                NodeSocket::required("B",  SocketType::Vec2),
759            ],
760            NodeType::SdfTriangle => vec![
761                NodeSocket::required("UV", SocketType::Vec2),
762                NodeSocket::required("A",  SocketType::Vec2),
763                NodeSocket::required("B",  SocketType::Vec2),
764                NodeSocket::required("C",  SocketType::Vec2),
765            ],
766            NodeType::SdfRing => vec![
767                NodeSocket::required("UV",         SocketType::Vec2),
768                NodeSocket::optional("Center",     SocketType::Vec2,  "vec2(0.5)"),
769                NodeSocket::optional("OuterRadius",SocketType::Float, "0.4"),
770                NodeSocket::optional("InnerRadius",SocketType::Float, "0.3"),
771            ],
772            NodeType::SdfStar => vec![
773                NodeSocket::required("UV",     SocketType::Vec2),
774                NodeSocket::optional("Points", SocketType::Float, "5.0"),
775                NodeSocket::optional("Inner",  SocketType::Float, "0.2"),
776                NodeSocket::optional("Outer",  SocketType::Float, "0.4"),
777            ],
778            NodeType::SdfSmoothUnion | NodeType::SdfSmoothSubtract | NodeType::SdfSmoothIntersect => vec![
779                NodeSocket::required("A", SocketType::Float),
780                NodeSocket::required("B", SocketType::Float),
781                NodeSocket::optional("K", SocketType::Float, "0.1"),
782            ],
783            NodeType::SdfToAlpha | NodeType::SdfToSoftAlpha => vec![
784                NodeSocket::required("SDF",       SocketType::Float),
785                NodeSocket::optional("Threshold", SocketType::Float, "0.0"),
786                NodeSocket::optional("Feather",   SocketType::Float, "0.01"),
787            ],
788            NodeType::Mandelbrot | NodeType::Julia | NodeType::BurningShip |
789            NodeType::NewtonFractal => vec![
790                NodeSocket::required("UV",       SocketType::Vec2),
791                NodeSocket::optional("MaxIter",  SocketType::Float, "100.0"),
792                NodeSocket::optional("Zoom",     SocketType::Float, "1.0"),
793                NodeSocket::optional("Cx",       SocketType::Float, "-0.7"),
794                NodeSocket::optional("Cy",       SocketType::Float, "0.27"),
795            ],
796            NodeType::LorenzAttractor => vec![
797                NodeSocket::required("UV",     SocketType::Vec2),
798                NodeSocket::optional("Time",   SocketType::Float, "0.0"),
799                NodeSocket::optional("Scale",  SocketType::Float, "0.05"),
800            ],
801            NodeType::LyapunovViz => vec![
802                NodeSocket::required("UV",     SocketType::Vec2),
803                NodeSocket::optional("Seq",    SocketType::Float, "0.0"),
804                NodeSocket::optional("Iters",  SocketType::Float, "100.0"),
805            ],
806            NodeType::IfGreater | NodeType::IfLess => vec![
807                NodeSocket::required("A",         SocketType::Any),
808                NodeSocket::optional("Threshold", SocketType::Float, "0.5"),
809                NodeSocket::required("TrueVal",   SocketType::Any),
810                NodeSocket::required("FalseVal",  SocketType::Any),
811            ],
812            NodeType::ConditionalBlend => vec![
813                NodeSocket::required("Condition", SocketType::Float),
814                NodeSocket::required("A",         SocketType::Any),
815                NodeSocket::required("B",         SocketType::Any),
816                NodeSocket::optional("Feather",   SocketType::Float, "0.05"),
817            ],
818            NodeType::TextureSample => vec![
819                NodeSocket::required("Tex", SocketType::Sampler2D),
820                NodeSocket::required("UV",  SocketType::Vec2),
821            ],
822            NodeType::ScreenShake => vec![
823                NodeSocket::required("UV",        SocketType::Vec2),
824                NodeSocket::optional("Strength",  SocketType::Float, "0.0"),
825                NodeSocket::optional("Time",      SocketType::Float, "0.0"),
826            ],
827            NodeType::OutputColor | NodeType::OutputWithBloom => vec![
828                NodeSocket::required("Color", SocketType::Vec4),
829            ],
830            NodeType::OutputTarget(_) => vec![
831                NodeSocket::required("Color", SocketType::Vec4),
832            ],
833            // Input nodes — no inputs
834            _ => vec![],
835        }
836    }
837
838    /// Output socket definitions.
839    pub fn output_sockets(&self) -> Vec<NodeSocket> {
840        match self {
841            NodeType::UvCoord | NodeType::ScreenCoord | NodeType::RotateVec2 => vec![
842                NodeSocket::optional("UV", SocketType::Vec2, "vec2(0.0)"),
843            ],
844            NodeType::WorldPos | NodeType::CameraPos => vec![
845                NodeSocket::optional("Pos", SocketType::Vec3, "vec3(0.0)"),
846            ],
847            NodeType::Time => vec![
848                NodeSocket::optional("T", SocketType::Float, "0.0"),
849            ],
850            NodeType::Resolution => vec![
851                NodeSocket::optional("Res", SocketType::Vec2, "vec2(1.0)"),
852            ],
853            NodeType::ConstFloat(_) => vec![
854                NodeSocket::optional("Value", SocketType::Float, "0.0"),
855            ],
856            NodeType::ConstVec2(_, _) => vec![
857                NodeSocket::optional("Value", SocketType::Vec2, "vec2(0.0)"),
858            ],
859            NodeType::ConstVec3(..) => vec![
860                NodeSocket::optional("Value", SocketType::Vec3, "vec3(0.0)"),
861            ],
862            NodeType::ConstVec4(..) | NodeType::VertexColor => vec![
863                NodeSocket::optional("Value", SocketType::Vec4, "vec4(0.0)"),
864            ],
865            NodeType::Uniform(_, t) => vec![
866                NodeSocket::optional("Value", *t, t.default_value()),
867            ],
868            NodeType::TextureSample => vec![
869                NodeSocket::optional("RGBA", SocketType::Vec4, "vec4(0.0)"),
870                NodeSocket::optional("RGB",  SocketType::Vec3, "vec3(0.0)"),
871                NodeSocket::optional("A",    SocketType::Float, "0.0"),
872            ],
873            NodeType::CombineVec2 => vec![
874                NodeSocket::optional("XY", SocketType::Vec2, "vec2(0.0)"),
875            ],
876            NodeType::CombineVec3 | NodeType::HsvToRgb | NodeType::RgbToHsv |
877            NodeType::WorldPos | NodeType::Normalize | NodeType::Reflect | NodeType::Cross => vec![
878                NodeSocket::optional("Out", SocketType::Vec3, "vec3(0.0)"),
879            ],
880            NodeType::CombineVec4 => vec![
881                NodeSocket::optional("RGBA", SocketType::Vec4, "vec4(0.0)"),
882            ],
883            NodeType::SplitVec2 => vec![
884                NodeSocket::optional("X", SocketType::Float, "0.0"),
885                NodeSocket::optional("Y", SocketType::Float, "0.0"),
886            ],
887            NodeType::SplitVec3 => vec![
888                NodeSocket::optional("X", SocketType::Float, "0.0"),
889                NodeSocket::optional("Y", SocketType::Float, "0.0"),
890                NodeSocket::optional("Z", SocketType::Float, "0.0"),
891            ],
892            NodeType::SplitVec4 => vec![
893                NodeSocket::optional("X", SocketType::Float, "0.0"),
894                NodeSocket::optional("Y", SocketType::Float, "0.0"),
895                NodeSocket::optional("Z", SocketType::Float, "0.0"),
896                NodeSocket::optional("W", SocketType::Float, "0.0"),
897            ],
898            // Output nodes — no outputs
899            NodeType::OutputColor | NodeType::OutputTarget(_) | NodeType::OutputWithBloom => vec![],
900            // Most nodes produce a single "Out" of Any type
901            _ => vec![
902                NodeSocket::optional("Out", SocketType::Any, "0.0"),
903            ],
904        }
905    }
906
907    pub fn is_input_node(&self) -> bool {
908        matches!(self,
909            NodeType::UvCoord | NodeType::WorldPos | NodeType::CameraPos |
910            NodeType::Time | NodeType::Resolution | NodeType::ConstFloat(_) |
911            NodeType::ConstVec2(..) | NodeType::ConstVec3(..) | NodeType::ConstVec4(..) |
912            NodeType::Uniform(..) | NodeType::VertexColor | NodeType::ScreenCoord
913        )
914    }
915
916    pub fn is_output_node(&self) -> bool {
917        matches!(self,
918            NodeType::OutputColor | NodeType::OutputTarget(_) | NodeType::OutputWithBloom
919        )
920    }
921
922    pub fn output_count(&self) -> usize { self.output_sockets().len() }
923    pub fn input_count(&self)  -> usize { self.input_sockets().len() }
924}
925
926// ── ShaderNode ────────────────────────────────────────────────────────────────
927
928/// A single node in the shader graph.
929#[derive(Debug, Clone)]
930pub struct ShaderNode {
931    pub id:        NodeId,
932    pub node_type: NodeType,
933    /// Editor layout position.
934    pub editor_x:  f32,
935    pub editor_y:  f32,
936    /// Per-input constant fallback values (used when socket is not connected).
937    pub constant_inputs: HashMap<usize, String>,
938    /// Optional label override.
939    pub label:     Option<String>,
940    /// Whether this node is bypassed (output = first input).
941    pub bypassed:  bool,
942    /// Whether this node is muted (output = zero/transparent).
943    pub muted:     bool,
944}
945
946impl ShaderNode {
947    pub fn new(id: NodeId, node_type: NodeType) -> Self {
948        Self {
949            id, node_type,
950            editor_x:        0.0,
951            editor_y:        0.0,
952            constant_inputs: HashMap::new(),
953            label:           None,
954            bypassed:        false,
955            muted:           false,
956        }
957    }
958
959    pub fn with_label(mut self, label: impl Into<String>) -> Self {
960        self.label = Some(label.into());
961        self
962    }
963
964    pub fn with_constant(mut self, slot: usize, value: impl Into<String>) -> Self {
965        self.constant_inputs.insert(slot, value.into());
966        self
967    }
968
969    pub fn display_label(&self) -> &str {
970        self.label.as_deref().unwrap_or_else(|| self.node_type.label())
971    }
972
973    /// Variable name used in compiled GLSL for the output of this node.
974    pub fn var_name(&self, slot: usize) -> String {
975        format!("n{}_{}", self.id.0, slot)
976    }
977}
978
979// ── Tests ─────────────────────────────────────────────────────────────────────
980
981#[cfg(test)]
982mod tests {
983    use super::*;
984
985    #[test]
986    fn test_all_node_types_have_labels() {
987        let types = [
988            NodeType::Add, NodeType::Multiply, NodeType::Sin, NodeType::Cos,
989            NodeType::PerlinNoise, NodeType::Mandelbrot, NodeType::OutputColor,
990        ];
991        for t in &types {
992            assert!(!t.label().is_empty());
993        }
994    }
995
996    #[test]
997    fn test_socket_compatibility() {
998        assert!(SocketType::Float.is_compatible_with(SocketType::Float));
999        assert!(SocketType::Any.is_compatible_with(SocketType::Vec3));
1000        assert!(!SocketType::Float.is_compatible_with(SocketType::Vec4));
1001    }
1002
1003    #[test]
1004    fn test_node_input_output_counts() {
1005        let add = NodeType::Add;
1006        assert_eq!(add.input_count(), 2);
1007        assert_eq!(add.output_count(), 1);
1008
1009        let uv = NodeType::UvCoord;
1010        assert_eq!(uv.input_count(), 0);
1011        assert_eq!(uv.output_count(), 1);
1012
1013        let out = NodeType::OutputColor;
1014        assert_eq!(out.input_count(), 1);
1015        assert_eq!(out.output_count(), 0);
1016    }
1017
1018    #[test]
1019    fn test_var_name() {
1020        let node = ShaderNode::new(NodeId(42), NodeType::Add);
1021        assert_eq!(node.var_name(0), "n42_0");
1022        assert_eq!(node.var_name(1), "n42_1");
1023    }
1024
1025    #[test]
1026    fn test_socket_default_values() {
1027        let s = NodeSocket::optional("test", SocketType::Vec3, "vec3(0.0)");
1028        assert_eq!(s.default, "vec3(0.0)");
1029        assert!(!s.required);
1030    }
1031
1032    #[test]
1033    fn test_is_input_output_node() {
1034        assert!(NodeType::UvCoord.is_input_node());
1035        assert!(!NodeType::Add.is_input_node());
1036        assert!(NodeType::OutputColor.is_output_node());
1037        assert!(!NodeType::Add.is_output_node());
1038    }
1039}