1use std::collections::{HashMap, HashSet, VecDeque};
6use std::fmt;
7
8use crate::rendergraph::resources::{
9 ResourceDescriptor, ResourceHandle, ResourceLifetime, ResourceTable, SizePolicy, TextureFormat,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum PassType {
19 Graphics,
20 Compute,
21 Transfer,
22 Present,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum QueueAffinity {
28 Graphics,
29 Compute,
30 Transfer,
31 Any,
32}
33
34#[derive(Debug, Clone)]
36pub enum PassCondition {
37 Always,
39 FeatureEnabled(String),
41 Callback(String), All(Vec<PassCondition>),
45 Any(Vec<PassCondition>),
47}
48
49impl PassCondition {
50 pub fn evaluate(&self, features: &HashSet<String>, callbacks: &HashMap<String, bool>) -> bool {
52 match self {
53 Self::Always => true,
54 Self::FeatureEnabled(name) => features.contains(name),
55 Self::Callback(name) => callbacks.get(name).copied().unwrap_or(false),
56 Self::All(conds) => conds.iter().all(|c| c.evaluate(features, callbacks)),
57 Self::Any(conds) => conds.iter().any(|c| c.evaluate(features, callbacks)),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy)]
64pub struct ResolutionScale {
65 pub width_scale: f32,
66 pub height_scale: f32,
67}
68
69impl ResolutionScale {
70 pub fn full() -> Self {
71 Self {
72 width_scale: 1.0,
73 height_scale: 1.0,
74 }
75 }
76 pub fn half() -> Self {
77 Self {
78 width_scale: 0.5,
79 height_scale: 0.5,
80 }
81 }
82 pub fn quarter() -> Self {
83 Self {
84 width_scale: 0.25,
85 height_scale: 0.25,
86 }
87 }
88 pub fn custom(w: f32, h: f32) -> Self {
89 Self {
90 width_scale: w,
91 height_scale: h,
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
102pub struct RenderPass {
103 pub name: String,
104 pub pass_type: PassType,
105 pub queue: QueueAffinity,
106 pub condition: PassCondition,
107 pub resolution: ResolutionScale,
108 pub inputs: Vec<ResourceHandle>,
110 pub outputs: Vec<ResourceHandle>,
112 pub input_names: Vec<String>,
114 pub output_names: Vec<String>,
116 pub explicit_deps: Vec<String>,
118 pub has_side_effects: bool,
120 pub tag: Option<String>,
122}
123
124impl RenderPass {
125 pub fn new(name: &str, pass_type: PassType) -> Self {
126 Self {
127 name: name.to_string(),
128 pass_type,
129 queue: QueueAffinity::Graphics,
130 condition: PassCondition::Always,
131 resolution: ResolutionScale::full(),
132 inputs: Vec::new(),
133 outputs: Vec::new(),
134 input_names: Vec::new(),
135 output_names: Vec::new(),
136 explicit_deps: Vec::new(),
137 has_side_effects: false,
138 tag: None,
139 }
140 }
141
142 pub fn with_queue(mut self, queue: QueueAffinity) -> Self {
143 self.queue = queue;
144 self
145 }
146
147 pub fn with_condition(mut self, condition: PassCondition) -> Self {
148 self.condition = condition;
149 self
150 }
151
152 pub fn with_resolution(mut self, scale: ResolutionScale) -> Self {
153 self.resolution = scale;
154 self
155 }
156
157 pub fn with_side_effects(mut self) -> Self {
158 self.has_side_effects = true;
159 self
160 }
161
162 pub fn with_tag(mut self, tag: &str) -> Self {
163 self.tag = Some(tag.to_string());
164 self
165 }
166
167 pub fn add_input(&mut self, handle: ResourceHandle, name: &str) {
168 self.inputs.push(handle);
169 self.input_names.push(name.to_string());
170 }
171
172 pub fn add_output(&mut self, handle: ResourceHandle, name: &str) {
173 self.outputs.push(handle);
174 self.output_names.push(name.to_string());
175 }
176
177 pub fn depends_on(&mut self, pass_name: &str) {
178 if !self.explicit_deps.contains(&pass_name.to_string()) {
179 self.explicit_deps.push(pass_name.to_string());
180 }
181 }
182
183 pub fn is_async_compute_candidate(&self) -> bool {
185 self.pass_type == PassType::Compute && self.queue != QueueAffinity::Graphics
186 }
187}
188
189#[derive(Debug, Clone)]
196pub struct ResourceNode {
197 pub name: String,
198 pub handle: ResourceHandle,
199 pub descriptor: ResourceDescriptor,
200 pub lifetime: ResourceLifetime,
201 pub producer: Option<String>,
203 pub consumers: Vec<String>,
205}
206
207impl ResourceNode {
208 pub fn new(name: &str, handle: ResourceHandle, descriptor: ResourceDescriptor, lifetime: ResourceLifetime) -> Self {
209 Self {
210 name: name.to_string(),
211 handle,
212 descriptor,
213 lifetime,
214 producer: None,
215 consumers: Vec::new(),
216 }
217 }
218}
219
220#[derive(Debug, Clone)]
227pub struct PassDependency {
228 pub from_pass: String,
229 pub to_pass: String,
230 pub resource: String,
231 pub kind: DependencyKind,
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum DependencyKind {
236 ReadAfterWrite,
238 WriteAfterWrite,
240 WriteAfterRead,
242 Explicit,
244}
245
246impl fmt::Display for PassDependency {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 write!(
249 f,
250 "{} -> {} (via '{}', {:?})",
251 self.from_pass, self.to_pass, self.resource, self.kind
252 )
253 }
254}
255
256#[derive(Debug, Clone)]
262pub struct ValidationResult {
263 pub errors: Vec<String>,
264 pub warnings: Vec<String>,
265}
266
267impl ValidationResult {
268 pub fn new() -> Self {
269 Self {
270 errors: Vec::new(),
271 warnings: Vec::new(),
272 }
273 }
274
275 pub fn is_ok(&self) -> bool {
276 self.errors.is_empty()
277 }
278
279 pub fn error(&mut self, msg: impl Into<String>) {
280 self.errors.push(msg.into());
281 }
282
283 pub fn warning(&mut self, msg: impl Into<String>) {
284 self.warnings.push(msg.into());
285 }
286}
287
288impl Default for ValidationResult {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294impl fmt::Display for ValidationResult {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 if self.is_ok() {
297 write!(f, "Validation OK")?;
298 } else {
299 write!(f, "Validation FAILED ({} errors)", self.errors.len())?;
300 }
301 for e in &self.errors {
302 write!(f, "\n ERROR: {}", e)?;
303 }
304 for w in &self.warnings {
305 write!(f, "\n WARN: {}", w)?;
306 }
307 Ok(())
308 }
309}
310
311pub struct RenderGraph {
319 passes: HashMap<String, RenderPass>,
321 pass_order: Vec<String>,
323 resource_nodes: HashMap<String, ResourceNode>,
325 edges: Vec<PassDependency>,
327 sorted_passes: Vec<String>,
329 dirty: bool,
331 pub resource_table: ResourceTable,
333 features: HashSet<String>,
335 callback_values: HashMap<String, bool>,
337 label: String,
339}
340
341impl RenderGraph {
342 pub fn new(label: &str) -> Self {
343 Self {
344 passes: HashMap::new(),
345 pass_order: Vec::new(),
346 resource_nodes: HashMap::new(),
347 edges: Vec::new(),
348 sorted_passes: Vec::new(),
349 dirty: true,
350 resource_table: ResourceTable::new(),
351 features: HashSet::new(),
352 callback_values: HashMap::new(),
353 label: label.to_string(),
354 }
355 }
356
357 pub fn enable_feature(&mut self, feature: &str) {
360 self.features.insert(feature.to_string());
361 }
362
363 pub fn disable_feature(&mut self, feature: &str) {
364 self.features.remove(feature);
365 }
366
367 pub fn is_feature_enabled(&self, feature: &str) -> bool {
368 self.features.contains(feature)
369 }
370
371 pub fn set_callback(&mut self, name: &str, value: bool) {
372 self.callback_values.insert(name.to_string(), value);
373 }
374
375 pub fn declare_resource(&mut self, descriptor: ResourceDescriptor) -> ResourceHandle {
379 let name = descriptor.name.clone();
380 let handle = self.resource_table.declare_transient(descriptor.clone());
381 self.resource_nodes
382 .entry(name.clone())
383 .or_insert_with(|| ResourceNode::new(&name, handle, descriptor, ResourceLifetime::Transient));
384 self.dirty = true;
385 handle
386 }
387
388 pub fn import_resource(&mut self, descriptor: ResourceDescriptor) -> ResourceHandle {
390 let name = descriptor.name.clone();
391 let handle = self.resource_table.declare_imported(descriptor.clone());
392 self.resource_nodes
393 .entry(name.clone())
394 .or_insert_with(|| ResourceNode::new(&name, handle, descriptor, ResourceLifetime::Imported));
395 self.dirty = true;
396 handle
397 }
398
399 pub fn add_pass(&mut self, pass: RenderPass) {
403 let name = pass.name.clone();
404 for (h, rname) in pass.outputs.iter().zip(pass.output_names.iter()) {
406 self.resource_table.add_writer(*h, &name);
407 if let Some(rn) = self.resource_nodes.get_mut(rname) {
408 rn.producer = Some(name.clone());
409 }
410 }
411 for (h, rname) in pass.inputs.iter().zip(pass.input_names.iter()) {
412 self.resource_table.add_reader(*h, &name);
413 if let Some(rn) = self.resource_nodes.get_mut(rname) {
414 if !rn.consumers.contains(&name) {
415 rn.consumers.push(name.clone());
416 }
417 }
418 }
419 if !self.pass_order.contains(&name) {
420 self.pass_order.push(name.clone());
421 }
422 self.passes.insert(name, pass);
423 self.dirty = true;
424 }
425
426 pub fn remove_pass(&mut self, name: &str) -> Option<RenderPass> {
428 self.pass_order.retain(|n| n != name);
429 self.dirty = true;
430 self.passes.remove(name)
431 }
432
433 pub fn get_pass(&self, name: &str) -> Option<&RenderPass> {
435 self.passes.get(name)
436 }
437
438 pub fn get_pass_mut(&mut self, name: &str) -> Option<&mut RenderPass> {
440 self.dirty = true;
441 self.passes.get_mut(name)
442 }
443
444 pub fn pass_names(&self) -> &[String] {
446 &self.pass_order
447 }
448
449 pub fn pass_count(&self) -> usize {
451 self.passes.len()
452 }
453
454 pub fn resource_count(&self) -> usize {
456 self.resource_nodes.len()
457 }
458
459 pub fn build_edges(&mut self) {
463 self.edges.clear();
464
465 for (res_name, rn) in &self.resource_nodes {
467 if let Some(ref producer) = rn.producer {
468 for consumer in &rn.consumers {
469 if producer != consumer {
470 self.edges.push(PassDependency {
471 from_pass: producer.clone(),
472 to_pass: consumer.clone(),
473 resource: res_name.clone(),
474 kind: DependencyKind::ReadAfterWrite,
475 });
476 }
477 }
478 }
479 }
480
481 let pass_names: Vec<String> = self.passes.keys().cloned().collect();
483 for name in &pass_names {
484 let deps = self.passes[name].explicit_deps.clone();
485 for dep in deps {
486 if self.passes.contains_key(&dep) {
487 self.edges.push(PassDependency {
488 from_pass: dep,
489 to_pass: name.clone(),
490 resource: String::new(),
491 kind: DependencyKind::Explicit,
492 });
493 }
494 }
495 }
496 }
497
498 pub fn edges(&self) -> &[PassDependency] {
500 &self.edges
501 }
502
503 pub fn detect_cycles(&mut self) -> Vec<Vec<String>> {
508 self.build_edges();
509
510 let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
511 for name in self.passes.keys() {
512 adj.entry(name.as_str()).or_default();
513 }
514 for edge in &self.edges {
515 adj.entry(edge.from_pass.as_str())
516 .or_default()
517 .push(edge.to_pass.as_str());
518 }
519
520 let mut index_counter: u32 = 0;
522 let mut stack: Vec<&str> = Vec::new();
523 let mut on_stack: HashSet<&str> = HashSet::new();
524 let mut indices: HashMap<&str, u32> = HashMap::new();
525 let mut lowlinks: HashMap<&str, u32> = HashMap::new();
526 let mut sccs: Vec<Vec<String>> = Vec::new();
527
528 fn strongconnect<'a>(
529 v: &'a str,
530 adj: &HashMap<&'a str, Vec<&'a str>>,
531 index_counter: &mut u32,
532 stack: &mut Vec<&'a str>,
533 on_stack: &mut HashSet<&'a str>,
534 indices: &mut HashMap<&'a str, u32>,
535 lowlinks: &mut HashMap<&'a str, u32>,
536 sccs: &mut Vec<Vec<String>>,
537 ) {
538 indices.insert(v, *index_counter);
539 lowlinks.insert(v, *index_counter);
540 *index_counter += 1;
541 stack.push(v);
542 on_stack.insert(v);
543
544 if let Some(neighbors) = adj.get(v) {
545 for &w in neighbors {
546 if !indices.contains_key(w) {
547 strongconnect(w, adj, index_counter, stack, on_stack, indices, lowlinks, sccs);
548 let lw = lowlinks[w];
549 let lv = lowlinks[v];
550 lowlinks.insert(v, lv.min(lw));
551 } else if on_stack.contains(w) {
552 let iw = indices[w];
553 let lv = lowlinks[v];
554 lowlinks.insert(v, lv.min(iw));
555 }
556 }
557 }
558
559 if lowlinks[v] == indices[v] {
560 let mut scc = Vec::new();
561 while let Some(w) = stack.pop() {
562 on_stack.remove(w);
563 scc.push(w.to_string());
564 if w == v {
565 break;
566 }
567 }
568 if scc.len() > 1 {
569 sccs.push(scc);
570 }
571 }
572 }
573
574 let nodes: Vec<&str> = adj.keys().copied().collect();
575 for node in nodes {
576 if !indices.contains_key(node) {
577 strongconnect(
578 node,
579 &adj,
580 &mut index_counter,
581 &mut stack,
582 &mut on_stack,
583 &mut indices,
584 &mut lowlinks,
585 &mut sccs,
586 );
587 }
588 }
589
590 sccs
591 }
592
593 pub fn topological_sort(&mut self) -> Result<Vec<String>, Vec<String>> {
596 self.build_edges();
597
598 let mut in_degree: HashMap<&str, usize> = HashMap::new();
599 for name in self.passes.keys() {
600 in_degree.entry(name.as_str()).or_insert(0);
601 }
602 let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
603 for edge in &self.edges {
604 adj.entry(edge.from_pass.as_str())
605 .or_default()
606 .push(edge.to_pass.as_str());
607 *in_degree.entry(edge.to_pass.as_str()).or_insert(0) += 1;
608 }
609
610 let mut queue: VecDeque<&str> = VecDeque::new();
611 for (&node, °) in &in_degree {
612 if deg == 0 {
613 queue.push_back(node);
614 }
615 }
616
617 let order_map: HashMap<&str, usize> = self
619 .pass_order
620 .iter()
621 .enumerate()
622 .map(|(i, n)| (n.as_str(), i))
623 .collect();
624 let mut initial: Vec<&str> = queue.drain(..).collect();
625 initial.sort_by_key(|n| order_map.get(n).copied().unwrap_or(usize::MAX));
626 for n in initial {
627 queue.push_back(n);
628 }
629
630 let mut sorted: Vec<String> = Vec::new();
631 let mut visited = 0usize;
632
633 while let Some(node) = queue.pop_front() {
634 sorted.push(node.to_string());
635 visited += 1;
636 if let Some(neighbors) = adj.get(node) {
637 let mut next: Vec<&str> = Vec::new();
639 for &nb in neighbors {
640 let deg = in_degree.get_mut(nb).unwrap();
641 *deg -= 1;
642 if *deg == 0 {
643 next.push(nb);
644 }
645 }
646 next.sort_by_key(|n| order_map.get(n).copied().unwrap_or(usize::MAX));
647 for nb in next {
648 queue.push_back(nb);
649 }
650 }
651 }
652
653 if visited != self.passes.len() {
654 let sorted_set: HashSet<&str> = sorted.iter().map(|s| s.as_str()).collect();
656 let cycle_nodes: Vec<String> = self
657 .passes
658 .keys()
659 .filter(|k| !sorted_set.contains(k.as_str()))
660 .cloned()
661 .collect();
662 return Err(cycle_nodes);
663 }
664
665 self.sorted_passes = sorted.clone();
666 self.dirty = false;
667 Ok(sorted)
668 }
669
670 pub fn sorted(&mut self) -> Result<&[String], Vec<String>> {
672 if self.dirty {
673 self.topological_sort()?;
674 }
675 Ok(&self.sorted_passes)
676 }
677
678 pub fn active_passes(&mut self) -> Result<Vec<String>, Vec<String>> {
680 let sorted = self.topological_sort()?;
681 let features = &self.features;
682 let callbacks = &self.callback_values;
683 Ok(sorted
684 .into_iter()
685 .filter(|name| {
686 self.passes
687 .get(name)
688 .map(|p| p.condition.evaluate(features, callbacks))
689 .unwrap_or(false)
690 })
691 .collect())
692 }
693
694 pub fn validate(&mut self) -> ValidationResult {
699 let mut result = ValidationResult::new();
700
701 let cycles = self.detect_cycles();
703 for cycle in &cycles {
704 result.error(format!("Cycle detected involving passes: {}", cycle.join(", ")));
705 }
706
707 let dangling = self.resource_table.find_dangling();
709 for d in &dangling {
710 match d.kind {
711 crate::rendergraph::resources::DanglingKind::NeverWritten => {
712 result.error(format!("Resource '{}' is never written by any pass", d.name));
713 }
714 crate::rendergraph::resources::DanglingKind::NeverRead => {
715 result.warning(format!("Resource '{}' is never read by any pass", d.name));
716 }
717 }
718 }
719
720 for pass in self.passes.values() {
722 for input_name in &pass.input_names {
723 if self.resource_table.lookup(input_name).is_none() {
724 result.error(format!(
725 "Pass '{}' reads resource '{}' which is not declared",
726 pass.name, input_name
727 ));
728 }
729 }
730 for output_name in &pass.output_names {
731 if self.resource_table.lookup(output_name).is_none() {
732 result.error(format!(
733 "Pass '{}' writes resource '{}' which is not declared",
734 pass.name, output_name
735 ));
736 }
737 }
738 }
739
740 for pass in self.passes.values() {
742 for dep in &pass.explicit_deps {
743 if !self.passes.contains_key(dep) {
744 result.error(format!(
745 "Pass '{}' depends on '{}' which does not exist",
746 pass.name, dep
747 ));
748 }
749 }
750 }
751
752 for pass in self.passes.values() {
754 if pass.inputs.is_empty() && pass.outputs.is_empty() && !pass.has_side_effects {
755 result.warning(format!(
756 "Pass '{}' has no inputs, no outputs, and no side effects",
757 pass.name
758 ));
759 }
760 }
761
762 result
767 }
768
769 pub fn merge(&mut self, other: &RenderGraph, prefix: &str) {
774 for (name, rn) in &other.resource_nodes {
776 let new_name = if self.resource_nodes.contains_key(name) {
777 format!("{}_{}", prefix, name)
778 } else {
779 name.clone()
780 };
781 let mut desc = rn.descriptor.clone();
782 desc.name = new_name.clone();
783 let handle = if rn.lifetime == ResourceLifetime::Imported {
784 self.import_resource(desc)
785 } else {
786 self.declare_resource(desc)
787 };
788 let _ = handle;
790 }
791
792 for (name, pass) in &other.passes {
794 let new_name = if self.passes.contains_key(name) {
795 format!("{}_{}", prefix, name)
796 } else {
797 name.clone()
798 };
799 let mut new_pass = RenderPass::new(&new_name, pass.pass_type);
800 new_pass.queue = pass.queue;
801 new_pass.condition = pass.condition.clone();
802 new_pass.resolution = pass.resolution;
803 new_pass.has_side_effects = pass.has_side_effects;
804 new_pass.tag = pass.tag.clone();
805
806 for iname in &pass.input_names {
808 let mapped = if self.resource_nodes.contains_key(iname) && other.resource_nodes.contains_key(iname) {
809 if self.resource_nodes.contains_key(&format!("{}_{}", prefix, iname)) {
811 format!("{}_{}", prefix, iname)
812 } else {
813 iname.clone()
814 }
815 } else {
816 iname.clone()
817 };
818 if let Some(h) = self.resource_table.lookup(&mapped) {
819 new_pass.add_input(h, &mapped);
820 }
821 }
822 for oname in &pass.output_names {
823 let mapped = if self.resource_nodes.contains_key(oname) && other.resource_nodes.contains_key(oname) {
824 if self.resource_nodes.contains_key(&format!("{}_{}", prefix, oname)) {
825 format!("{}_{}", prefix, oname)
826 } else {
827 oname.clone()
828 }
829 } else {
830 oname.clone()
831 };
832 if let Some(h) = self.resource_table.lookup(&mapped) {
833 new_pass.add_output(h, &mapped);
834 }
835 }
836
837 for dep in &pass.explicit_deps {
839 let mapped_dep = if self.passes.contains_key(dep) && other.passes.contains_key(dep) {
840 format!("{}_{}", prefix, dep)
841 } else {
842 dep.clone()
843 };
844 new_pass.depends_on(&mapped_dep);
845 }
846
847 self.add_pass(new_pass);
848 }
849
850 self.dirty = true;
851 }
852
853 pub fn export_dot(&mut self) -> String {
857 self.build_edges();
859
860 let mut dot = String::new();
861 dot.push_str(&format!("digraph \"{}\" {{\n", self.label));
862 dot.push_str(" rankdir=LR;\n");
863 dot.push_str(" node [shape=box, style=filled];\n\n");
864
865 dot.push_str(" // Render passes\n");
867 for (name, pass) in &self.passes {
868 let color = match pass.pass_type {
869 PassType::Graphics => "#4a90d9",
870 PassType::Compute => "#d94a4a",
871 PassType::Transfer => "#4ad94a",
872 PassType::Present => "#d9d94a",
873 };
874 let active = pass.condition.evaluate(&self.features, &self.callback_values);
875 let style = if active { "filled" } else { "filled,dashed" };
876 let label = format!(
877 "{}\\n[{:?}]{}",
878 name,
879 pass.pass_type,
880 if !active { " (DISABLED)" } else { "" }
881 );
882 dot.push_str(&format!(
883 " \"pass_{}\" [label=\"{}\", fillcolor=\"{}\", style=\"{}\", fontcolor=white];\n",
884 name, label, color, style
885 ));
886 }
887
888 dot.push_str("\n // Resources\n");
890 for (name, rn) in &self.resource_nodes {
891 let shape = match rn.lifetime {
892 ResourceLifetime::Transient => "ellipse",
893 ResourceLifetime::Imported => "diamond",
894 };
895 let label = format!(
896 "{}\\n{:?}",
897 name, rn.descriptor.format
898 );
899 dot.push_str(&format!(
900 " \"res_{}\" [label=\"{}\", shape={}, fillcolor=\"#e0e0e0\", fontcolor=black];\n",
901 name, label, shape
902 ));
903 }
904
905 dot.push_str("\n // Edges\n");
907 for pass in self.passes.values() {
908 for oname in &pass.output_names {
909 dot.push_str(&format!(
910 " \"pass_{}\" -> \"res_{}\" [color=red, label=\"write\"];\n",
911 pass.name, oname
912 ));
913 }
914 for iname in &pass.input_names {
915 dot.push_str(&format!(
916 " \"res_{}\" -> \"pass_{}\" [color=blue, label=\"read\"];\n",
917 iname, pass.name
918 ));
919 }
920 }
921
922 for pass in self.passes.values() {
924 for dep in &pass.explicit_deps {
925 dot.push_str(&format!(
926 " \"pass_{}\" -> \"pass_{}\" [style=dashed, color=gray, label=\"explicit\"];\n",
927 dep, pass.name
928 ));
929 }
930 }
931
932 dot.push_str("}\n");
933 dot
934 }
935
936 pub fn label(&self) -> &str {
939 &self.label
940 }
941
942 pub fn resource_node(&self, name: &str) -> Option<&ResourceNode> {
943 self.resource_nodes.get(name)
944 }
945
946 pub fn all_passes(&self) -> impl Iterator<Item = &RenderPass> {
947 self.passes.values()
948 }
949
950 pub fn all_resource_nodes(&self) -> impl Iterator<Item = &ResourceNode> {
951 self.resource_nodes.values()
952 }
953
954 pub fn features(&self) -> &HashSet<String> {
955 &self.features
956 }
957
958 pub fn passes_by_tag(&self) -> HashMap<String, Vec<&RenderPass>> {
960 let mut map: HashMap<String, Vec<&RenderPass>> = HashMap::new();
961 for pass in self.passes.values() {
962 let tag = pass.tag.clone().unwrap_or_else(|| "untagged".to_string());
963 map.entry(tag).or_default().push(pass);
964 }
965 map
966 }
967}
968
969impl fmt::Display for RenderGraph {
970 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971 write!(
972 f,
973 "RenderGraph '{}': {} passes, {} resources, {} edges",
974 self.label,
975 self.passes.len(),
976 self.resource_nodes.len(),
977 self.edges.len(),
978 )
979 }
980}
981
982pub struct RenderGraphBuilder {
988 graph: RenderGraph,
989 backbuffer_width: u32,
990 backbuffer_height: u32,
991}
992
993impl RenderGraphBuilder {
994 pub fn new(label: &str, width: u32, height: u32) -> Self {
995 Self {
996 graph: RenderGraph::new(label),
997 backbuffer_width: width,
998 backbuffer_height: height,
999 }
1000 }
1001
1002 pub fn backbuffer_size(&self) -> (u32, u32) {
1003 (self.backbuffer_width, self.backbuffer_height)
1004 }
1005
1006 pub fn texture(&mut self, name: &str, format: TextureFormat) -> ResourceHandle {
1008 let desc = ResourceDescriptor::new(name, format);
1009 self.graph.declare_resource(desc)
1010 }
1011
1012 pub fn texture_scaled(
1014 &mut self,
1015 name: &str,
1016 format: TextureFormat,
1017 width_scale: f32,
1018 height_scale: f32,
1019 ) -> ResourceHandle {
1020 let desc = ResourceDescriptor::new(name, format).with_size(SizePolicy::Relative {
1021 width_scale,
1022 height_scale,
1023 });
1024 self.graph.declare_resource(desc)
1025 }
1026
1027 pub fn texture_absolute(
1029 &mut self,
1030 name: &str,
1031 format: TextureFormat,
1032 width: u32,
1033 height: u32,
1034 ) -> ResourceHandle {
1035 let desc = ResourceDescriptor::new(name, format).with_size(SizePolicy::Absolute { width, height });
1036 self.graph.declare_resource(desc)
1037 }
1038
1039 pub fn import(&mut self, name: &str, format: TextureFormat) -> ResourceHandle {
1041 let desc = ResourceDescriptor::new(name, format);
1042 self.graph.import_resource(desc)
1043 }
1044
1045 pub fn graphics_pass(&mut self, name: &str) -> PassBuilder<'_> {
1047 PassBuilder {
1048 graph: &mut self.graph,
1049 pass: RenderPass::new(name, PassType::Graphics),
1050 }
1051 }
1052
1053 pub fn compute_pass(&mut self, name: &str) -> PassBuilder<'_> {
1055 PassBuilder {
1056 graph: &mut self.graph,
1057 pass: RenderPass::new(name, PassType::Compute),
1058 }
1059 }
1060
1061 pub fn enable_feature(&mut self, feature: &str) -> &mut Self {
1063 self.graph.enable_feature(feature);
1064 self
1065 }
1066
1067 pub fn build(self) -> RenderGraph {
1069 self.graph
1070 }
1071}
1072
1073pub struct PassBuilder<'a> {
1075 graph: &'a mut RenderGraph,
1076 pass: RenderPass,
1077}
1078
1079impl<'a> PassBuilder<'a> {
1080 pub fn reads(mut self, handle: ResourceHandle, name: &str) -> Self {
1081 self.pass.add_input(handle, name);
1082 self
1083 }
1084
1085 pub fn writes(mut self, handle: ResourceHandle, name: &str) -> Self {
1086 self.pass.add_output(handle, name);
1087 self
1088 }
1089
1090 pub fn depends_on(mut self, pass_name: &str) -> Self {
1091 self.pass.depends_on(pass_name);
1092 self
1093 }
1094
1095 pub fn condition(mut self, cond: PassCondition) -> Self {
1096 self.pass.condition = cond;
1097 self
1098 }
1099
1100 pub fn resolution(mut self, scale: ResolutionScale) -> Self {
1101 self.pass.resolution = scale;
1102 self
1103 }
1104
1105 pub fn queue(mut self, q: QueueAffinity) -> Self {
1106 self.pass.queue = q;
1107 self
1108 }
1109
1110 pub fn side_effects(mut self) -> Self {
1111 self.pass.has_side_effects = true;
1112 self
1113 }
1114
1115 pub fn tag(mut self, t: &str) -> Self {
1116 self.pass.tag = Some(t.to_string());
1117 self
1118 }
1119
1120 pub fn finish(self) {
1122 self.graph.add_pass(self.pass);
1123 }
1124}
1125
1126#[derive(Debug, Clone)]
1132pub struct PassConfig {
1133 pub name: String,
1134 pub pass_type: PassType,
1135 pub inputs: Vec<String>,
1136 pub outputs: Vec<String>,
1137 pub condition: Option<String>,
1138 pub resolution_scale: Option<(f32, f32)>,
1139 pub queue: QueueAffinity,
1140 pub explicit_deps: Vec<String>,
1141}
1142
1143#[derive(Debug, Clone)]
1145pub struct ResourceConfig {
1146 pub name: String,
1147 pub format: TextureFormat,
1148 pub size: SizePolicy,
1149 pub imported: bool,
1150}
1151
1152#[derive(Debug, Clone)]
1154pub struct GraphConfig {
1155 pub label: String,
1156 pub resources: Vec<ResourceConfig>,
1157 pub passes: Vec<PassConfig>,
1158 pub features: Vec<String>,
1159}
1160
1161impl GraphConfig {
1162 pub fn build(&self) -> RenderGraph {
1164 let mut graph = RenderGraph::new(&self.label);
1165
1166 let mut handles: HashMap<String, ResourceHandle> = HashMap::new();
1168 for rc in &self.resources {
1169 let desc = ResourceDescriptor::new(&rc.name, rc.format).with_size(rc.size);
1170 let h = if rc.imported {
1171 graph.import_resource(desc)
1172 } else {
1173 graph.declare_resource(desc)
1174 };
1175 handles.insert(rc.name.clone(), h);
1176 }
1177
1178 for f in &self.features {
1180 graph.enable_feature(f);
1181 }
1182
1183 for pc in &self.passes {
1185 let mut pass = RenderPass::new(&pc.name, pc.pass_type);
1186 pass.queue = pc.queue;
1187
1188 if let Some(ref cond) = pc.condition {
1189 pass.condition = PassCondition::FeatureEnabled(cond.clone());
1190 }
1191 if let Some((ws, hs)) = pc.resolution_scale {
1192 pass.resolution = ResolutionScale::custom(ws, hs);
1193 }
1194
1195 for iname in &pc.inputs {
1196 if let Some(&h) = handles.get(iname) {
1197 pass.add_input(h, iname);
1198 }
1199 }
1200 for oname in &pc.outputs {
1201 if let Some(&h) = handles.get(oname) {
1202 pass.add_output(h, oname);
1203 }
1204 }
1205 for dep in &pc.explicit_deps {
1206 pass.depends_on(dep);
1207 }
1208
1209 graph.add_pass(pass);
1210 }
1211
1212 graph
1213 }
1214}
1215
1216#[cfg(test)]
1221mod tests {
1222 use super::*;
1223
1224 fn simple_graph() -> RenderGraph {
1225 let mut b = RenderGraphBuilder::new("test", 1920, 1080);
1226 let depth = b.texture("depth", TextureFormat::Depth32Float);
1227 let color = b.texture("color", TextureFormat::Rgba16Float);
1228 let final_rt = b.texture("final", TextureFormat::Rgba8Unorm);
1229
1230 b.graphics_pass("depth_pre")
1231 .writes(depth, "depth")
1232 .tag("geometry")
1233 .finish();
1234
1235 b.graphics_pass("lighting")
1236 .reads(depth, "depth")
1237 .writes(color, "color")
1238 .tag("lighting")
1239 .finish();
1240
1241 b.graphics_pass("tonemap")
1242 .reads(color, "color")
1243 .writes(final_rt, "final")
1244 .tag("post")
1245 .finish();
1246
1247 b.build()
1248 }
1249
1250 #[test]
1251 fn test_topological_sort() {
1252 let mut g = simple_graph();
1253 let sorted = g.topological_sort().unwrap();
1254 assert_eq!(sorted, vec!["depth_pre", "lighting", "tonemap"]);
1255 }
1256
1257 #[test]
1258 fn test_cycle_detection() {
1259 let mut g = RenderGraph::new("cycle_test");
1260 let r1 = g.declare_resource(ResourceDescriptor::new("r1", TextureFormat::Rgba8Unorm));
1261 let r2 = g.declare_resource(ResourceDescriptor::new("r2", TextureFormat::Rgba8Unorm));
1262
1263 let mut pa = RenderPass::new("a", PassType::Graphics);
1264 pa.add_input(r2, "r2");
1265 pa.add_output(r1, "r1");
1266 g.add_pass(pa);
1267
1268 let mut pb = RenderPass::new("b", PassType::Graphics);
1269 pb.add_input(r1, "r1");
1270 pb.add_output(r2, "r2");
1271 g.add_pass(pb);
1272
1273 let result = g.topological_sort();
1274 assert!(result.is_err());
1275 }
1276
1277 #[test]
1278 fn test_conditional_pass() {
1279 let mut g = simple_graph();
1280 g.get_pass_mut("tonemap").unwrap().condition =
1282 PassCondition::FeatureEnabled("hdr_output".to_string());
1283
1284 let active = g.active_passes().unwrap();
1285 assert!(!active.contains(&"tonemap".to_string()));
1286 assert!(active.contains(&"depth_pre".to_string()));
1287
1288 g.enable_feature("hdr_output");
1290 let active = g.active_passes().unwrap();
1291 assert!(active.contains(&"tonemap".to_string()));
1292 }
1293
1294 #[test]
1295 fn test_validation() {
1296 let mut g = simple_graph();
1297 let result = g.validate();
1298 assert!(result.is_ok());
1299 }
1300
1301 #[test]
1302 fn test_dot_export() {
1303 let mut g = simple_graph();
1304 let dot = g.export_dot();
1305 assert!(dot.contains("digraph"));
1306 assert!(dot.contains("depth_pre"));
1307 assert!(dot.contains("lighting"));
1308 assert!(dot.contains("tonemap"));
1309 }
1310
1311 #[test]
1312 fn test_merge() {
1313 let mut g1 = simple_graph();
1314 let g2 = simple_graph();
1315 g1.merge(&g2, "post");
1316 assert!(g1.pass_count() > 3);
1318 }
1319
1320 #[test]
1321 fn test_graph_config_build() {
1322 let config = GraphConfig {
1323 label: "from_config".to_string(),
1324 resources: vec![
1325 ResourceConfig {
1326 name: "depth".to_string(),
1327 format: TextureFormat::Depth32Float,
1328 size: SizePolicy::Relative {
1329 width_scale: 1.0,
1330 height_scale: 1.0,
1331 },
1332 imported: false,
1333 },
1334 ResourceConfig {
1335 name: "color".to_string(),
1336 format: TextureFormat::Rgba16Float,
1337 size: SizePolicy::Relative {
1338 width_scale: 1.0,
1339 height_scale: 1.0,
1340 },
1341 imported: false,
1342 },
1343 ],
1344 passes: vec![
1345 PassConfig {
1346 name: "depth_pre".to_string(),
1347 pass_type: PassType::Graphics,
1348 inputs: vec![],
1349 outputs: vec!["depth".to_string()],
1350 condition: None,
1351 resolution_scale: None,
1352 queue: QueueAffinity::Graphics,
1353 explicit_deps: vec![],
1354 },
1355 PassConfig {
1356 name: "lighting".to_string(),
1357 pass_type: PassType::Graphics,
1358 inputs: vec!["depth".to_string()],
1359 outputs: vec!["color".to_string()],
1360 condition: None,
1361 resolution_scale: None,
1362 queue: QueueAffinity::Graphics,
1363 explicit_deps: vec![],
1364 },
1365 ],
1366 features: vec![],
1367 };
1368 let mut graph = config.build();
1369 let sorted = graph.topological_sort().unwrap();
1370 assert_eq!(sorted, vec!["depth_pre", "lighting"]);
1371 }
1372
1373 #[test]
1374 fn test_pass_builder_chain() {
1375 let mut b = RenderGraphBuilder::new("builder_test", 1280, 720);
1376 let bloom_half = b.texture_scaled("bloom_half", TextureFormat::Rgba16Float, 0.5, 0.5);
1377 let bloom_quarter = b.texture_scaled("bloom_quarter", TextureFormat::Rgba16Float, 0.25, 0.25);
1378 let color = b.texture("hdr_color", TextureFormat::Rgba16Float);
1379
1380 b.graphics_pass("bloom_down")
1381 .reads(color, "hdr_color")
1382 .writes(bloom_half, "bloom_half")
1383 .resolution(ResolutionScale::half())
1384 .tag("bloom")
1385 .finish();
1386
1387 b.graphics_pass("bloom_down2")
1388 .reads(bloom_half, "bloom_half")
1389 .writes(bloom_quarter, "bloom_quarter")
1390 .resolution(ResolutionScale::quarter())
1391 .tag("bloom")
1392 .finish();
1393
1394 let graph = b.build();
1395 assert_eq!(graph.pass_count(), 2);
1396 assert_eq!(graph.resource_count(), 3);
1397 }
1398}