1use std::collections::{HashMap, HashSet, VecDeque};
5use super::nodes::{
6 Connection, DataType, GlslSnippet, NodeId, NodeType, ParamValue, ShaderGraph, ShaderNode,
7};
8
9#[derive(Debug, Clone)]
15pub struct CompileOptions {
16 pub dead_node_elimination: bool,
18 pub constant_folding: bool,
20 pub common_subexpression_elimination: bool,
22 pub debug_comments: bool,
24 pub glsl_version: String,
26 pub enable_conditionals: bool,
28 pub animated_uniforms: bool,
30}
31
32impl Default for CompileOptions {
33 fn default() -> Self {
34 Self {
35 dead_node_elimination: true,
36 constant_folding: true,
37 common_subexpression_elimination: true,
38 debug_comments: false,
39 glsl_version: "330 core".to_string(),
40 enable_conditionals: true,
41 animated_uniforms: true,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
52pub enum CompileError {
53 CycleDetected(Vec<NodeId>),
55 MissingInput { node_id: NodeId, socket_index: usize, socket_name: String },
57 NoOutputNodes,
59 TypeMismatch {
61 from_node: NodeId,
62 from_socket: usize,
63 from_type: DataType,
64 to_node: NodeId,
65 to_socket: usize,
66 to_type: DataType,
67 },
68 ValidationErrors(Vec<String>),
70 Internal(String),
72}
73
74impl std::fmt::Display for CompileError {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 CompileError::CycleDetected(ids) => {
78 write!(f, "Cycle detected involving nodes: {:?}",
79 ids.iter().map(|id| id.0).collect::<Vec<_>>())
80 }
81 CompileError::MissingInput { node_id, socket_index, socket_name } => {
82 write!(f, "Node {} missing input at socket {} ('{}')",
83 node_id.0, socket_index, socket_name)
84 }
85 CompileError::NoOutputNodes => write!(f, "Graph has no output nodes"),
86 CompileError::TypeMismatch { from_node, from_socket, from_type, to_node, to_socket, to_type } => {
87 write!(f, "Type mismatch: node {}:{} ({}) -> node {}:{} ({})",
88 from_node.0, from_socket, from_type,
89 to_node.0, to_socket, to_type)
90 }
91 CompileError::ValidationErrors(errs) => {
92 write!(f, "Validation errors: {}", errs.join("; "))
93 }
94 CompileError::Internal(msg) => write!(f, "Internal error: {}", msg),
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
105pub struct CompiledShader {
106 pub fragment_source: String,
108 pub vertex_source: String,
110 pub uniforms: Vec<UniformDecl>,
112 pub varyings: Vec<VaryingDecl>,
114 pub instruction_count: u32,
116 pub sampler_count: u32,
118 pub live_node_count: usize,
120 pub topology_hash: u64,
122 pub node_order: Vec<NodeId>,
124 pub output_var_map: HashMap<(u64, usize), String>,
126}
127
128#[derive(Debug, Clone)]
130pub struct UniformDecl {
131 pub name: String,
132 pub data_type: DataType,
133 pub default_value: Option<ParamValue>,
134 pub is_animated: bool,
135}
136
137#[derive(Debug, Clone)]
139pub struct VaryingDecl {
140 pub name: String,
141 pub data_type: DataType,
142}
143
144pub struct ShaderCompiler {
150 options: CompileOptions,
151}
152
153impl ShaderCompiler {
154 pub fn new(options: CompileOptions) -> Self {
155 Self { options }
156 }
157
158 pub fn with_defaults() -> Self {
159 Self::new(CompileOptions::default())
160 }
161
162 pub fn compile(&self, graph: &ShaderGraph) -> Result<CompiledShader, CompileError> {
164 let errors = graph.validate();
166 if !errors.is_empty() {
167 return Err(CompileError::ValidationErrors(errors));
168 }
169
170 let output_nodes = graph.output_nodes();
172 if output_nodes.is_empty() {
173 return Err(CompileError::NoOutputNodes);
174 }
175
176 let live_nodes = if self.options.dead_node_elimination {
178 self.find_live_nodes(graph, &output_nodes)
179 } else {
180 graph.node_ids().collect()
181 };
182
183 let sorted = self.topological_sort(graph, &live_nodes)?;
185
186 let folded_values = if self.options.constant_folding {
188 self.constant_fold(graph, &sorted)
189 } else {
190 HashMap::new()
191 };
192
193 let cse_map = if self.options.common_subexpression_elimination {
195 self.find_common_subexpressions(graph, &sorted)
196 } else {
197 HashMap::new()
198 };
199
200 let (uniforms, varyings) = self.collect_declarations(graph, &sorted);
202
203 let (fragment_source, output_var_map) = self.generate_glsl(
205 graph, &sorted, &folded_values, &cse_map, &uniforms, &varyings,
206 );
207
208 let vertex_source = self.generate_vertex_shader(&varyings);
210
211 let instruction_count: u32 = sorted.iter()
213 .filter_map(|id| graph.node(*id).map(|n| n.estimated_cost()))
214 .sum();
215 let sampler_count = uniforms.iter()
216 .filter(|u| u.data_type == DataType::Sampler2D)
217 .count() as u32;
218
219 Ok(CompiledShader {
220 fragment_source,
221 vertex_source,
222 uniforms,
223 varyings,
224 instruction_count,
225 sampler_count,
226 live_node_count: sorted.len(),
227 topology_hash: graph.topology_hash(),
228 node_order: sorted,
229 output_var_map,
230 })
231 }
232
233 fn find_live_nodes(&self, graph: &ShaderGraph, outputs: &[NodeId]) -> HashSet<NodeId> {
239 let mut live = HashSet::new();
240 let mut queue: VecDeque<NodeId> = outputs.iter().copied().collect();
241
242 while let Some(node_id) = queue.pop_front() {
243 if !live.insert(node_id) {
244 continue; }
246 for conn in graph.connections() {
248 if conn.to_node == node_id && !live.contains(&conn.from_node) {
249 queue.push_back(conn.from_node);
250 }
251 }
252 }
253
254 live
255 }
256
257 fn topological_sort(
263 &self,
264 graph: &ShaderGraph,
265 live_nodes: &HashSet<NodeId>,
266 ) -> Result<Vec<NodeId>, CompileError> {
267 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
269 let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
270
271 for &nid in live_nodes {
272 in_degree.entry(nid).or_insert(0);
273 adjacency.entry(nid).or_insert_with(Vec::new);
274 }
275
276 for conn in graph.connections() {
277 if live_nodes.contains(&conn.from_node) && live_nodes.contains(&conn.to_node) {
278 adjacency.entry(conn.from_node).or_insert_with(Vec::new).push(conn.to_node);
279 *in_degree.entry(conn.to_node).or_insert(0) += 1;
280 }
281 }
282
283 let mut queue: VecDeque<NodeId> = in_degree.iter()
285 .filter(|(_, °)| deg == 0)
286 .map(|(&id, _)| id)
287 .collect();
288
289 let mut queue_vec: Vec<NodeId> = queue.drain(..).collect();
291 queue_vec.sort_by_key(|id| id.0);
292 queue = queue_vec.into_iter().collect();
293
294 let mut sorted = Vec::new();
295
296 while let Some(node_id) = queue.pop_front() {
297 sorted.push(node_id);
298 if let Some(neighbors) = adjacency.get(&node_id) {
299 let mut next_neighbors: Vec<NodeId> = Vec::new();
300 for &neighbor in neighbors {
301 if let Some(deg) = in_degree.get_mut(&neighbor) {
302 *deg -= 1;
303 if *deg == 0 {
304 next_neighbors.push(neighbor);
305 }
306 }
307 }
308 next_neighbors.sort_by_key(|id| id.0);
309 for n in next_neighbors {
310 queue.push_back(n);
311 }
312 }
313 }
314
315 if sorted.len() != live_nodes.len() {
316 let sorted_set: HashSet<NodeId> = sorted.iter().copied().collect();
318 let cycle_nodes: Vec<NodeId> = live_nodes.iter()
319 .filter(|id| !sorted_set.contains(id))
320 .copied()
321 .collect();
322 return Err(CompileError::CycleDetected(cycle_nodes));
323 }
324
325 Ok(sorted)
326 }
327
328 fn constant_fold(
335 &self,
336 graph: &ShaderGraph,
337 sorted: &[NodeId],
338 ) -> HashMap<NodeId, Vec<ParamValue>> {
339 let mut folded: HashMap<NodeId, Vec<ParamValue>> = HashMap::new();
340
341 for &node_id in sorted {
342 let node = match graph.node(node_id) {
343 Some(n) => n,
344 None => continue,
345 };
346
347 if !node.node_type.is_pure_math() {
348 continue;
349 }
350
351 let incoming = graph.incoming_connections(node_id);
353 let mut input_values: Vec<Option<ParamValue>> = Vec::new();
354 let mut all_constant = true;
355
356 for (idx, socket) in node.inputs.iter().enumerate() {
357 let conn = incoming.iter().find(|c| c.to_socket == idx);
359 if let Some(c) = conn {
360 if let Some(folded_vals) = folded.get(&c.from_node) {
362 if c.from_socket < folded_vals.len() {
363 input_values.push(Some(folded_vals[c.from_socket].clone()));
364 continue;
365 }
366 }
367 all_constant = false;
368 break;
369 } else if let Some(def) = &socket.default_value {
370 input_values.push(Some(def.clone()));
371 } else {
372 all_constant = false;
373 break;
374 }
375 }
376
377 if !all_constant {
378 continue;
379 }
380
381 let values: Vec<ParamValue> = input_values.into_iter().filter_map(|v| v).collect();
383 if let Some(result) = self.evaluate_constant(&node.node_type, &values) {
384 folded.insert(node_id, result);
385 }
386 }
387
388 folded
389 }
390
391 fn evaluate_constant(&self, node_type: &NodeType, inputs: &[ParamValue]) -> Option<Vec<ParamValue>> {
393 match node_type {
394 NodeType::Add => {
395 let a = inputs.first()?.as_float()?;
396 let b = inputs.get(1)?.as_float()?;
397 Some(vec![ParamValue::Float(a + b)])
398 }
399 NodeType::Sub => {
400 let a = inputs.first()?.as_float()?;
401 let b = inputs.get(1)?.as_float()?;
402 Some(vec![ParamValue::Float(a - b)])
403 }
404 NodeType::Mul => {
405 let a = inputs.first()?.as_float()?;
406 let b = inputs.get(1)?.as_float()?;
407 Some(vec![ParamValue::Float(a * b)])
408 }
409 NodeType::Div => {
410 let a = inputs.first()?.as_float()?;
411 let b = inputs.get(1)?.as_float()?;
412 if b.abs() < 1e-10 { return None; }
413 Some(vec![ParamValue::Float(a / b)])
414 }
415 NodeType::Abs => {
416 let x = inputs.first()?.as_float()?;
417 Some(vec![ParamValue::Float(x.abs())])
418 }
419 NodeType::Floor => {
420 let x = inputs.first()?.as_float()?;
421 Some(vec![ParamValue::Float(x.floor())])
422 }
423 NodeType::Ceil => {
424 let x = inputs.first()?.as_float()?;
425 Some(vec![ParamValue::Float(x.ceil())])
426 }
427 NodeType::Fract => {
428 let x = inputs.first()?.as_float()?;
429 Some(vec![ParamValue::Float(x.fract())])
430 }
431 NodeType::Mod => {
432 let x = inputs.first()?.as_float()?;
433 let y = inputs.get(1)?.as_float()?;
434 if y.abs() < 1e-10 { return None; }
435 Some(vec![ParamValue::Float(x % y)])
436 }
437 NodeType::Pow => {
438 let base = inputs.first()?.as_float()?;
439 let exp = inputs.get(1)?.as_float()?;
440 Some(vec![ParamValue::Float(base.max(0.0).powf(exp))])
441 }
442 NodeType::Sqrt => {
443 let x = inputs.first()?.as_float()?;
444 Some(vec![ParamValue::Float(x.max(0.0).sqrt())])
445 }
446 NodeType::Sin => {
447 let x = inputs.first()?.as_float()?;
448 Some(vec![ParamValue::Float(x.sin())])
449 }
450 NodeType::Cos => {
451 let x = inputs.first()?.as_float()?;
452 Some(vec![ParamValue::Float(x.cos())])
453 }
454 NodeType::Tan => {
455 let x = inputs.first()?.as_float()?;
456 Some(vec![ParamValue::Float(x.tan())])
457 }
458 NodeType::Atan2 => {
459 let y = inputs.first()?.as_float()?;
460 let x = inputs.get(1)?.as_float()?;
461 Some(vec![ParamValue::Float(y.atan2(x))])
462 }
463 NodeType::Lerp => {
464 let a = inputs.first()?.as_float()?;
465 let b = inputs.get(1)?.as_float()?;
466 let t = inputs.get(2)?.as_float()?;
467 Some(vec![ParamValue::Float(a + (b - a) * t)])
468 }
469 NodeType::Clamp => {
470 let x = inputs.first()?.as_float()?;
471 let lo = inputs.get(1)?.as_float()?;
472 let hi = inputs.get(2)?.as_float()?;
473 Some(vec![ParamValue::Float(x.clamp(lo, hi))])
474 }
475 NodeType::Smoothstep => {
476 let e0 = inputs.first()?.as_float()?;
477 let e1 = inputs.get(1)?.as_float()?;
478 let x = inputs.get(2)?.as_float()?;
479 let range = e1 - e0;
480 if range.abs() < 1e-10 {
481 return Some(vec![ParamValue::Float(if x < e0 { 0.0 } else { 1.0 })]);
482 }
483 let t = ((x - e0) / range).clamp(0.0, 1.0);
484 Some(vec![ParamValue::Float(t * t * (3.0 - 2.0 * t))])
485 }
486 NodeType::Remap => {
487 let x = inputs.first()?.as_float()?;
488 let in_min = inputs.get(1)?.as_float()?;
489 let in_max = inputs.get(2)?.as_float()?;
490 let out_min = inputs.get(3)?.as_float()?;
491 let out_max = inputs.get(4)?.as_float()?;
492 let range = in_max - in_min;
493 if range.abs() < 1e-10 { return None; }
494 let t = (x - in_min) / range;
495 Some(vec![ParamValue::Float(out_min + (out_max - out_min) * t)])
496 }
497 NodeType::Step => {
498 let edge = inputs.first()?.as_float()?;
499 let x = inputs.get(1)?.as_float()?;
500 Some(vec![ParamValue::Float(if x >= edge { 1.0 } else { 0.0 })])
501 }
502 NodeType::Invert => {
503 let c = inputs.first()?.as_vec3()?;
504 Some(vec![ParamValue::Vec3([1.0 - c[0], 1.0 - c[1], 1.0 - c[2]])])
505 }
506 NodeType::Posterize => {
507 let c = inputs.first()?.as_vec3()?;
508 let levels = inputs.get(1)?.as_float()?;
509 if levels < 1.0 { return None; }
510 Some(vec![ParamValue::Vec3([
511 (c[0] * levels).floor() / levels,
512 (c[1] * levels).floor() / levels,
513 (c[2] * levels).floor() / levels,
514 ])])
515 }
516 NodeType::Contrast => {
517 let c = inputs.first()?.as_vec3()?;
518 let amount = inputs.get(1)?.as_float()?;
519 Some(vec![ParamValue::Vec3([
520 (c[0] - 0.5) * amount + 0.5,
521 (c[1] - 0.5) * amount + 0.5,
522 (c[2] - 0.5) * amount + 0.5,
523 ])])
524 }
525 NodeType::Saturation => {
526 let c = inputs.first()?.as_vec3()?;
527 let amount = inputs.get(1)?.as_float()?;
528 let lum = c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114;
529 Some(vec![ParamValue::Vec3([
530 lum + (c[0] - lum) * amount,
531 lum + (c[1] - lum) * amount,
532 lum + (c[2] - lum) * amount,
533 ])])
534 }
535 _ => None, }
537 }
538
539 fn find_common_subexpressions(
545 &self,
546 graph: &ShaderGraph,
547 sorted: &[NodeId],
548 ) -> HashMap<NodeId, NodeId> {
549 let mut cse_map: HashMap<NodeId, NodeId> = HashMap::new();
550 let mut signatures: HashMap<String, NodeId> = HashMap::new();
552
553 for &node_id in sorted {
554 let node = match graph.node(node_id) {
555 Some(n) => n,
556 None => continue,
557 };
558
559 let incoming = graph.incoming_connections(node_id);
561 let mut sig_parts: Vec<String> = vec![node.node_type.display_name().to_string()];
562
563 for (idx, socket) in node.inputs.iter().enumerate() {
564 let conn = incoming.iter().find(|c| c.to_socket == idx);
565 if let Some(c) = conn {
566 let resolved = cse_map.get(&c.from_node).copied().unwrap_or(c.from_node);
568 sig_parts.push(format!("c{}:{}", resolved.0, c.from_socket));
569 } else if let Some(def) = &socket.default_value {
570 sig_parts.push(format!("d:{}", def.to_glsl()));
571 } else {
572 sig_parts.push("none".to_string());
573 }
574 }
575
576 let signature = sig_parts.join("|");
577
578 if let Some(&canonical) = signatures.get(&signature) {
579 cse_map.insert(node_id, canonical);
580 } else {
581 signatures.insert(signature, node_id);
582 }
583 }
584
585 cse_map
586 }
587
588 fn collect_declarations(
593 &self,
594 graph: &ShaderGraph,
595 sorted: &[NodeId],
596 ) -> (Vec<UniformDecl>, Vec<VaryingDecl>) {
597 let mut uniforms: Vec<UniformDecl> = Vec::new();
598 let mut uniform_names: HashSet<String> = HashSet::new();
599 let mut varyings: Vec<VaryingDecl> = Vec::new();
600 let mut varying_names: HashSet<String> = HashSet::new();
601
602 let standard_uniforms = vec![
604 ("u_time", DataType::Float, true),
605 ("u_model", DataType::Mat4, false),
606 ("u_view", DataType::Mat4, false),
607 ("u_projection", DataType::Mat4, false),
608 ("u_camera_pos", DataType::Vec3, false),
609 ("u_inv_model", DataType::Mat4, false),
610 ];
611 for (name, dt, animated) in standard_uniforms {
612 if uniform_names.insert(name.to_string()) {
613 uniforms.push(UniformDecl {
614 name: name.to_string(),
615 data_type: dt,
616 default_value: None,
617 is_animated: animated,
618 });
619 }
620 }
621
622 let standard_varyings = vec![
624 ("v_position", DataType::Vec3),
625 ("v_normal", DataType::Vec3),
626 ("v_uv", DataType::Vec2),
627 ];
628 for (name, dt) in standard_varyings {
629 if varying_names.insert(name.to_string()) {
630 varyings.push(VaryingDecl { name: name.to_string(), data_type: dt });
631 }
632 }
633
634 for &node_id in sorted {
635 let node = match graph.node(node_id) {
636 Some(n) => n,
637 None => continue,
638 };
639
640 match &node.node_type {
641 NodeType::Texture => {
642 let sampler_idx = node.inputs.get(1)
644 .and_then(|s| s.default_value.as_ref())
645 .and_then(|v| v.as_int())
646 .unwrap_or(0);
647 let name = format!("u_texture{}", sampler_idx);
648 if uniform_names.insert(name.clone()) {
649 uniforms.push(UniformDecl {
650 name,
651 data_type: DataType::Sampler2D,
652 default_value: None,
653 is_animated: false,
654 });
655 }
656 }
657 NodeType::GameStateVar => {
658 let var_name = node.inputs.first()
660 .and_then(|s| s.default_value.as_ref())
661 .and_then(|v| v.as_string())
662 .unwrap_or("game_var_0");
663 let name = format!("u_gs_{}", var_name);
664 if uniform_names.insert(name.clone()) {
665 uniforms.push(UniformDecl {
666 name,
667 data_type: DataType::Float,
668 default_value: Some(ParamValue::Float(0.0)),
669 is_animated: false,
670 });
671 }
672 }
673 _ => {}
674 }
675
676 if let Some(ref var_name) = node.conditional_var {
678 let name = format!("u_gs_{}", var_name);
679 if uniform_names.insert(name.clone()) {
680 uniforms.push(UniformDecl {
681 name,
682 data_type: DataType::Float,
683 default_value: Some(ParamValue::Float(0.0)),
684 is_animated: false,
685 });
686 }
687 }
688
689 for (key, val) in &node.properties {
691 if key.starts_with("uniform_") {
692 let name = format!("u_prop_{}_{}", node.id.0, key.trim_start_matches("uniform_"));
693 if uniform_names.insert(name.clone()) {
694 uniforms.push(UniformDecl {
695 name,
696 data_type: val.data_type(),
697 default_value: Some(val.clone()),
698 is_animated: self.options.animated_uniforms,
699 });
700 }
701 }
702 }
703 }
704
705 (uniforms, varyings)
706 }
707
708 fn generate_glsl(
713 &self,
714 graph: &ShaderGraph,
715 sorted: &[NodeId],
716 folded: &HashMap<NodeId, Vec<ParamValue>>,
717 cse_map: &HashMap<NodeId, NodeId>,
718 uniforms: &[UniformDecl],
719 varyings: &[VaryingDecl],
720 ) -> (String, HashMap<(u64, usize), String>) {
721 let mut code = String::new();
722 let mut output_var_map: HashMap<(u64, usize), String> = HashMap::new();
723
724 code.push_str(&format!("#version {}\n", self.options.glsl_version));
726 code.push_str("precision highp float;\n\n");
727
728 for u in uniforms {
730 code.push_str(&format!("uniform {} {};\n", u.data_type, u.name));
731 }
732 code.push('\n');
733
734 for v in varyings {
736 code.push_str(&format!("in {} {};\n", v.data_type, v.name));
737 }
738 code.push('\n');
739
740 code.push_str("layout(location = 0) out vec4 fragColor;\n");
742 code.push_str("layout(location = 1) out vec4 fragEmission;\n");
743 code.push_str("layout(location = 2) out vec4 fragBloom;\n");
744 code.push_str("layout(location = 3) out vec4 fragNormal;\n");
745 code.push('\n');
746
747 code.push_str("void main() {\n");
749
750 let mut emitted_cse: HashSet<NodeId> = HashSet::new();
752
753 for &node_id in sorted {
754 if let Some(&canonical) = cse_map.get(&node_id) {
756 if let Some(node) = graph.node(node_id) {
758 for (idx, _) in node.outputs.iter().enumerate() {
759 if let Some(var) = output_var_map.get(&(canonical.0, idx)) {
760 output_var_map.insert((node_id.0, idx), var.clone());
761 }
762 }
763 }
764 continue;
765 }
766
767 let node = match graph.node(node_id) {
768 Some(n) => n,
769 None => continue,
770 };
771
772 if !node.enabled {
773 continue;
774 }
775
776 if let Some(folded_vals) = folded.get(&node_id) {
778 if self.options.debug_comments {
779 code.push_str(&format!(" // [FOLDED] {} (node {})\n",
780 node.node_type.display_name(), node_id.0));
781 }
782 for (idx, val) in folded_vals.iter().enumerate() {
783 let var_name = format!("n{}_{}", node_id.0, idx);
784 code.push_str(&format!(" {} {} = {};\n",
785 val.data_type(), var_name, val.to_glsl()));
786 output_var_map.insert((node_id.0, idx), var_name);
787 }
788 continue;
789 }
790
791 if self.options.debug_comments {
793 code.push_str(&format!(" // {} (node {})\n",
794 node.node_type.display_name(), node_id.0));
795 }
796
797 let has_condition = self.options.enable_conditionals && node.conditional_var.is_some();
799 if has_condition {
800 let var_name = node.conditional_var.as_ref().unwrap();
801 code.push_str(&format!(" if (u_gs_{} > {}) {{\n",
802 var_name, format_float_glsl(node.conditional_threshold)));
803 }
804
805 let incoming = graph.incoming_connections(node_id);
807 let mut input_vars: Vec<String> = Vec::new();
808 for (idx, socket) in node.inputs.iter().enumerate() {
809 let conn = incoming.iter().find(|c| c.to_socket == idx);
810 if let Some(c) = conn {
811 let resolved_from = cse_map.get(&c.from_node).copied().unwrap_or(c.from_node);
812 if let Some(var) = output_var_map.get(&(resolved_from.0, c.from_socket)) {
813 input_vars.push(var.clone());
814 } else {
815 input_vars.push(socket.default_value.as_ref()
817 .map(|v| v.to_glsl())
818 .unwrap_or_default());
819 }
820 } else {
821 input_vars.push(String::new());
822 }
823 }
824
825 let prefix = node.var_prefix();
827 let snippet = node.node_type.generate_glsl(&prefix, &input_vars);
828
829 let indent = if has_condition { " " } else { " " };
830 for line in &snippet.lines {
831 code.push_str(&format!("{}{}\n", indent, line));
832 }
833
834 for (idx, var) in snippet.output_vars.iter().enumerate() {
836 output_var_map.insert((node_id.0, idx), var.clone());
837 }
838
839 let _ = emitted_cse.insert(node_id);
840
841 if has_condition {
843 code.push_str(" }\n");
844 }
845 }
846
847 code.push_str("}\n");
848
849 (code, output_var_map)
850 }
851
852 fn generate_vertex_shader(&self, varyings: &[VaryingDecl]) -> String {
853 let mut code = String::new();
854 code.push_str(&format!("#version {}\n", self.options.glsl_version));
855 code.push_str("precision highp float;\n\n");
856
857 code.push_str("layout(location = 0) in vec3 a_position;\n");
859 code.push_str("layout(location = 1) in vec3 a_normal;\n");
860 code.push_str("layout(location = 2) in vec2 a_uv;\n\n");
861
862 code.push_str("uniform mat4 u_model;\n");
864 code.push_str("uniform mat4 u_view;\n");
865 code.push_str("uniform mat4 u_projection;\n\n");
866
867 for v in varyings {
869 code.push_str(&format!("out {} {};\n", v.data_type, v.name));
870 }
871 code.push('\n');
872
873 code.push_str("void main() {\n");
874 code.push_str(" vec4 world_pos = u_model * vec4(a_position, 1.0);\n");
875 code.push_str(" v_position = world_pos.xyz;\n");
876 code.push_str(" v_normal = normalize((u_model * vec4(a_normal, 0.0)).xyz);\n");
877 code.push_str(" v_uv = a_uv;\n");
878 code.push_str(" gl_Position = u_projection * u_view * world_pos;\n");
879 code.push_str("}\n");
880
881 code
882 }
883}
884
885fn format_float_glsl(v: f32) -> String {
886 if v == v.floor() && v.abs() < 1e9 {
887 format!("{:.1}", v)
888 } else {
889 format!("{}", v)
890 }
891}
892
893pub fn compile_graph(graph: &ShaderGraph) -> Result<CompiledShader, CompileError> {
899 ShaderCompiler::with_defaults().compile(graph)
900}
901
902pub fn compile_graph_with(graph: &ShaderGraph, options: CompileOptions) -> Result<CompiledShader, CompileError> {
904 ShaderCompiler::new(options).compile(graph)
905}
906
907pub fn types_compatible(from: DataType, to: DataType) -> bool {
913 if from == to {
914 return true;
915 }
916 matches!((from, to),
918 (DataType::Float, DataType::Vec2)
919 | (DataType::Float, DataType::Vec3)
920 | (DataType::Float, DataType::Vec4)
921 | (DataType::Int, DataType::Float)
922 | (DataType::Bool, DataType::Float)
923 | (DataType::Bool, DataType::Int)
924 )
925}
926
927pub fn generate_cast(expr: &str, from: DataType, to: DataType) -> String {
929 if from == to {
930 return expr.to_string();
931 }
932 match (from, to) {
933 (DataType::Float, DataType::Vec2) => format!("vec2({})", expr),
934 (DataType::Float, DataType::Vec3) => format!("vec3({})", expr),
935 (DataType::Float, DataType::Vec4) => format!("vec4({})", expr),
936 (DataType::Int, DataType::Float) => format!("float({})", expr),
937 (DataType::Bool, DataType::Float) => format!("float({})", expr),
938 (DataType::Bool, DataType::Int) => format!("int({})", expr),
939 (DataType::Vec2, DataType::Vec3) => format!("vec3({}, 0.0)", expr),
940 (DataType::Vec2, DataType::Vec4) => format!("vec4({}, 0.0, 1.0)", expr),
941 (DataType::Vec3, DataType::Vec4) => format!("vec4({}, 1.0)", expr),
942 (DataType::Vec4, DataType::Vec3) => format!("{}.xyz", expr),
943 (DataType::Vec3, DataType::Vec2) => format!("{}.xy", expr),
944 (DataType::Vec4, DataType::Vec2) => format!("{}.xy", expr),
945 (DataType::Vec3, DataType::Float) => format!("length({})", expr),
946 (DataType::Vec4, DataType::Float) => format!("{}.x", expr),
947 _ => format!("{}({})", to, expr), }
949}
950
951pub struct ShaderVariantCache {
957 cache: HashMap<u64, CompiledShader>,
958}
959
960impl ShaderVariantCache {
961 pub fn new() -> Self {
962 Self { cache: HashMap::new() }
963 }
964
965 pub fn get_or_compile(
967 &mut self,
968 graph: &ShaderGraph,
969 compiler: &ShaderCompiler,
970 ) -> Result<&CompiledShader, CompileError> {
971 let hash = graph.topology_hash();
972 if !self.cache.contains_key(&hash) {
973 let compiled = compiler.compile(graph)?;
974 self.cache.insert(hash, compiled);
975 }
976 Ok(self.cache.get(&hash).unwrap())
977 }
978
979 pub fn invalidate(&mut self, hash: u64) {
981 self.cache.remove(&hash);
982 }
983
984 pub fn clear(&mut self) {
986 self.cache.clear();
987 }
988
989 pub fn len(&self) -> usize {
991 self.cache.len()
992 }
993
994 pub fn is_empty(&self) -> bool {
996 self.cache.is_empty()
997 }
998}
999
1000impl Default for ShaderVariantCache {
1001 fn default() -> Self {
1002 Self::new()
1003 }
1004}