Skip to main content

proof_engine/render/shader_graph/
mod.rs

1//! Node-based composable shader graph system.
2//!
3//! The shader graph compiles a directed acyclic graph of nodes into GLSL
4//! fragment shader source code at runtime. Every visual effect in Proof
5//! Engine can be described as a graph of mathematical operations.
6//!
7//! ## Architecture
8//! - `ShaderGraph`      — owns nodes and edges, validates, compiles
9//! - `ShaderNode`       — individual processing unit (40+ types)
10//! - `ShaderEdge`       — connects an output socket to an input socket
11//! - `GraphCompiler`    — walks the graph and emits GLSL
12//! - `GraphOptimizer`   — dead-node elimination, constant folding
13//! - `ShaderPreset`     — named, pre-built graphs for common effects
14//! - `ShaderParameter`  — runtime-controllable uniform (bound to MathFunction)
15//!
16//! ## Quick Start
17//! ```rust,no_run
18//! use proof_engine::render::shader_graph::{ShaderGraph, ShaderPreset};
19//! let graph = ShaderPreset::void_protocol();
20//! let glsl  = graph.compile().unwrap();
21//! println!("{}", glsl.fragment_source);
22//! ```
23
24pub mod nodes;
25pub mod compiler;
26pub mod optimizer;
27pub mod presets;
28
29pub use nodes::{ShaderNode, NodeType, SocketType, NodeSocket};
30pub use compiler::{GraphCompiler, CompiledShader};
31pub use optimizer::GraphOptimizer;
32pub use presets::ShaderPreset;
33
34use std::collections::HashMap;
35use crate::math::MathFunction;
36
37// ── Identifiers ───────────────────────────────────────────────────────────────
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct NodeId(pub u32);
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct EdgeId(pub u32);
44
45// ── ShaderEdge ────────────────────────────────────────────────────────────────
46
47/// A directed connection from one node's output socket to another's input.
48#[derive(Debug, Clone)]
49pub struct ShaderEdge {
50    pub id:        EdgeId,
51    pub from_node: NodeId,
52    pub from_slot: u8,
53    pub to_node:   NodeId,
54    pub to_slot:   u8,
55}
56
57// ── ShaderParameter ───────────────────────────────────────────────────────────
58
59/// A runtime-controllable parameter bound to a GLSL uniform.
60#[derive(Debug, Clone)]
61pub struct ShaderParameter {
62    pub name:     String,
63    pub glsl_name: String,
64    pub value:    ParameterValue,
65    /// Optional MathFunction driving this parameter over time.
66    pub driver:   Option<MathFunction>,
67    pub min:      f32,
68    pub max:      f32,
69}
70
71#[derive(Debug, Clone)]
72pub enum ParameterValue {
73    Float(f32),
74    Vec2(f32, f32),
75    Vec3(f32, f32, f32),
76    Vec4(f32, f32, f32, f32),
77    Int(i32),
78    Bool(bool),
79}
80
81impl ParameterValue {
82    pub fn as_float(&self) -> Option<f32> {
83        if let ParameterValue::Float(v) = self { Some(*v) } else { None }
84    }
85
86    pub fn glsl_type(&self) -> &'static str {
87        match self {
88            ParameterValue::Float(_)       => "float",
89            ParameterValue::Vec2(_, _)     => "vec2",
90            ParameterValue::Vec3(_, _, _)  => "vec3",
91            ParameterValue::Vec4(_, _, _, _) => "vec4",
92            ParameterValue::Int(_)         => "int",
93            ParameterValue::Bool(_)        => "bool",
94        }
95    }
96
97    pub fn glsl_literal(&self) -> String {
98        match self {
99            ParameterValue::Float(v)          => format!("{:.6}", v),
100            ParameterValue::Vec2(x, y)        => format!("vec2({:.6}, {:.6})", x, y),
101            ParameterValue::Vec3(x, y, z)     => format!("vec3({:.6}, {:.6}, {:.6})", x, y, z),
102            ParameterValue::Vec4(x,y,z,w)     => format!("vec4({:.6},{:.6},{:.6},{:.6})",x,y,z,w),
103            ParameterValue::Int(v)            => format!("{}", v),
104            ParameterValue::Bool(v)           => if *v { "true".to_string() } else { "false".to_string() },
105        }
106    }
107}
108
109// ── ShaderGraph ───────────────────────────────────────────────────────────────
110
111/// A directed acyclic graph of shader processing nodes.
112#[derive(Debug, Clone)]
113pub struct ShaderGraph {
114    pub name:       String,
115    pub nodes:      HashMap<NodeId, ShaderNode>,
116    pub edges:      Vec<ShaderEdge>,
117    pub parameters: Vec<ShaderParameter>,
118    /// The node whose output is the final fragment color.
119    pub output_node: Option<NodeId>,
120    next_node_id:   u32,
121    next_edge_id:   u32,
122}
123
124impl ShaderGraph {
125    pub fn new(name: impl Into<String>) -> Self {
126        Self {
127            name:         name.into(),
128            nodes:        HashMap::new(),
129            edges:        Vec::new(),
130            parameters:   Vec::new(),
131            output_node:  None,
132            next_node_id: 0,
133            next_edge_id: 0,
134        }
135    }
136
137    // ── Node management ────────────────────────────────────────────────────────
138
139    pub fn add_node(&mut self, node_type: NodeType) -> NodeId {
140        let id = NodeId(self.next_node_id);
141        self.next_node_id += 1;
142        self.nodes.insert(id, ShaderNode::new(id, node_type));
143        id
144    }
145
146    pub fn add_node_at(&mut self, node_type: NodeType, x: f32, y: f32) -> NodeId {
147        let id = self.add_node(node_type);
148        if let Some(n) = self.nodes.get_mut(&id) {
149            n.editor_x = x;
150            n.editor_y = y;
151        }
152        id
153    }
154
155    pub fn remove_node(&mut self, id: NodeId) -> bool {
156        if self.nodes.remove(&id).is_some() {
157            self.edges.retain(|e| e.from_node != id && e.to_node != id);
158            if self.output_node == Some(id) { self.output_node = None; }
159            true
160        } else {
161            false
162        }
163    }
164
165    pub fn node(&self, id: NodeId) -> Option<&ShaderNode> {
166        self.nodes.get(&id)
167    }
168
169    pub fn node_mut(&mut self, id: NodeId) -> Option<&mut ShaderNode> {
170        self.nodes.get_mut(&id)
171    }
172
173    pub fn set_output(&mut self, id: NodeId) {
174        self.output_node = Some(id);
175    }
176
177    // ── Edge management ────────────────────────────────────────────────────────
178
179    pub fn connect(
180        &mut self,
181        from_node: NodeId, from_slot: u8,
182        to_node:   NodeId, to_slot:   u8,
183    ) -> Result<EdgeId, GraphError> {
184        // Validate nodes exist
185        if !self.nodes.contains_key(&from_node) {
186            return Err(GraphError::NodeNotFound(from_node));
187        }
188        if !self.nodes.contains_key(&to_node) {
189            return Err(GraphError::NodeNotFound(to_node));
190        }
191        // Prevent duplicate connections to same input slot
192        if self.edges.iter().any(|e| e.to_node == to_node && e.to_slot == to_slot) {
193            return Err(GraphError::SlotAlreadyConnected { node: to_node, slot: to_slot });
194        }
195        // Prevent cycles (simple reachability check)
196        if self.would_create_cycle(from_node, to_node) {
197            return Err(GraphError::CycleDetected);
198        }
199        let id = EdgeId(self.next_edge_id);
200        self.next_edge_id += 1;
201        self.edges.push(ShaderEdge { id, from_node, from_slot, to_node, to_slot });
202        Ok(id)
203    }
204
205    pub fn disconnect(&mut self, edge_id: EdgeId) -> bool {
206        let before = self.edges.len();
207        self.edges.retain(|e| e.id != edge_id);
208        self.edges.len() < before
209    }
210
211    pub fn disconnect_input(&mut self, to_node: NodeId, to_slot: u8) {
212        self.edges.retain(|e| !(e.to_node == to_node && e.to_slot == to_slot));
213    }
214
215    // ── Parameter management ───────────────────────────────────────────────────
216
217    pub fn add_parameter(&mut self, param: ShaderParameter) -> usize {
218        let idx = self.parameters.len();
219        self.parameters.push(param);
220        idx
221    }
222
223    pub fn set_parameter_float(&mut self, name: &str, value: f32) {
224        for p in &mut self.parameters {
225            if p.name == name {
226                p.value = ParameterValue::Float(value.clamp(p.min, p.max));
227                break;
228            }
229        }
230    }
231
232    /// Update animated parameters by evaluating their MathFunction drivers.
233    pub fn update_parameters(&mut self, time: f32) {
234        for p in &mut self.parameters {
235            if let Some(ref func) = p.driver {
236                let v = func.evaluate(time, 0.0).clamp(p.min, p.max);
237                p.value = ParameterValue::Float(v);
238            }
239        }
240    }
241
242    // ── Compilation ────────────────────────────────────────────────────────────
243
244    /// Compile the graph to GLSL. Returns an error if the graph is invalid.
245    pub fn compile(&self) -> Result<CompiledShader, GraphError> {
246        let optimized = GraphOptimizer::run(self);
247        compiler::GraphCompiler::compile(&optimized)
248    }
249
250    /// Validate graph structure without compiling.
251    pub fn validate(&self) -> Vec<GraphError> {
252        let mut errors = Vec::new();
253        if self.output_node.is_none() {
254            errors.push(GraphError::NoOutputNode);
255        }
256        if let Some(out) = self.output_node {
257            if !self.nodes.contains_key(&out) {
258                errors.push(GraphError::NodeNotFound(out));
259            }
260        }
261        // Check for disconnected required inputs
262        for (id, node) in &self.nodes {
263            for (slot, sock) in node.node_type.input_sockets().iter().enumerate() {
264                if sock.required {
265                    let connected = self.edges.iter()
266                        .any(|e| e.to_node == *id && e.to_slot == slot as u8);
267                    if !connected && node.constant_inputs.get(&slot).is_none() {
268                        errors.push(GraphError::RequiredInputDisconnected {
269                            node: *id, slot: slot as u8,
270                        });
271                    }
272                }
273            }
274        }
275        errors
276    }
277
278    // ── Topological sort ───────────────────────────────────────────────────────
279
280    /// Returns nodes in evaluation order (inputs before outputs).
281    pub fn topological_order(&self) -> Result<Vec<NodeId>, GraphError> {
282        let mut visited = std::collections::HashSet::new();
283        let mut order   = Vec::new();
284
285        fn visit(
286            id: NodeId,
287            graph: &ShaderGraph,
288            visited: &mut std::collections::HashSet<NodeId>,
289            order:   &mut Vec<NodeId>,
290            stack:   &mut std::collections::HashSet<NodeId>,
291        ) -> Result<(), GraphError> {
292            if stack.contains(&id) { return Err(GraphError::CycleDetected); }
293            if visited.contains(&id) { return Ok(()); }
294            stack.insert(id);
295            // Visit all nodes feeding into this one
296            for edge in graph.edges.iter().filter(|e| e.to_node == id) {
297                visit(edge.from_node, graph, visited, order, stack)?;
298            }
299            stack.remove(&id);
300            visited.insert(id);
301            order.push(id);
302            Ok(())
303        }
304
305        let mut stack = std::collections::HashSet::new();
306        if let Some(out) = self.output_node {
307            visit(out, self, &mut visited, &mut order, &mut stack)?;
308        } else {
309            // Visit all nodes if no output set
310            let ids: Vec<NodeId> = self.nodes.keys().copied().collect();
311            for id in ids {
312                visit(id, self, &mut visited, &mut order, &mut stack)?;
313            }
314        }
315        Ok(order)
316    }
317
318    fn would_create_cycle(&self, from: NodeId, to: NodeId) -> bool {
319        // DFS from `to` — if we can reach `from`, adding from→to creates a cycle
320        let mut visited = std::collections::HashSet::new();
321        let mut stack   = vec![to];
322        while let Some(cur) = stack.pop() {
323            if cur == from { return true; }
324            if visited.insert(cur) {
325                for e in self.edges.iter().filter(|e| e.from_node == cur) {
326                    stack.push(e.to_node);
327                }
328            }
329        }
330        false
331    }
332
333    // ── Serialization ──────────────────────────────────────────────────────────
334
335    pub fn to_toml(&self) -> String {
336        let mut out = format!("[graph]\nname = {:?}\n\n", self.name);
337        for (id, node) in &self.nodes {
338            out.push_str(&format!(
339                "[[nodes]]\nid = {}\ntype = {:?}\nx = {:.1}\ny = {:.1}\n\n",
340                id.0, node.node_type.label(), node.editor_x, node.editor_y
341            ));
342        }
343        for edge in &self.edges {
344            out.push_str(&format!(
345                "[[edges]]\nfrom = {}\nfrom_slot = {}\nto = {}\nto_slot = {}\n\n",
346                edge.from_node.0, edge.from_slot, edge.to_node.0, edge.to_slot
347            ));
348        }
349        out
350    }
351
352    /// Statistics about the graph.
353    pub fn stats(&self) -> GraphStats {
354        GraphStats {
355            node_count:      self.nodes.len(),
356            edge_count:      self.edges.len(),
357            parameter_count: self.parameters.len(),
358        }
359    }
360}
361
362#[derive(Debug)]
363pub struct GraphStats {
364    pub node_count:      usize,
365    pub edge_count:      usize,
366    pub parameter_count: usize,
367}
368
369// ── GraphError ────────────────────────────────────────────────────────────────
370
371#[derive(Debug, Clone, PartialEq)]
372pub enum GraphError {
373    NodeNotFound(NodeId),
374    CycleDetected,
375    NoOutputNode,
376    SlotAlreadyConnected { node: NodeId, slot: u8 },
377    RequiredInputDisconnected { node: NodeId, slot: u8 },
378    TypeMismatch { from: SocketType, to: SocketType },
379    CompileError(String),
380}
381
382impl std::fmt::Display for GraphError {
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        match self {
385            GraphError::NodeNotFound(id)         => write!(f, "Node {:?} not found", id),
386            GraphError::CycleDetected            => write!(f, "Graph contains a cycle"),
387            GraphError::NoOutputNode             => write!(f, "No output node set"),
388            GraphError::SlotAlreadyConnected { node, slot } =>
389                write!(f, "Node {:?} slot {} already has an incoming connection", node, slot),
390            GraphError::RequiredInputDisconnected { node, slot } =>
391                write!(f, "Node {:?} required slot {} is not connected", node, slot),
392            GraphError::TypeMismatch { from, to } =>
393                write!(f, "Type mismatch: {:?} -> {:?}", from, to),
394            GraphError::CompileError(msg)        => write!(f, "Compile error: {}", msg),
395        }
396    }
397}
398
399// ── Tests ─────────────────────────────────────────────────────────────────────
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use nodes::NodeType;
405
406    #[test]
407    fn test_add_remove_node() {
408        let mut g = ShaderGraph::new("test");
409        let id = g.add_node(NodeType::UvCoord);
410        assert!(g.node(id).is_some());
411        assert!(g.remove_node(id));
412        assert!(g.node(id).is_none());
413    }
414
415    #[test]
416    fn test_connect_nodes() {
417        let mut g = ShaderGraph::new("test");
418        let uv   = g.add_node(NodeType::UvCoord);
419        let out  = g.add_node(NodeType::OutputColor);
420        g.set_output(out);
421        let result = g.connect(uv, 0, out, 0);
422        assert!(result.is_ok());
423    }
424
425    #[test]
426    fn test_cycle_detection() {
427        let mut g  = ShaderGraph::new("test");
428        let a = g.add_node(NodeType::Add);
429        let b = g.add_node(NodeType::Add);
430        let _ = g.connect(a, 0, b, 0);
431        let result = g.connect(b, 0, a, 0);
432        assert_eq!(result, Err(GraphError::CycleDetected));
433    }
434
435    #[test]
436    fn test_duplicate_input_rejected() {
437        let mut g  = ShaderGraph::new("test");
438        let src1 = g.add_node(NodeType::ConstFloat(1.0));
439        let src2 = g.add_node(NodeType::ConstFloat(2.0));
440        let dst  = g.add_node(NodeType::Add);
441        let _ = g.connect(src1, 0, dst, 0);
442        let r = g.connect(src2, 0, dst, 0);
443        assert!(matches!(r, Err(GraphError::SlotAlreadyConnected { .. })));
444    }
445
446    #[test]
447    fn test_topological_order() {
448        let mut g   = ShaderGraph::new("test");
449        let uv  = g.add_node(NodeType::UvCoord);
450        let sin = g.add_node(NodeType::SineWave);
451        let out = g.add_node(NodeType::OutputColor);
452        g.set_output(out);
453        let _ = g.connect(uv, 0, sin, 0);
454        let _ = g.connect(sin, 0, out, 0);
455        let order = g.topological_order().unwrap();
456        assert_eq!(order[0], uv);
457        assert_eq!(order[1], sin);
458        assert_eq!(order[2], out);
459    }
460
461    #[test]
462    fn test_parameter_update() {
463        let mut g = ShaderGraph::new("test");
464        g.add_parameter(ShaderParameter {
465            name:      "brightness".to_string(),
466            glsl_name: "u_brightness".to_string(),
467            value:     ParameterValue::Float(0.5),
468            driver:    None,
469            min:       0.0,
470            max:       2.0,
471        });
472        g.set_parameter_float("brightness", 1.5);
473        assert_eq!(g.parameters[0].value.as_float(), Some(1.5));
474    }
475
476    #[test]
477    fn test_stats() {
478        let mut g = ShaderGraph::new("test");
479        g.add_node(NodeType::UvCoord);
480        g.add_node(NodeType::OutputColor);
481        let s = g.stats();
482        assert_eq!(s.node_count, 2);
483    }
484}