1pub mod nodes;
25pub mod compiler;
26pub mod optimizer;
27pub mod presets;
28
29pub use nodes::{ShaderNode, NodeType, SocketType, NodeSocket};
30pub use compiler::{GraphCompiler, CompiledShader};
31pub use optimizer::GraphOptimizer;
32pub use presets::ShaderPreset;
33
34use std::collections::HashMap;
35use crate::math::MathFunction;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct NodeId(pub u32);
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct EdgeId(pub u32);
44
45#[derive(Debug, Clone)]
49pub struct ShaderEdge {
50 pub id: EdgeId,
51 pub from_node: NodeId,
52 pub from_slot: u8,
53 pub to_node: NodeId,
54 pub to_slot: u8,
55}
56
57#[derive(Debug, Clone)]
61pub struct ShaderParameter {
62 pub name: String,
63 pub glsl_name: String,
64 pub value: ParameterValue,
65 pub driver: Option<MathFunction>,
67 pub min: f32,
68 pub max: f32,
69}
70
71#[derive(Debug, Clone)]
72pub enum ParameterValue {
73 Float(f32),
74 Vec2(f32, f32),
75 Vec3(f32, f32, f32),
76 Vec4(f32, f32, f32, f32),
77 Int(i32),
78 Bool(bool),
79}
80
81impl ParameterValue {
82 pub fn as_float(&self) -> Option<f32> {
83 if let ParameterValue::Float(v) = self { Some(*v) } else { None }
84 }
85
86 pub fn glsl_type(&self) -> &'static str {
87 match self {
88 ParameterValue::Float(_) => "float",
89 ParameterValue::Vec2(_, _) => "vec2",
90 ParameterValue::Vec3(_, _, _) => "vec3",
91 ParameterValue::Vec4(_, _, _, _) => "vec4",
92 ParameterValue::Int(_) => "int",
93 ParameterValue::Bool(_) => "bool",
94 }
95 }
96
97 pub fn glsl_literal(&self) -> String {
98 match self {
99 ParameterValue::Float(v) => format!("{:.6}", v),
100 ParameterValue::Vec2(x, y) => format!("vec2({:.6}, {:.6})", x, y),
101 ParameterValue::Vec3(x, y, z) => format!("vec3({:.6}, {:.6}, {:.6})", x, y, z),
102 ParameterValue::Vec4(x,y,z,w) => format!("vec4({:.6},{:.6},{:.6},{:.6})",x,y,z,w),
103 ParameterValue::Int(v) => format!("{}", v),
104 ParameterValue::Bool(v) => if *v { "true".to_string() } else { "false".to_string() },
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
113pub struct ShaderGraph {
114 pub name: String,
115 pub nodes: HashMap<NodeId, ShaderNode>,
116 pub edges: Vec<ShaderEdge>,
117 pub parameters: Vec<ShaderParameter>,
118 pub output_node: Option<NodeId>,
120 next_node_id: u32,
121 next_edge_id: u32,
122}
123
124impl ShaderGraph {
125 pub fn new(name: impl Into<String>) -> Self {
126 Self {
127 name: name.into(),
128 nodes: HashMap::new(),
129 edges: Vec::new(),
130 parameters: Vec::new(),
131 output_node: None,
132 next_node_id: 0,
133 next_edge_id: 0,
134 }
135 }
136
137 pub fn add_node(&mut self, node_type: NodeType) -> NodeId {
140 let id = NodeId(self.next_node_id);
141 self.next_node_id += 1;
142 self.nodes.insert(id, ShaderNode::new(id, node_type));
143 id
144 }
145
146 pub fn add_node_at(&mut self, node_type: NodeType, x: f32, y: f32) -> NodeId {
147 let id = self.add_node(node_type);
148 if let Some(n) = self.nodes.get_mut(&id) {
149 n.editor_x = x;
150 n.editor_y = y;
151 }
152 id
153 }
154
155 pub fn remove_node(&mut self, id: NodeId) -> bool {
156 if self.nodes.remove(&id).is_some() {
157 self.edges.retain(|e| e.from_node != id && e.to_node != id);
158 if self.output_node == Some(id) { self.output_node = None; }
159 true
160 } else {
161 false
162 }
163 }
164
165 pub fn node(&self, id: NodeId) -> Option<&ShaderNode> {
166 self.nodes.get(&id)
167 }
168
169 pub fn node_mut(&mut self, id: NodeId) -> Option<&mut ShaderNode> {
170 self.nodes.get_mut(&id)
171 }
172
173 pub fn set_output(&mut self, id: NodeId) {
174 self.output_node = Some(id);
175 }
176
177 pub fn connect(
180 &mut self,
181 from_node: NodeId, from_slot: u8,
182 to_node: NodeId, to_slot: u8,
183 ) -> Result<EdgeId, GraphError> {
184 if !self.nodes.contains_key(&from_node) {
186 return Err(GraphError::NodeNotFound(from_node));
187 }
188 if !self.nodes.contains_key(&to_node) {
189 return Err(GraphError::NodeNotFound(to_node));
190 }
191 if self.edges.iter().any(|e| e.to_node == to_node && e.to_slot == to_slot) {
193 return Err(GraphError::SlotAlreadyConnected { node: to_node, slot: to_slot });
194 }
195 if self.would_create_cycle(from_node, to_node) {
197 return Err(GraphError::CycleDetected);
198 }
199 let id = EdgeId(self.next_edge_id);
200 self.next_edge_id += 1;
201 self.edges.push(ShaderEdge { id, from_node, from_slot, to_node, to_slot });
202 Ok(id)
203 }
204
205 pub fn disconnect(&mut self, edge_id: EdgeId) -> bool {
206 let before = self.edges.len();
207 self.edges.retain(|e| e.id != edge_id);
208 self.edges.len() < before
209 }
210
211 pub fn disconnect_input(&mut self, to_node: NodeId, to_slot: u8) {
212 self.edges.retain(|e| !(e.to_node == to_node && e.to_slot == to_slot));
213 }
214
215 pub fn add_parameter(&mut self, param: ShaderParameter) -> usize {
218 let idx = self.parameters.len();
219 self.parameters.push(param);
220 idx
221 }
222
223 pub fn set_parameter_float(&mut self, name: &str, value: f32) {
224 for p in &mut self.parameters {
225 if p.name == name {
226 p.value = ParameterValue::Float(value.clamp(p.min, p.max));
227 break;
228 }
229 }
230 }
231
232 pub fn update_parameters(&mut self, time: f32) {
234 for p in &mut self.parameters {
235 if let Some(ref func) = p.driver {
236 let v = func.evaluate(time, 0.0).clamp(p.min, p.max);
237 p.value = ParameterValue::Float(v);
238 }
239 }
240 }
241
242 pub fn compile(&self) -> Result<CompiledShader, GraphError> {
246 let optimized = GraphOptimizer::run(self);
247 compiler::GraphCompiler::compile(&optimized)
248 }
249
250 pub fn validate(&self) -> Vec<GraphError> {
252 let mut errors = Vec::new();
253 if self.output_node.is_none() {
254 errors.push(GraphError::NoOutputNode);
255 }
256 if let Some(out) = self.output_node {
257 if !self.nodes.contains_key(&out) {
258 errors.push(GraphError::NodeNotFound(out));
259 }
260 }
261 for (id, node) in &self.nodes {
263 for (slot, sock) in node.node_type.input_sockets().iter().enumerate() {
264 if sock.required {
265 let connected = self.edges.iter()
266 .any(|e| e.to_node == *id && e.to_slot == slot as u8);
267 if !connected && node.constant_inputs.get(&slot).is_none() {
268 errors.push(GraphError::RequiredInputDisconnected {
269 node: *id, slot: slot as u8,
270 });
271 }
272 }
273 }
274 }
275 errors
276 }
277
278 pub fn topological_order(&self) -> Result<Vec<NodeId>, GraphError> {
282 let mut visited = std::collections::HashSet::new();
283 let mut order = Vec::new();
284
285 fn visit(
286 id: NodeId,
287 graph: &ShaderGraph,
288 visited: &mut std::collections::HashSet<NodeId>,
289 order: &mut Vec<NodeId>,
290 stack: &mut std::collections::HashSet<NodeId>,
291 ) -> Result<(), GraphError> {
292 if stack.contains(&id) { return Err(GraphError::CycleDetected); }
293 if visited.contains(&id) { return Ok(()); }
294 stack.insert(id);
295 for edge in graph.edges.iter().filter(|e| e.to_node == id) {
297 visit(edge.from_node, graph, visited, order, stack)?;
298 }
299 stack.remove(&id);
300 visited.insert(id);
301 order.push(id);
302 Ok(())
303 }
304
305 let mut stack = std::collections::HashSet::new();
306 if let Some(out) = self.output_node {
307 visit(out, self, &mut visited, &mut order, &mut stack)?;
308 } else {
309 let ids: Vec<NodeId> = self.nodes.keys().copied().collect();
311 for id in ids {
312 visit(id, self, &mut visited, &mut order, &mut stack)?;
313 }
314 }
315 Ok(order)
316 }
317
318 fn would_create_cycle(&self, from: NodeId, to: NodeId) -> bool {
319 let mut visited = std::collections::HashSet::new();
321 let mut stack = vec![to];
322 while let Some(cur) = stack.pop() {
323 if cur == from { return true; }
324 if visited.insert(cur) {
325 for e in self.edges.iter().filter(|e| e.from_node == cur) {
326 stack.push(e.to_node);
327 }
328 }
329 }
330 false
331 }
332
333 pub fn to_toml(&self) -> String {
336 let mut out = format!("[graph]\nname = {:?}\n\n", self.name);
337 for (id, node) in &self.nodes {
338 out.push_str(&format!(
339 "[[nodes]]\nid = {}\ntype = {:?}\nx = {:.1}\ny = {:.1}\n\n",
340 id.0, node.node_type.label(), node.editor_x, node.editor_y
341 ));
342 }
343 for edge in &self.edges {
344 out.push_str(&format!(
345 "[[edges]]\nfrom = {}\nfrom_slot = {}\nto = {}\nto_slot = {}\n\n",
346 edge.from_node.0, edge.from_slot, edge.to_node.0, edge.to_slot
347 ));
348 }
349 out
350 }
351
352 pub fn stats(&self) -> GraphStats {
354 GraphStats {
355 node_count: self.nodes.len(),
356 edge_count: self.edges.len(),
357 parameter_count: self.parameters.len(),
358 }
359 }
360}
361
362#[derive(Debug)]
363pub struct GraphStats {
364 pub node_count: usize,
365 pub edge_count: usize,
366 pub parameter_count: usize,
367}
368
369#[derive(Debug, Clone, PartialEq)]
372pub enum GraphError {
373 NodeNotFound(NodeId),
374 CycleDetected,
375 NoOutputNode,
376 SlotAlreadyConnected { node: NodeId, slot: u8 },
377 RequiredInputDisconnected { node: NodeId, slot: u8 },
378 TypeMismatch { from: SocketType, to: SocketType },
379 CompileError(String),
380}
381
382impl std::fmt::Display for GraphError {
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 match self {
385 GraphError::NodeNotFound(id) => write!(f, "Node {:?} not found", id),
386 GraphError::CycleDetected => write!(f, "Graph contains a cycle"),
387 GraphError::NoOutputNode => write!(f, "No output node set"),
388 GraphError::SlotAlreadyConnected { node, slot } =>
389 write!(f, "Node {:?} slot {} already has an incoming connection", node, slot),
390 GraphError::RequiredInputDisconnected { node, slot } =>
391 write!(f, "Node {:?} required slot {} is not connected", node, slot),
392 GraphError::TypeMismatch { from, to } =>
393 write!(f, "Type mismatch: {:?} -> {:?}", from, to),
394 GraphError::CompileError(msg) => write!(f, "Compile error: {}", msg),
395 }
396 }
397}
398
399#[cfg(test)]
402mod tests {
403 use super::*;
404 use nodes::NodeType;
405
406 #[test]
407 fn test_add_remove_node() {
408 let mut g = ShaderGraph::new("test");
409 let id = g.add_node(NodeType::UvCoord);
410 assert!(g.node(id).is_some());
411 assert!(g.remove_node(id));
412 assert!(g.node(id).is_none());
413 }
414
415 #[test]
416 fn test_connect_nodes() {
417 let mut g = ShaderGraph::new("test");
418 let uv = g.add_node(NodeType::UvCoord);
419 let out = g.add_node(NodeType::OutputColor);
420 g.set_output(out);
421 let result = g.connect(uv, 0, out, 0);
422 assert!(result.is_ok());
423 }
424
425 #[test]
426 fn test_cycle_detection() {
427 let mut g = ShaderGraph::new("test");
428 let a = g.add_node(NodeType::Add);
429 let b = g.add_node(NodeType::Add);
430 let _ = g.connect(a, 0, b, 0);
431 let result = g.connect(b, 0, a, 0);
432 assert_eq!(result, Err(GraphError::CycleDetected));
433 }
434
435 #[test]
436 fn test_duplicate_input_rejected() {
437 let mut g = ShaderGraph::new("test");
438 let src1 = g.add_node(NodeType::ConstFloat(1.0));
439 let src2 = g.add_node(NodeType::ConstFloat(2.0));
440 let dst = g.add_node(NodeType::Add);
441 let _ = g.connect(src1, 0, dst, 0);
442 let r = g.connect(src2, 0, dst, 0);
443 assert!(matches!(r, Err(GraphError::SlotAlreadyConnected { .. })));
444 }
445
446 #[test]
447 fn test_topological_order() {
448 let mut g = ShaderGraph::new("test");
449 let uv = g.add_node(NodeType::UvCoord);
450 let sin = g.add_node(NodeType::SineWave);
451 let out = g.add_node(NodeType::OutputColor);
452 g.set_output(out);
453 let _ = g.connect(uv, 0, sin, 0);
454 let _ = g.connect(sin, 0, out, 0);
455 let order = g.topological_order().unwrap();
456 assert_eq!(order[0], uv);
457 assert_eq!(order[1], sin);
458 assert_eq!(order[2], out);
459 }
460
461 #[test]
462 fn test_parameter_update() {
463 let mut g = ShaderGraph::new("test");
464 g.add_parameter(ShaderParameter {
465 name: "brightness".to_string(),
466 glsl_name: "u_brightness".to_string(),
467 value: ParameterValue::Float(0.5),
468 driver: None,
469 min: 0.0,
470 max: 2.0,
471 });
472 g.set_parameter_float("brightness", 1.5);
473 assert_eq!(g.parameters[0].value.as_float(), Some(1.5));
474 }
475
476 #[test]
477 fn test_stats() {
478 let mut g = ShaderGraph::new("test");
479 g.add_node(NodeType::UvCoord);
480 g.add_node(NodeType::OutputColor);
481 let s = g.stats();
482 assert_eq!(s.node_count, 2);
483 }
484}