Skip to main content

proof_engine/rendergraph/
graph.rs

1//! Declarative render graph: named passes, resource nodes, dependency edges,
2//! topological sort, cycle detection, conditional passes, multi-resolution
3//! passes, validation, merging, and DOT export.
4
5use std::collections::{HashMap, HashSet, VecDeque};
6use std::fmt;
7
8use crate::rendergraph::resources::{
9    ResourceDescriptor, ResourceHandle, ResourceLifetime, ResourceTable, SizePolicy, TextureFormat,
10};
11
12// ---------------------------------------------------------------------------
13// Pass / node types
14// ---------------------------------------------------------------------------
15
16/// The kind of work a render pass performs.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum PassType {
19    Graphics,
20    Compute,
21    Transfer,
22    Present,
23}
24
25/// Queue hint for the executor.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum QueueAffinity {
28    Graphics,
29    Compute,
30    Transfer,
31    Any,
32}
33
34/// Condition that controls whether a pass executes.
35#[derive(Debug, Clone)]
36pub enum PassCondition {
37    /// Always execute.
38    Always,
39    /// Execute only when the named feature is enabled.
40    FeatureEnabled(String),
41    /// Execute only when a boolean callback returns true.
42    Callback(String), // name of the callback — actual fn stored externally
43    /// Combine conditions with AND.
44    All(Vec<PassCondition>),
45    /// Combine conditions with OR.
46    Any(Vec<PassCondition>),
47}
48
49impl PassCondition {
50    /// Evaluate the condition against a set of enabled features and named booleans.
51    pub fn evaluate(&self, features: &HashSet<String>, callbacks: &HashMap<String, bool>) -> bool {
52        match self {
53            Self::Always => true,
54            Self::FeatureEnabled(name) => features.contains(name),
55            Self::Callback(name) => callbacks.get(name).copied().unwrap_or(false),
56            Self::All(conds) => conds.iter().all(|c| c.evaluate(features, callbacks)),
57            Self::Any(conds) => conds.iter().any(|c| c.evaluate(features, callbacks)),
58        }
59    }
60}
61
62/// Resolution multiplier for a pass (relative to its declared resource sizes).
63#[derive(Debug, Clone, Copy)]
64pub struct ResolutionScale {
65    pub width_scale: f32,
66    pub height_scale: f32,
67}
68
69impl ResolutionScale {
70    pub fn full() -> Self {
71        Self {
72            width_scale: 1.0,
73            height_scale: 1.0,
74        }
75    }
76    pub fn half() -> Self {
77        Self {
78            width_scale: 0.5,
79            height_scale: 0.5,
80        }
81    }
82    pub fn quarter() -> Self {
83        Self {
84            width_scale: 0.25,
85            height_scale: 0.25,
86        }
87    }
88    pub fn custom(w: f32, h: f32) -> Self {
89        Self {
90            width_scale: w,
91            height_scale: h,
92        }
93    }
94}
95
96// ---------------------------------------------------------------------------
97// Render pass node
98// ---------------------------------------------------------------------------
99
100/// A named node in the render graph representing a render pass.
101#[derive(Debug, Clone)]
102pub struct RenderPass {
103    pub name: String,
104    pub pass_type: PassType,
105    pub queue: QueueAffinity,
106    pub condition: PassCondition,
107    pub resolution: ResolutionScale,
108    /// Resource handles this pass reads.
109    pub inputs: Vec<ResourceHandle>,
110    /// Resource handles this pass writes.
111    pub outputs: Vec<ResourceHandle>,
112    /// Names of input resources (for serialization / debug).
113    pub input_names: Vec<String>,
114    /// Names of output resources.
115    pub output_names: Vec<String>,
116    /// Explicit ordering dependencies (pass names that must run before this).
117    pub explicit_deps: Vec<String>,
118    /// Whether this pass has side effects (e.g., writes to swapchain).
119    pub has_side_effects: bool,
120    /// User-attached tag for grouping.
121    pub tag: Option<String>,
122}
123
124impl RenderPass {
125    pub fn new(name: &str, pass_type: PassType) -> Self {
126        Self {
127            name: name.to_string(),
128            pass_type,
129            queue: QueueAffinity::Graphics,
130            condition: PassCondition::Always,
131            resolution: ResolutionScale::full(),
132            inputs: Vec::new(),
133            outputs: Vec::new(),
134            input_names: Vec::new(),
135            output_names: Vec::new(),
136            explicit_deps: Vec::new(),
137            has_side_effects: false,
138            tag: None,
139        }
140    }
141
142    pub fn with_queue(mut self, queue: QueueAffinity) -> Self {
143        self.queue = queue;
144        self
145    }
146
147    pub fn with_condition(mut self, condition: PassCondition) -> Self {
148        self.condition = condition;
149        self
150    }
151
152    pub fn with_resolution(mut self, scale: ResolutionScale) -> Self {
153        self.resolution = scale;
154        self
155    }
156
157    pub fn with_side_effects(mut self) -> Self {
158        self.has_side_effects = true;
159        self
160    }
161
162    pub fn with_tag(mut self, tag: &str) -> Self {
163        self.tag = Some(tag.to_string());
164        self
165    }
166
167    pub fn add_input(&mut self, handle: ResourceHandle, name: &str) {
168        self.inputs.push(handle);
169        self.input_names.push(name.to_string());
170    }
171
172    pub fn add_output(&mut self, handle: ResourceHandle, name: &str) {
173        self.outputs.push(handle);
174        self.output_names.push(name.to_string());
175    }
176
177    pub fn depends_on(&mut self, pass_name: &str) {
178        if !self.explicit_deps.contains(&pass_name.to_string()) {
179            self.explicit_deps.push(pass_name.to_string());
180        }
181    }
182
183    /// True if this pass can potentially run on the async compute queue.
184    pub fn is_async_compute_candidate(&self) -> bool {
185        self.pass_type == PassType::Compute && self.queue != QueueAffinity::Graphics
186    }
187}
188
189// ---------------------------------------------------------------------------
190// Resource node
191// ---------------------------------------------------------------------------
192
193/// A resource node in the graph. Resources are vertices connected to passes
194/// via read/write edges.
195#[derive(Debug, Clone)]
196pub struct ResourceNode {
197    pub name: String,
198    pub handle: ResourceHandle,
199    pub descriptor: ResourceDescriptor,
200    pub lifetime: ResourceLifetime,
201    /// Pass that produces this resource (if any).
202    pub producer: Option<String>,
203    /// Passes that consume this resource.
204    pub consumers: Vec<String>,
205}
206
207impl ResourceNode {
208    pub fn new(name: &str, handle: ResourceHandle, descriptor: ResourceDescriptor, lifetime: ResourceLifetime) -> Self {
209        Self {
210            name: name.to_string(),
211            handle,
212            descriptor,
213            lifetime,
214            producer: None,
215            consumers: Vec::new(),
216        }
217    }
218}
219
220// ---------------------------------------------------------------------------
221// Dependency edge
222// ---------------------------------------------------------------------------
223
224/// An edge in the render graph, connecting a producer pass to a consumer pass
225/// through a shared resource.
226#[derive(Debug, Clone)]
227pub struct PassDependency {
228    pub from_pass: String,
229    pub to_pass: String,
230    pub resource: String,
231    pub kind: DependencyKind,
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum DependencyKind {
236    /// Consumer reads resource written by producer.
237    ReadAfterWrite,
238    /// Consumer writes resource previously written by producer (execution ordering).
239    WriteAfterWrite,
240    /// Consumer writes resource previously read by producer.
241    WriteAfterRead,
242    /// Explicit ordering dependency (no resource involved).
243    Explicit,
244}
245
246impl fmt::Display for PassDependency {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        write!(
249            f,
250            "{} -> {} (via '{}', {:?})",
251            self.from_pass, self.to_pass, self.resource, self.kind
252        )
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Validation result
258// ---------------------------------------------------------------------------
259
260/// Result of validating a render graph.
261#[derive(Debug, Clone)]
262pub struct ValidationResult {
263    pub errors: Vec<String>,
264    pub warnings: Vec<String>,
265}
266
267impl ValidationResult {
268    pub fn new() -> Self {
269        Self {
270            errors: Vec::new(),
271            warnings: Vec::new(),
272        }
273    }
274
275    pub fn is_ok(&self) -> bool {
276        self.errors.is_empty()
277    }
278
279    pub fn error(&mut self, msg: impl Into<String>) {
280        self.errors.push(msg.into());
281    }
282
283    pub fn warning(&mut self, msg: impl Into<String>) {
284        self.warnings.push(msg.into());
285    }
286}
287
288impl Default for ValidationResult {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294impl fmt::Display for ValidationResult {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        if self.is_ok() {
297            write!(f, "Validation OK")?;
298        } else {
299            write!(f, "Validation FAILED ({} errors)", self.errors.len())?;
300        }
301        for e in &self.errors {
302            write!(f, "\n  ERROR: {}", e)?;
303        }
304        for w in &self.warnings {
305            write!(f, "\n  WARN:  {}", w)?;
306        }
307        Ok(())
308    }
309}
310
311// ---------------------------------------------------------------------------
312// Render graph
313// ---------------------------------------------------------------------------
314
315/// A declarative render graph: collection of passes (nodes) connected via
316/// resource dependencies (edges). Supports topological sorting, cycle
317/// detection, validation, conditional execution, and DOT export.
318pub struct RenderGraph {
319    /// All render passes, keyed by name.
320    passes: HashMap<String, RenderPass>,
321    /// Insertion order of passes.
322    pass_order: Vec<String>,
323    /// All resource nodes, keyed by name.
324    resource_nodes: HashMap<String, ResourceNode>,
325    /// Computed dependency edges.
326    edges: Vec<PassDependency>,
327    /// Topologically sorted pass names (computed lazily).
328    sorted_passes: Vec<String>,
329    /// Whether the sorted order is stale.
330    dirty: bool,
331    /// Resource table for handle bookkeeping.
332    pub resource_table: ResourceTable,
333    /// Enabled features for conditional passes.
334    features: HashSet<String>,
335    /// Named boolean callbacks for conditional passes.
336    callback_values: HashMap<String, bool>,
337    /// Graph label (for DOT export).
338    label: String,
339}
340
341impl RenderGraph {
342    pub fn new(label: &str) -> Self {
343        Self {
344            passes: HashMap::new(),
345            pass_order: Vec::new(),
346            resource_nodes: HashMap::new(),
347            edges: Vec::new(),
348            sorted_passes: Vec::new(),
349            dirty: true,
350            resource_table: ResourceTable::new(),
351            features: HashSet::new(),
352            callback_values: HashMap::new(),
353            label: label.to_string(),
354        }
355    }
356
357    // -- Feature / condition API ------------------------------------------
358
359    pub fn enable_feature(&mut self, feature: &str) {
360        self.features.insert(feature.to_string());
361    }
362
363    pub fn disable_feature(&mut self, feature: &str) {
364        self.features.remove(feature);
365    }
366
367    pub fn is_feature_enabled(&self, feature: &str) -> bool {
368        self.features.contains(feature)
369    }
370
371    pub fn set_callback(&mut self, name: &str, value: bool) {
372        self.callback_values.insert(name.to_string(), value);
373    }
374
375    // -- Resource declaration API -----------------------------------------
376
377    /// Declare a transient texture resource.
378    pub fn declare_resource(&mut self, descriptor: ResourceDescriptor) -> ResourceHandle {
379        let name = descriptor.name.clone();
380        let handle = self.resource_table.declare_transient(descriptor.clone());
381        self.resource_nodes
382            .entry(name.clone())
383            .or_insert_with(|| ResourceNode::new(&name, handle, descriptor, ResourceLifetime::Transient));
384        self.dirty = true;
385        handle
386    }
387
388    /// Declare an imported (externally managed) resource.
389    pub fn import_resource(&mut self, descriptor: ResourceDescriptor) -> ResourceHandle {
390        let name = descriptor.name.clone();
391        let handle = self.resource_table.declare_imported(descriptor.clone());
392        self.resource_nodes
393            .entry(name.clone())
394            .or_insert_with(|| ResourceNode::new(&name, handle, descriptor, ResourceLifetime::Imported));
395        self.dirty = true;
396        handle
397    }
398
399    // -- Pass API ---------------------------------------------------------
400
401    /// Add a render pass to the graph.
402    pub fn add_pass(&mut self, pass: RenderPass) {
403        let name = pass.name.clone();
404        // Update resource table writers/readers
405        for (h, rname) in pass.outputs.iter().zip(pass.output_names.iter()) {
406            self.resource_table.add_writer(*h, &name);
407            if let Some(rn) = self.resource_nodes.get_mut(rname) {
408                rn.producer = Some(name.clone());
409            }
410        }
411        for (h, rname) in pass.inputs.iter().zip(pass.input_names.iter()) {
412            self.resource_table.add_reader(*h, &name);
413            if let Some(rn) = self.resource_nodes.get_mut(rname) {
414                if !rn.consumers.contains(&name) {
415                    rn.consumers.push(name.clone());
416                }
417            }
418        }
419        if !self.pass_order.contains(&name) {
420            self.pass_order.push(name.clone());
421        }
422        self.passes.insert(name, pass);
423        self.dirty = true;
424    }
425
426    /// Remove a pass by name.
427    pub fn remove_pass(&mut self, name: &str) -> Option<RenderPass> {
428        self.pass_order.retain(|n| n != name);
429        self.dirty = true;
430        self.passes.remove(name)
431    }
432
433    /// Get a pass by name.
434    pub fn get_pass(&self, name: &str) -> Option<&RenderPass> {
435        self.passes.get(name)
436    }
437
438    /// Get a mutable pass by name.
439    pub fn get_pass_mut(&mut self, name: &str) -> Option<&mut RenderPass> {
440        self.dirty = true;
441        self.passes.get_mut(name)
442    }
443
444    /// All pass names in insertion order.
445    pub fn pass_names(&self) -> &[String] {
446        &self.pass_order
447    }
448
449    /// Number of passes.
450    pub fn pass_count(&self) -> usize {
451        self.passes.len()
452    }
453
454    /// Number of resource nodes.
455    pub fn resource_count(&self) -> usize {
456        self.resource_nodes.len()
457    }
458
459    // -- Dependency building ----------------------------------------------
460
461    /// Rebuild dependency edges from resource read/write declarations.
462    pub fn build_edges(&mut self) {
463        self.edges.clear();
464
465        // For each resource, connect producer -> consumers
466        for (res_name, rn) in &self.resource_nodes {
467            if let Some(ref producer) = rn.producer {
468                for consumer in &rn.consumers {
469                    if producer != consumer {
470                        self.edges.push(PassDependency {
471                            from_pass: producer.clone(),
472                            to_pass: consumer.clone(),
473                            resource: res_name.clone(),
474                            kind: DependencyKind::ReadAfterWrite,
475                        });
476                    }
477                }
478            }
479        }
480
481        // Explicit dependencies
482        let pass_names: Vec<String> = self.passes.keys().cloned().collect();
483        for name in &pass_names {
484            let deps = self.passes[name].explicit_deps.clone();
485            for dep in deps {
486                if self.passes.contains_key(&dep) {
487                    self.edges.push(PassDependency {
488                        from_pass: dep,
489                        to_pass: name.clone(),
490                        resource: String::new(),
491                        kind: DependencyKind::Explicit,
492                    });
493                }
494            }
495        }
496    }
497
498    /// Get all edges.
499    pub fn edges(&self) -> &[PassDependency] {
500        &self.edges
501    }
502
503    // -- Topological sort / cycle detection --------------------------------
504
505    /// Detect cycles in the dependency graph. Returns the set of passes
506    /// involved in a cycle, or empty if acyclic.
507    pub fn detect_cycles(&mut self) -> Vec<Vec<String>> {
508        self.build_edges();
509
510        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
511        for name in self.passes.keys() {
512            adj.entry(name.as_str()).or_default();
513        }
514        for edge in &self.edges {
515            adj.entry(edge.from_pass.as_str())
516                .or_default()
517                .push(edge.to_pass.as_str());
518        }
519
520        // Tarjan's SCC
521        let mut index_counter: u32 = 0;
522        let mut stack: Vec<&str> = Vec::new();
523        let mut on_stack: HashSet<&str> = HashSet::new();
524        let mut indices: HashMap<&str, u32> = HashMap::new();
525        let mut lowlinks: HashMap<&str, u32> = HashMap::new();
526        let mut sccs: Vec<Vec<String>> = Vec::new();
527
528        fn strongconnect<'a>(
529            v: &'a str,
530            adj: &HashMap<&'a str, Vec<&'a str>>,
531            index_counter: &mut u32,
532            stack: &mut Vec<&'a str>,
533            on_stack: &mut HashSet<&'a str>,
534            indices: &mut HashMap<&'a str, u32>,
535            lowlinks: &mut HashMap<&'a str, u32>,
536            sccs: &mut Vec<Vec<String>>,
537        ) {
538            indices.insert(v, *index_counter);
539            lowlinks.insert(v, *index_counter);
540            *index_counter += 1;
541            stack.push(v);
542            on_stack.insert(v);
543
544            if let Some(neighbors) = adj.get(v) {
545                for &w in neighbors {
546                    if !indices.contains_key(w) {
547                        strongconnect(w, adj, index_counter, stack, on_stack, indices, lowlinks, sccs);
548                        let lw = lowlinks[w];
549                        let lv = lowlinks[v];
550                        lowlinks.insert(v, lv.min(lw));
551                    } else if on_stack.contains(w) {
552                        let iw = indices[w];
553                        let lv = lowlinks[v];
554                        lowlinks.insert(v, lv.min(iw));
555                    }
556                }
557            }
558
559            if lowlinks[v] == indices[v] {
560                let mut scc = Vec::new();
561                while let Some(w) = stack.pop() {
562                    on_stack.remove(w);
563                    scc.push(w.to_string());
564                    if w == v {
565                        break;
566                    }
567                }
568                if scc.len() > 1 {
569                    sccs.push(scc);
570                }
571            }
572        }
573
574        let nodes: Vec<&str> = adj.keys().copied().collect();
575        for node in nodes {
576            if !indices.contains_key(node) {
577                strongconnect(
578                    node,
579                    &adj,
580                    &mut index_counter,
581                    &mut stack,
582                    &mut on_stack,
583                    &mut indices,
584                    &mut lowlinks,
585                    &mut sccs,
586                );
587            }
588        }
589
590        sccs
591    }
592
593    /// Perform topological sort using Kahn's algorithm.
594    /// Returns `Err` with cycle participants if a cycle is detected.
595    pub fn topological_sort(&mut self) -> Result<Vec<String>, Vec<String>> {
596        self.build_edges();
597
598        let mut in_degree: HashMap<&str, usize> = HashMap::new();
599        for name in self.passes.keys() {
600            in_degree.entry(name.as_str()).or_insert(0);
601        }
602        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
603        for edge in &self.edges {
604            adj.entry(edge.from_pass.as_str())
605                .or_default()
606                .push(edge.to_pass.as_str());
607            *in_degree.entry(edge.to_pass.as_str()).or_insert(0) += 1;
608        }
609
610        let mut queue: VecDeque<&str> = VecDeque::new();
611        for (&node, &deg) in &in_degree {
612            if deg == 0 {
613                queue.push_back(node);
614            }
615        }
616
617        // Sort the initial queue by pass insertion order for determinism
618        let order_map: HashMap<&str, usize> = self
619            .pass_order
620            .iter()
621            .enumerate()
622            .map(|(i, n)| (n.as_str(), i))
623            .collect();
624        let mut initial: Vec<&str> = queue.drain(..).collect();
625        initial.sort_by_key(|n| order_map.get(n).copied().unwrap_or(usize::MAX));
626        for n in initial {
627            queue.push_back(n);
628        }
629
630        let mut sorted: Vec<String> = Vec::new();
631        let mut visited = 0usize;
632
633        while let Some(node) = queue.pop_front() {
634            sorted.push(node.to_string());
635            visited += 1;
636            if let Some(neighbors) = adj.get(node) {
637                // Collect and sort neighbors for determinism
638                let mut next: Vec<&str> = Vec::new();
639                for &nb in neighbors {
640                    let deg = in_degree.get_mut(nb).unwrap();
641                    *deg -= 1;
642                    if *deg == 0 {
643                        next.push(nb);
644                    }
645                }
646                next.sort_by_key(|n| order_map.get(n).copied().unwrap_or(usize::MAX));
647                for nb in next {
648                    queue.push_back(nb);
649                }
650            }
651        }
652
653        if visited != self.passes.len() {
654            // Cycle: return nodes not in sorted
655            let sorted_set: HashSet<&str> = sorted.iter().map(|s| s.as_str()).collect();
656            let cycle_nodes: Vec<String> = self
657                .passes
658                .keys()
659                .filter(|k| !sorted_set.contains(k.as_str()))
660                .cloned()
661                .collect();
662            return Err(cycle_nodes);
663        }
664
665        self.sorted_passes = sorted.clone();
666        self.dirty = false;
667        Ok(sorted)
668    }
669
670    /// Get the sorted pass list (computes if dirty).
671    pub fn sorted(&mut self) -> Result<&[String], Vec<String>> {
672        if self.dirty {
673            self.topological_sort()?;
674        }
675        Ok(&self.sorted_passes)
676    }
677
678    /// Filter sorted passes to only those whose conditions are satisfied.
679    pub fn active_passes(&mut self) -> Result<Vec<String>, Vec<String>> {
680        let sorted = self.topological_sort()?;
681        let features = &self.features;
682        let callbacks = &self.callback_values;
683        Ok(sorted
684            .into_iter()
685            .filter(|name| {
686                self.passes
687                    .get(name)
688                    .map(|p| p.condition.evaluate(features, callbacks))
689                    .unwrap_or(false)
690            })
691            .collect())
692    }
693
694    // -- Validation -------------------------------------------------------
695
696    /// Validate the graph: check for cycles, dangling resources, disconnected
697    /// inputs, and other structural issues.
698    pub fn validate(&mut self) -> ValidationResult {
699        let mut result = ValidationResult::new();
700
701        // 1. Cycle detection
702        let cycles = self.detect_cycles();
703        for cycle in &cycles {
704            result.error(format!("Cycle detected involving passes: {}", cycle.join(", ")));
705        }
706
707        // 2. Dangling resources
708        let dangling = self.resource_table.find_dangling();
709        for d in &dangling {
710            match d.kind {
711                crate::rendergraph::resources::DanglingKind::NeverWritten => {
712                    result.error(format!("Resource '{}' is never written by any pass", d.name));
713                }
714                crate::rendergraph::resources::DanglingKind::NeverRead => {
715                    result.warning(format!("Resource '{}' is never read by any pass", d.name));
716                }
717            }
718        }
719
720        // 3. Check all pass inputs are connected
721        for pass in self.passes.values() {
722            for input_name in &pass.input_names {
723                if self.resource_table.lookup(input_name).is_none() {
724                    result.error(format!(
725                        "Pass '{}' reads resource '{}' which is not declared",
726                        pass.name, input_name
727                    ));
728                }
729            }
730            for output_name in &pass.output_names {
731                if self.resource_table.lookup(output_name).is_none() {
732                    result.error(format!(
733                        "Pass '{}' writes resource '{}' which is not declared",
734                        pass.name, output_name
735                    ));
736                }
737            }
738        }
739
740        // 4. Check explicit deps reference real passes
741        for pass in self.passes.values() {
742            for dep in &pass.explicit_deps {
743                if !self.passes.contains_key(dep) {
744                    result.error(format!(
745                        "Pass '{}' depends on '{}' which does not exist",
746                        pass.name, dep
747                    ));
748                }
749            }
750        }
751
752        // 5. Warn about passes with no inputs and no outputs (possible error)
753        for pass in self.passes.values() {
754            if pass.inputs.is_empty() && pass.outputs.is_empty() && !pass.has_side_effects {
755                result.warning(format!(
756                    "Pass '{}' has no inputs, no outputs, and no side effects",
757                    pass.name
758                ));
759            }
760        }
761
762        // 6. Multi-resolution: warn if a pass reads a resource at a different
763        //    resolution than it was written.
764        // (Informational only — this is valid for bloom/SSAO but might be a mistake)
765
766        result
767    }
768
769    // -- Graph merging ----------------------------------------------------
770
771    /// Merge another graph into this one. Passes and resources from `other`
772    /// are added. Conflicting names get a prefix.
773    pub fn merge(&mut self, other: &RenderGraph, prefix: &str) {
774        // Merge resources
775        for (name, rn) in &other.resource_nodes {
776            let new_name = if self.resource_nodes.contains_key(name) {
777                format!("{}_{}", prefix, name)
778            } else {
779                name.clone()
780            };
781            let mut desc = rn.descriptor.clone();
782            desc.name = new_name.clone();
783            let handle = if rn.lifetime == ResourceLifetime::Imported {
784                self.import_resource(desc)
785            } else {
786                self.declare_resource(desc)
787            };
788            // We need to remap, but for simplicity we store the handle in the new node
789            let _ = handle;
790        }
791
792        // Merge passes with remapped resource names
793        for (name, pass) in &other.passes {
794            let new_name = if self.passes.contains_key(name) {
795                format!("{}_{}", prefix, name)
796            } else {
797                name.clone()
798            };
799            let mut new_pass = RenderPass::new(&new_name, pass.pass_type);
800            new_pass.queue = pass.queue;
801            new_pass.condition = pass.condition.clone();
802            new_pass.resolution = pass.resolution;
803            new_pass.has_side_effects = pass.has_side_effects;
804            new_pass.tag = pass.tag.clone();
805
806            // Remap input/output names
807            for iname in &pass.input_names {
808                let mapped = if self.resource_nodes.contains_key(iname) && other.resource_nodes.contains_key(iname) {
809                    // If it existed in both, it was prefixed
810                    if self.resource_nodes.contains_key(&format!("{}_{}", prefix, iname)) {
811                        format!("{}_{}", prefix, iname)
812                    } else {
813                        iname.clone()
814                    }
815                } else {
816                    iname.clone()
817                };
818                if let Some(h) = self.resource_table.lookup(&mapped) {
819                    new_pass.add_input(h, &mapped);
820                }
821            }
822            for oname in &pass.output_names {
823                let mapped = if self.resource_nodes.contains_key(oname) && other.resource_nodes.contains_key(oname) {
824                    if self.resource_nodes.contains_key(&format!("{}_{}", prefix, oname)) {
825                        format!("{}_{}", prefix, oname)
826                    } else {
827                        oname.clone()
828                    }
829                } else {
830                    oname.clone()
831                };
832                if let Some(h) = self.resource_table.lookup(&mapped) {
833                    new_pass.add_output(h, &mapped);
834                }
835            }
836
837            // Remap explicit deps
838            for dep in &pass.explicit_deps {
839                let mapped_dep = if self.passes.contains_key(dep) && other.passes.contains_key(dep) {
840                    format!("{}_{}", prefix, dep)
841                } else {
842                    dep.clone()
843                };
844                new_pass.depends_on(&mapped_dep);
845            }
846
847            self.add_pass(new_pass);
848        }
849
850        self.dirty = true;
851    }
852
853    // -- DOT export -------------------------------------------------------
854
855    /// Export the graph in Graphviz DOT format for debug visualization.
856    pub fn export_dot(&mut self) -> String {
857        // Ensure edges are built
858        self.build_edges();
859
860        let mut dot = String::new();
861        dot.push_str(&format!("digraph \"{}\" {{\n", self.label));
862        dot.push_str("  rankdir=LR;\n");
863        dot.push_str("  node [shape=box, style=filled];\n\n");
864
865        // Pass nodes
866        dot.push_str("  // Render passes\n");
867        for (name, pass) in &self.passes {
868            let color = match pass.pass_type {
869                PassType::Graphics => "#4a90d9",
870                PassType::Compute => "#d94a4a",
871                PassType::Transfer => "#4ad94a",
872                PassType::Present => "#d9d94a",
873            };
874            let active = pass.condition.evaluate(&self.features, &self.callback_values);
875            let style = if active { "filled" } else { "filled,dashed" };
876            let label = format!(
877                "{}\\n[{:?}]{}",
878                name,
879                pass.pass_type,
880                if !active { " (DISABLED)" } else { "" }
881            );
882            dot.push_str(&format!(
883                "  \"pass_{}\" [label=\"{}\", fillcolor=\"{}\", style=\"{}\", fontcolor=white];\n",
884                name, label, color, style
885            ));
886        }
887
888        // Resource nodes
889        dot.push_str("\n  // Resources\n");
890        for (name, rn) in &self.resource_nodes {
891            let shape = match rn.lifetime {
892                ResourceLifetime::Transient => "ellipse",
893                ResourceLifetime::Imported => "diamond",
894            };
895            let label = format!(
896                "{}\\n{:?}",
897                name, rn.descriptor.format
898            );
899            dot.push_str(&format!(
900                "  \"res_{}\" [label=\"{}\", shape={}, fillcolor=\"#e0e0e0\", fontcolor=black];\n",
901                name, label, shape
902            ));
903        }
904
905        // Edges: pass -> resource (writes) and resource -> pass (reads)
906        dot.push_str("\n  // Edges\n");
907        for pass in self.passes.values() {
908            for oname in &pass.output_names {
909                dot.push_str(&format!(
910                    "  \"pass_{}\" -> \"res_{}\" [color=red, label=\"write\"];\n",
911                    pass.name, oname
912                ));
913            }
914            for iname in &pass.input_names {
915                dot.push_str(&format!(
916                    "  \"res_{}\" -> \"pass_{}\" [color=blue, label=\"read\"];\n",
917                    iname, pass.name
918                ));
919            }
920        }
921
922        // Explicit dep edges
923        for pass in self.passes.values() {
924            for dep in &pass.explicit_deps {
925                dot.push_str(&format!(
926                    "  \"pass_{}\" -> \"pass_{}\" [style=dashed, color=gray, label=\"explicit\"];\n",
927                    dep, pass.name
928                ));
929            }
930        }
931
932        dot.push_str("}\n");
933        dot
934    }
935
936    // -- Accessors --------------------------------------------------------
937
938    pub fn label(&self) -> &str {
939        &self.label
940    }
941
942    pub fn resource_node(&self, name: &str) -> Option<&ResourceNode> {
943        self.resource_nodes.get(name)
944    }
945
946    pub fn all_passes(&self) -> impl Iterator<Item = &RenderPass> {
947        self.passes.values()
948    }
949
950    pub fn all_resource_nodes(&self) -> impl Iterator<Item = &ResourceNode> {
951        self.resource_nodes.values()
952    }
953
954    pub fn features(&self) -> &HashSet<String> {
955        &self.features
956    }
957
958    /// Returns passes grouped by tag.
959    pub fn passes_by_tag(&self) -> HashMap<String, Vec<&RenderPass>> {
960        let mut map: HashMap<String, Vec<&RenderPass>> = HashMap::new();
961        for pass in self.passes.values() {
962            let tag = pass.tag.clone().unwrap_or_else(|| "untagged".to_string());
963            map.entry(tag).or_default().push(pass);
964        }
965        map
966    }
967}
968
969impl fmt::Display for RenderGraph {
970    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971        write!(
972            f,
973            "RenderGraph '{}': {} passes, {} resources, {} edges",
974            self.label,
975            self.passes.len(),
976            self.resource_nodes.len(),
977            self.edges.len(),
978        )
979    }
980}
981
982// ---------------------------------------------------------------------------
983// Builder helpers
984// ---------------------------------------------------------------------------
985
986/// Fluent builder for constructing a RenderGraph.
987pub struct RenderGraphBuilder {
988    graph: RenderGraph,
989    backbuffer_width: u32,
990    backbuffer_height: u32,
991}
992
993impl RenderGraphBuilder {
994    pub fn new(label: &str, width: u32, height: u32) -> Self {
995        Self {
996            graph: RenderGraph::new(label),
997            backbuffer_width: width,
998            backbuffer_height: height,
999        }
1000    }
1001
1002    pub fn backbuffer_size(&self) -> (u32, u32) {
1003        (self.backbuffer_width, self.backbuffer_height)
1004    }
1005
1006    /// Declare a full-resolution transient texture.
1007    pub fn texture(&mut self, name: &str, format: TextureFormat) -> ResourceHandle {
1008        let desc = ResourceDescriptor::new(name, format);
1009        self.graph.declare_resource(desc)
1010    }
1011
1012    /// Declare a texture at a specific resolution scale.
1013    pub fn texture_scaled(
1014        &mut self,
1015        name: &str,
1016        format: TextureFormat,
1017        width_scale: f32,
1018        height_scale: f32,
1019    ) -> ResourceHandle {
1020        let desc = ResourceDescriptor::new(name, format).with_size(SizePolicy::Relative {
1021            width_scale,
1022            height_scale,
1023        });
1024        self.graph.declare_resource(desc)
1025    }
1026
1027    /// Declare a texture with explicit pixel dimensions.
1028    pub fn texture_absolute(
1029        &mut self,
1030        name: &str,
1031        format: TextureFormat,
1032        width: u32,
1033        height: u32,
1034    ) -> ResourceHandle {
1035        let desc = ResourceDescriptor::new(name, format).with_size(SizePolicy::Absolute { width, height });
1036        self.graph.declare_resource(desc)
1037    }
1038
1039    /// Import an external resource.
1040    pub fn import(&mut self, name: &str, format: TextureFormat) -> ResourceHandle {
1041        let desc = ResourceDescriptor::new(name, format);
1042        self.graph.import_resource(desc)
1043    }
1044
1045    /// Add a graphics pass.
1046    pub fn graphics_pass(&mut self, name: &str) -> PassBuilder<'_> {
1047        PassBuilder {
1048            graph: &mut self.graph,
1049            pass: RenderPass::new(name, PassType::Graphics),
1050        }
1051    }
1052
1053    /// Add a compute pass.
1054    pub fn compute_pass(&mut self, name: &str) -> PassBuilder<'_> {
1055        PassBuilder {
1056            graph: &mut self.graph,
1057            pass: RenderPass::new(name, PassType::Compute),
1058        }
1059    }
1060
1061    /// Enable a feature flag.
1062    pub fn enable_feature(&mut self, feature: &str) -> &mut Self {
1063        self.graph.enable_feature(feature);
1064        self
1065    }
1066
1067    /// Finalize and return the built graph.
1068    pub fn build(self) -> RenderGraph {
1069        self.graph
1070    }
1071}
1072
1073/// Fluent builder for a single pass within a graph.
1074pub struct PassBuilder<'a> {
1075    graph: &'a mut RenderGraph,
1076    pass: RenderPass,
1077}
1078
1079impl<'a> PassBuilder<'a> {
1080    pub fn reads(mut self, handle: ResourceHandle, name: &str) -> Self {
1081        self.pass.add_input(handle, name);
1082        self
1083    }
1084
1085    pub fn writes(mut self, handle: ResourceHandle, name: &str) -> Self {
1086        self.pass.add_output(handle, name);
1087        self
1088    }
1089
1090    pub fn depends_on(mut self, pass_name: &str) -> Self {
1091        self.pass.depends_on(pass_name);
1092        self
1093    }
1094
1095    pub fn condition(mut self, cond: PassCondition) -> Self {
1096        self.pass.condition = cond;
1097        self
1098    }
1099
1100    pub fn resolution(mut self, scale: ResolutionScale) -> Self {
1101        self.pass.resolution = scale;
1102        self
1103    }
1104
1105    pub fn queue(mut self, q: QueueAffinity) -> Self {
1106        self.pass.queue = q;
1107        self
1108    }
1109
1110    pub fn side_effects(mut self) -> Self {
1111        self.pass.has_side_effects = true;
1112        self
1113    }
1114
1115    pub fn tag(mut self, t: &str) -> Self {
1116        self.pass.tag = Some(t.to_string());
1117        self
1118    }
1119
1120    /// Finalize the pass and add it to the graph.
1121    pub fn finish(self) {
1122        self.graph.add_pass(self.pass);
1123    }
1124}
1125
1126// ---------------------------------------------------------------------------
1127// Config-driven graph building
1128// ---------------------------------------------------------------------------
1129
1130/// A serializable pass description for config-driven graph rebuilding.
1131#[derive(Debug, Clone)]
1132pub struct PassConfig {
1133    pub name: String,
1134    pub pass_type: PassType,
1135    pub inputs: Vec<String>,
1136    pub outputs: Vec<String>,
1137    pub condition: Option<String>,
1138    pub resolution_scale: Option<(f32, f32)>,
1139    pub queue: QueueAffinity,
1140    pub explicit_deps: Vec<String>,
1141}
1142
1143/// A serializable resource description.
1144#[derive(Debug, Clone)]
1145pub struct ResourceConfig {
1146    pub name: String,
1147    pub format: TextureFormat,
1148    pub size: SizePolicy,
1149    pub imported: bool,
1150}
1151
1152/// A serializable graph configuration that can be hot-reloaded.
1153#[derive(Debug, Clone)]
1154pub struct GraphConfig {
1155    pub label: String,
1156    pub resources: Vec<ResourceConfig>,
1157    pub passes: Vec<PassConfig>,
1158    pub features: Vec<String>,
1159}
1160
1161impl GraphConfig {
1162    /// Build a RenderGraph from this configuration.
1163    pub fn build(&self) -> RenderGraph {
1164        let mut graph = RenderGraph::new(&self.label);
1165
1166        // Declare resources
1167        let mut handles: HashMap<String, ResourceHandle> = HashMap::new();
1168        for rc in &self.resources {
1169            let desc = ResourceDescriptor::new(&rc.name, rc.format).with_size(rc.size);
1170            let h = if rc.imported {
1171                graph.import_resource(desc)
1172            } else {
1173                graph.declare_resource(desc)
1174            };
1175            handles.insert(rc.name.clone(), h);
1176        }
1177
1178        // Enable features
1179        for f in &self.features {
1180            graph.enable_feature(f);
1181        }
1182
1183        // Add passes
1184        for pc in &self.passes {
1185            let mut pass = RenderPass::new(&pc.name, pc.pass_type);
1186            pass.queue = pc.queue;
1187
1188            if let Some(ref cond) = pc.condition {
1189                pass.condition = PassCondition::FeatureEnabled(cond.clone());
1190            }
1191            if let Some((ws, hs)) = pc.resolution_scale {
1192                pass.resolution = ResolutionScale::custom(ws, hs);
1193            }
1194
1195            for iname in &pc.inputs {
1196                if let Some(&h) = handles.get(iname) {
1197                    pass.add_input(h, iname);
1198                }
1199            }
1200            for oname in &pc.outputs {
1201                if let Some(&h) = handles.get(oname) {
1202                    pass.add_output(h, oname);
1203                }
1204            }
1205            for dep in &pc.explicit_deps {
1206                pass.depends_on(dep);
1207            }
1208
1209            graph.add_pass(pass);
1210        }
1211
1212        graph
1213    }
1214}
1215
1216// ---------------------------------------------------------------------------
1217// Tests
1218// ---------------------------------------------------------------------------
1219
1220#[cfg(test)]
1221mod tests {
1222    use super::*;
1223
1224    fn simple_graph() -> RenderGraph {
1225        let mut b = RenderGraphBuilder::new("test", 1920, 1080);
1226        let depth = b.texture("depth", TextureFormat::Depth32Float);
1227        let color = b.texture("color", TextureFormat::Rgba16Float);
1228        let final_rt = b.texture("final", TextureFormat::Rgba8Unorm);
1229
1230        b.graphics_pass("depth_pre")
1231            .writes(depth, "depth")
1232            .tag("geometry")
1233            .finish();
1234
1235        b.graphics_pass("lighting")
1236            .reads(depth, "depth")
1237            .writes(color, "color")
1238            .tag("lighting")
1239            .finish();
1240
1241        b.graphics_pass("tonemap")
1242            .reads(color, "color")
1243            .writes(final_rt, "final")
1244            .tag("post")
1245            .finish();
1246
1247        b.build()
1248    }
1249
1250    #[test]
1251    fn test_topological_sort() {
1252        let mut g = simple_graph();
1253        let sorted = g.topological_sort().unwrap();
1254        assert_eq!(sorted, vec!["depth_pre", "lighting", "tonemap"]);
1255    }
1256
1257    #[test]
1258    fn test_cycle_detection() {
1259        let mut g = RenderGraph::new("cycle_test");
1260        let r1 = g.declare_resource(ResourceDescriptor::new("r1", TextureFormat::Rgba8Unorm));
1261        let r2 = g.declare_resource(ResourceDescriptor::new("r2", TextureFormat::Rgba8Unorm));
1262
1263        let mut pa = RenderPass::new("a", PassType::Graphics);
1264        pa.add_input(r2, "r2");
1265        pa.add_output(r1, "r1");
1266        g.add_pass(pa);
1267
1268        let mut pb = RenderPass::new("b", PassType::Graphics);
1269        pb.add_input(r1, "r1");
1270        pb.add_output(r2, "r2");
1271        g.add_pass(pb);
1272
1273        let result = g.topological_sort();
1274        assert!(result.is_err());
1275    }
1276
1277    #[test]
1278    fn test_conditional_pass() {
1279        let mut g = simple_graph();
1280        // Disable tonemap via feature
1281        g.get_pass_mut("tonemap").unwrap().condition =
1282            PassCondition::FeatureEnabled("hdr_output".to_string());
1283
1284        let active = g.active_passes().unwrap();
1285        assert!(!active.contains(&"tonemap".to_string()));
1286        assert!(active.contains(&"depth_pre".to_string()));
1287
1288        // Enable feature
1289        g.enable_feature("hdr_output");
1290        let active = g.active_passes().unwrap();
1291        assert!(active.contains(&"tonemap".to_string()));
1292    }
1293
1294    #[test]
1295    fn test_validation() {
1296        let mut g = simple_graph();
1297        let result = g.validate();
1298        assert!(result.is_ok());
1299    }
1300
1301    #[test]
1302    fn test_dot_export() {
1303        let mut g = simple_graph();
1304        let dot = g.export_dot();
1305        assert!(dot.contains("digraph"));
1306        assert!(dot.contains("depth_pre"));
1307        assert!(dot.contains("lighting"));
1308        assert!(dot.contains("tonemap"));
1309    }
1310
1311    #[test]
1312    fn test_merge() {
1313        let mut g1 = simple_graph();
1314        let g2 = simple_graph();
1315        g1.merge(&g2, "post");
1316        // Should have more passes now
1317        assert!(g1.pass_count() > 3);
1318    }
1319
1320    #[test]
1321    fn test_graph_config_build() {
1322        let config = GraphConfig {
1323            label: "from_config".to_string(),
1324            resources: vec![
1325                ResourceConfig {
1326                    name: "depth".to_string(),
1327                    format: TextureFormat::Depth32Float,
1328                    size: SizePolicy::Relative {
1329                        width_scale: 1.0,
1330                        height_scale: 1.0,
1331                    },
1332                    imported: false,
1333                },
1334                ResourceConfig {
1335                    name: "color".to_string(),
1336                    format: TextureFormat::Rgba16Float,
1337                    size: SizePolicy::Relative {
1338                        width_scale: 1.0,
1339                        height_scale: 1.0,
1340                    },
1341                    imported: false,
1342                },
1343            ],
1344            passes: vec![
1345                PassConfig {
1346                    name: "depth_pre".to_string(),
1347                    pass_type: PassType::Graphics,
1348                    inputs: vec![],
1349                    outputs: vec!["depth".to_string()],
1350                    condition: None,
1351                    resolution_scale: None,
1352                    queue: QueueAffinity::Graphics,
1353                    explicit_deps: vec![],
1354                },
1355                PassConfig {
1356                    name: "lighting".to_string(),
1357                    pass_type: PassType::Graphics,
1358                    inputs: vec!["depth".to_string()],
1359                    outputs: vec!["color".to_string()],
1360                    condition: None,
1361                    resolution_scale: None,
1362                    queue: QueueAffinity::Graphics,
1363                    explicit_deps: vec![],
1364                },
1365            ],
1366            features: vec![],
1367        };
1368        let mut graph = config.build();
1369        let sorted = graph.topological_sort().unwrap();
1370        assert_eq!(sorted, vec!["depth_pre", "lighting"]);
1371    }
1372
1373    #[test]
1374    fn test_pass_builder_chain() {
1375        let mut b = RenderGraphBuilder::new("builder_test", 1280, 720);
1376        let bloom_half = b.texture_scaled("bloom_half", TextureFormat::Rgba16Float, 0.5, 0.5);
1377        let bloom_quarter = b.texture_scaled("bloom_quarter", TextureFormat::Rgba16Float, 0.25, 0.25);
1378        let color = b.texture("hdr_color", TextureFormat::Rgba16Float);
1379
1380        b.graphics_pass("bloom_down")
1381            .reads(color, "hdr_color")
1382            .writes(bloom_half, "bloom_half")
1383            .resolution(ResolutionScale::half())
1384            .tag("bloom")
1385            .finish();
1386
1387        b.graphics_pass("bloom_down2")
1388            .reads(bloom_half, "bloom_half")
1389            .writes(bloom_quarter, "bloom_quarter")
1390            .resolution(ResolutionScale::quarter())
1391            .tag("bloom")
1392            .finish();
1393
1394        let graph = b.build();
1395        assert_eq!(graph.pass_count(), 2);
1396        assert_eq!(graph.resource_count(), 3);
1397    }
1398}