1use std::collections::{HashMap, HashSet, VecDeque};
20use crate::render::compute::ResourceHandle;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum TextureFormat {
27 Rgba8Unorm,
28 Rgba16Float,
29 Rgba32Float,
30 R32Float,
31 Rg16Float,
32 Depth24Stencil8,
33 Depth32Float,
34 Rgb10A2Unorm,
35 Bgra8Unorm,
36}
37
38impl TextureFormat {
39 pub fn is_depth(self) -> bool {
40 matches!(self, Self::Depth24Stencil8 | Self::Depth32Float)
41 }
42
43 pub fn bytes_per_pixel(self) -> u32 {
44 match self {
45 Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgb10A2Unorm | Self::Depth24Stencil8 => 4,
46 Self::R32Float => 4,
47 Self::Rg16Float => 4,
48 Self::Rgba16Float | Self::Depth32Float => 8,
49 Self::Rgba32Float => 16,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy)]
56pub enum ResourceSize {
57 Fixed(u32, u32),
59 Relative(f32),
61 Backbuffer,
63}
64
65#[derive(Debug, Clone)]
67pub struct ResourceDesc {
68 pub name: String,
69 pub format: TextureFormat,
70 pub size: ResourceSize,
71 pub samples: u32,
72 pub mips: u32,
74 pub layers: u32,
76 pub persistent: bool,
78}
79
80impl ResourceDesc {
81 pub fn color(name: impl Into<String>, format: TextureFormat) -> Self {
82 Self {
83 name: name.into(),
84 format,
85 size: ResourceSize::Backbuffer,
86 samples: 1,
87 mips: 1,
88 layers: 1,
89 persistent: false,
90 }
91 }
92
93 pub fn depth(name: impl Into<String>) -> Self {
94 Self::color(name, TextureFormat::Depth24Stencil8)
95 }
96
97 pub fn half_res(mut self) -> Self {
98 self.size = ResourceSize::Relative(0.5);
99 self
100 }
101
102 pub fn fixed_size(mut self, w: u32, h: u32) -> Self {
103 self.size = ResourceSize::Fixed(w, h);
104 self
105 }
106
107 pub fn with_mips(mut self, mips: u32) -> Self {
108 self.mips = mips;
109 self
110 }
111
112 pub fn persistent(mut self) -> Self {
113 self.persistent = true;
114 self
115 }
116
117 pub fn msaa(mut self, samples: u32) -> Self {
118 self.samples = samples;
119 self
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum ResourceAccess {
128 ShaderRead,
130 RenderTarget,
132 DepthWrite,
134 DepthRead,
136 ImageReadWrite,
138 ComputeWrite,
140 ComputeRead,
142 TransferSrc,
144 TransferDst,
146 Present,
148}
149
150impl ResourceAccess {
151 pub fn is_write(self) -> bool {
152 matches!(self,
153 Self::RenderTarget
154 | Self::DepthWrite
155 | Self::ImageReadWrite
156 | Self::ComputeWrite
157 | Self::TransferDst
158 )
159 }
160
161 pub fn is_read(self) -> bool { !self.is_write() }
162}
163
164#[derive(Debug, Clone)]
168pub struct Barrier {
169 pub resource: String,
170 pub src_access: ResourceAccess,
171 pub dst_access: ResourceAccess,
172}
173
174impl Barrier {
175 pub fn new(resource: impl Into<String>, src: ResourceAccess, dst: ResourceAccess) -> Self {
176 Self { resource: resource.into(), src_access: src, dst_access: dst }
177 }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum PassKind {
185 Graphics,
187 Compute,
189 Transfer,
191 Present,
193}
194
195#[derive(Debug, Clone)]
199pub struct RenderPass {
200 pub name: String,
201 pub kind: PassKind,
202 pub reads: Vec<(String, ResourceAccess)>,
204 pub writes: Vec<(String, ResourceAccess)>,
206 pub priority: i32,
208 pub optional: bool,
210 pub depends: Vec<String>,
212}
213
214impl RenderPass {
215 pub fn new(name: impl Into<String>, kind: PassKind) -> Self {
216 Self {
217 name: name.into(),
218 kind,
219 reads: Vec::new(),
220 writes: Vec::new(),
221 priority: 0,
222 optional: false,
223 depends: Vec::new(),
224 }
225 }
226
227 pub fn reads(&mut self, res: impl Into<String>, access: ResourceAccess) {
228 self.reads.push((res.into(), access));
229 }
230
231 pub fn writes(&mut self, res: impl Into<String>, access: ResourceAccess) {
232 self.writes.push((res.into(), access));
233 }
234}
235
236pub struct PassBuilder<'g> {
240 graph: &'g mut RenderGraph,
241 pass: RenderPass,
242}
243
244impl<'g> PassBuilder<'g> {
245 fn new(graph: &'g mut RenderGraph, name: impl Into<String>, kind: PassKind) -> Self {
246 Self { graph, pass: RenderPass::new(name, kind) }
247 }
248
249 pub fn read(mut self, resource: impl Into<String>) -> Self {
250 self.pass.reads(resource, ResourceAccess::ShaderRead);
251 self
252 }
253
254 pub fn read_depth(mut self, resource: impl Into<String>) -> Self {
255 self.pass.reads(resource, ResourceAccess::DepthRead);
256 self
257 }
258
259 pub fn write(mut self, resource: impl Into<String>) -> Self {
260 self.pass.writes(resource, ResourceAccess::RenderTarget);
261 self
262 }
263
264 pub fn write_depth(mut self, resource: impl Into<String>) -> Self {
265 self.pass.writes(resource, ResourceAccess::DepthWrite);
266 self
267 }
268
269 pub fn compute_read(mut self, resource: impl Into<String>) -> Self {
270 self.pass.reads(resource, ResourceAccess::ComputeRead);
271 self
272 }
273
274 pub fn compute_write(mut self, resource: impl Into<String>) -> Self {
275 self.pass.writes(resource, ResourceAccess::ComputeWrite);
276 self
277 }
278
279 pub fn transfer_src(mut self, resource: impl Into<String>) -> Self {
280 self.pass.reads(resource, ResourceAccess::TransferSrc);
281 self
282 }
283
284 pub fn transfer_dst(mut self, resource: impl Into<String>) -> Self {
285 self.pass.writes(resource, ResourceAccess::TransferDst);
286 self
287 }
288
289 pub fn priority(mut self, p: i32) -> Self {
290 self.pass.priority = p;
291 self
292 }
293
294 pub fn optional(mut self) -> Self {
295 self.pass.optional = true;
296 self
297 }
298
299 pub fn after(mut self, pass_name: impl Into<String>) -> Self {
300 self.pass.depends.push(pass_name.into());
301 self
302 }
303
304 pub fn build(self) {
306 self.graph.add_pass(self.pass);
307 }
308}
309
310pub struct RenderGraph {
314 passes: Vec<RenderPass>,
315 resources: HashMap<String, ResourceDesc>,
316 output: Option<String>,
318}
319
320impl RenderGraph {
321 pub fn new() -> Self {
322 Self {
323 passes: Vec::new(),
324 resources: HashMap::new(),
325 output: None,
326 }
327 }
328
329 pub fn declare_resource(&mut self, desc: ResourceDesc) -> ResourceHandle {
331 let id = self.resources.len() as u32 + 1;
332 self.resources.insert(desc.name.clone(), desc);
333 ResourceHandle(id)
334 }
335
336 pub fn set_output(&mut self, resource: impl Into<String>) {
338 self.output = Some(resource.into());
339 }
340
341 pub fn add_pass(&mut self, pass: RenderPass) {
343 self.passes.push(pass);
344 }
345
346 pub fn graphics_pass<'g>(&'g mut self, name: impl Into<String>) -> PassBuilder<'g> {
348 PassBuilder::new(self, name, PassKind::Graphics)
349 }
350
351 pub fn compute_pass<'g>(&'g mut self, name: impl Into<String>) -> PassBuilder<'g> {
353 PassBuilder::new(self, name, PassKind::Compute)
354 }
355
356 pub fn transfer_pass<'g>(&'g mut self, name: impl Into<String>) -> PassBuilder<'g> {
358 PassBuilder::new(self, name, PassKind::Transfer)
359 }
360
361 pub fn compile(self) -> Result<CompiledGraph, GraphError> {
363 let compiler = GraphCompiler::new(self);
364 compiler.compile()
365 }
366}
367
368impl Default for RenderGraph {
369 fn default() -> Self { Self::new() }
370}
371
372#[derive(Debug, Clone)]
375pub enum GraphError {
376 CycleDetected(Vec<String>),
378 UnresolvedResource { pass: String, resource: String },
380 UnknownDependency { pass: String, dep: String },
382}
383
384impl std::fmt::Display for GraphError {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 match self {
387 Self::CycleDetected(cycle) =>
388 write!(f, "render graph cycle: {}", cycle.join(" → ")),
389 Self::UnresolvedResource { pass, resource } =>
390 write!(f, "pass '{}' reads undeclared resource '{}'", pass, resource),
391 Self::UnknownDependency { pass, dep } =>
392 write!(f, "pass '{}' depends on unknown pass '{}'", pass, dep),
393 }
394 }
395}
396
397#[derive(Debug, Clone)]
401pub struct CompiledPass {
402 pub pass: RenderPass,
403 pub pre_barriers: Vec<Barrier>,
405}
406
407pub struct CompiledGraph {
411 pub passes: Vec<CompiledPass>,
413 pub resources: HashMap<String, ResourceDesc>,
414 pub output: Option<String>,
415 pub stats: CompileStats,
417}
418
419#[derive(Debug, Default, Clone)]
420pub struct CompileStats {
421 pub pass_count: usize,
422 pub barrier_count: usize,
423 pub culled_passes: usize,
424}
425
426impl CompiledGraph {
427 pub fn iter(&self) -> impl Iterator<Item = &CompiledPass> {
429 self.passes.iter()
430 }
431
432 pub fn pass_count(&self) -> usize { self.passes.len() }
433
434 pub fn resource(&self, name: &str) -> Option<&ResourceDesc> {
436 self.resources.get(name)
437 }
438
439 pub fn resolve_size(&self, name: &str, bb_w: u32, bb_h: u32) -> Option<(u32, u32)> {
441 let desc = self.resources.get(name)?;
442 Some(match desc.size {
443 ResourceSize::Fixed(w, h) => (w, h),
444 ResourceSize::Backbuffer => (bb_w, bb_h),
445 ResourceSize::Relative(s) => (
446 ((bb_w as f32 * s) as u32).max(1),
447 ((bb_h as f32 * s) as u32).max(1),
448 ),
449 })
450 }
451}
452
453struct GraphCompiler {
456 graph: RenderGraph,
457}
458
459impl GraphCompiler {
460 fn new(graph: RenderGraph) -> Self { Self { graph } }
461
462 fn compile(mut self) -> Result<CompiledGraph, GraphError> {
463 let pass_names: HashSet<String> = self.graph.passes.iter()
465 .map(|p| p.name.clone())
466 .collect();
467
468 for pass in &self.graph.passes {
469 for dep in &pass.depends {
470 if !pass_names.contains(dep) {
471 return Err(GraphError::UnknownDependency {
472 pass: pass.name.clone(),
473 dep: dep.clone(),
474 });
475 }
476 }
477 }
478
479 let mut writers: HashMap<String, Vec<String>> = HashMap::new();
481 for pass in &self.graph.passes {
482 for (res, _) in &pass.writes {
483 writers.entry(res.clone()).or_default().push(pass.name.clone());
484 }
485 }
486
487 let live_passes = self.compute_live_passes(&writers);
489 let original_count = self.graph.passes.len();
490 let culled = original_count - live_passes.len();
491 self.graph.passes.retain(|p| live_passes.contains(&p.name));
492
493 let mut adj: HashMap<String, HashSet<String>> = HashMap::new(); let mut in_deg: HashMap<String, usize> = HashMap::new();
497 for pass in &self.graph.passes {
498 adj.entry(pass.name.clone()).or_default();
499 in_deg.entry(pass.name.clone()).or_insert(0);
500 }
501
502 for pass_b in &self.graph.passes {
503 for (res, _) in &pass_b.reads {
504 if let Some(ws) = writers.get(res) {
505 for pass_a in ws {
506 if pass_a != &pass_b.name && live_passes.contains(pass_a) {
507 if adj.get(pass_a).map_or(true, |s| !s.contains(&pass_b.name)) {
508 adj.entry(pass_a.clone()).or_default().insert(pass_b.name.clone());
509 *in_deg.entry(pass_b.name.clone()).or_insert(0) += 1;
510 }
511 }
512 }
513 }
514 }
515 for dep in &pass_b.depends {
517 if live_passes.contains(dep) {
518 if adj.get(dep).map_or(true, |s| !s.contains(&pass_b.name)) {
519 adj.entry(dep.clone()).or_default().insert(pass_b.name.clone());
520 *in_deg.entry(pass_b.name.clone()).or_insert(0) += 1;
521 }
522 }
523 }
524 }
525
526 let pass_map: HashMap<String, RenderPass> = self.graph.passes.drain(..)
528 .map(|p| (p.name.clone(), p))
529 .collect();
530
531 let mut queue: VecDeque<String> = in_deg.iter()
532 .filter(|(_, &d)| d == 0)
533 .map(|(n, _)| n.clone())
534 .collect();
535
536 let mut sorted: Vec<String> = Vec::new();
538 let mut cycle_check = 0usize;
539
540 while !queue.is_empty() {
541 let best_idx = queue.iter().enumerate()
543 .min_by_key(|(_, n)| pass_map.get(*n).map_or(0, |p| p.priority))
544 .map(|(i, _)| i)
545 .unwrap_or(0);
546 let node = queue.remove(best_idx).unwrap();
547 sorted.push(node.clone());
548 cycle_check += 1;
549
550 if let Some(successors) = adj.get(&node) {
551 for succ in successors {
552 let deg = in_deg.get_mut(succ).unwrap();
553 *deg -= 1;
554 if *deg == 0 {
555 queue.push_back(succ.clone());
556 }
557 }
558 }
559 }
560
561 if cycle_check != pass_map.len() {
562 let cycle_nodes: Vec<String> = in_deg.iter()
564 .filter(|(_, &d)| d > 0)
565 .map(|(n, _)| n.clone())
566 .collect();
567 return Err(GraphError::CycleDetected(cycle_nodes));
568 }
569
570 let mut last_access: HashMap<String, ResourceAccess> = HashMap::new();
573 let mut compiled: Vec<CompiledPass> = Vec::new();
574 let mut total_barriers = 0usize;
575
576 for pass_name in &sorted {
577 let pass = pass_map.get(pass_name).unwrap().clone();
578 let mut pre_barriers = Vec::new();
579
580 for (res, access) in &pass.reads {
582 if let Some(&prev) = last_access.get(res) {
583 if needs_barrier(prev, *access) {
584 pre_barriers.push(Barrier::new(res, prev, *access));
585 total_barriers += 1;
586 }
587 }
588 last_access.insert(res.clone(), *access);
589 }
590
591 for (res, access) in &pass.writes {
593 if let Some(&prev) = last_access.get(res) {
594 if needs_barrier(prev, *access) {
595 pre_barriers.push(Barrier::new(res, prev, *access));
596 total_barriers += 1;
597 }
598 }
599 last_access.insert(res.clone(), *access);
600 }
601
602 compiled.push(CompiledPass { pass, pre_barriers });
603 }
604
605 let stats = CompileStats {
606 pass_count: compiled.len(),
607 barrier_count: total_barriers,
608 culled_passes: culled,
609 };
610
611 Ok(CompiledGraph {
612 passes: compiled,
613 resources: self.graph.resources,
614 output: self.graph.output,
615 stats,
616 })
617 }
618
619 fn compute_live_passes(&self, writers: &HashMap<String, Vec<String>>) -> HashSet<String> {
621 let mut live: HashSet<String> = self.graph.passes.iter()
623 .filter(|p| !p.optional)
624 .map(|p| p.name.clone())
625 .collect();
626
627 if let Some(output) = &self.graph.output {
629 let mut stack: Vec<String> = Vec::new();
630 if let Some(ws) = writers.get(output) {
631 stack.extend(ws.clone());
632 }
633 while let Some(pass_name) = stack.pop() {
634 if live.insert(pass_name.clone()) {
635 if let Some(pass) = self.graph.passes.iter().find(|p| p.name == pass_name) {
637 for (res, _) in &pass.reads {
638 if let Some(ws) = writers.get(res) {
639 for w in ws {
640 if !live.contains(w) {
641 stack.push(w.clone());
642 }
643 }
644 }
645 }
646 }
647 }
648 }
649 }
650
651 live
652 }
653}
654
655fn needs_barrier(src: ResourceAccess, dst: ResourceAccess) -> bool {
657 use ResourceAccess::*;
658 if src == dst && src.is_read() { return false; }
662 src.is_write() || dst.is_write()
663}
664
665pub fn standard_frame_graph() -> RenderGraph {
681 let mut g = RenderGraph::new();
682
683 g.declare_resource(ResourceDesc::depth("depth").persistent());
685 g.declare_resource(ResourceDesc::color("gbuffer_albedo", TextureFormat::Rgba8Unorm));
686 g.declare_resource(ResourceDesc::color("gbuffer_normal", TextureFormat::Rgba16Float));
687 g.declare_resource(ResourceDesc::color("gbuffer_emissive", TextureFormat::Rgba16Float));
688 g.declare_resource(ResourceDesc::color("ssao", TextureFormat::R32Float).half_res());
689 g.declare_resource(ResourceDesc::color("hdr", TextureFormat::Rgba16Float));
690 g.declare_resource(ResourceDesc::color("bloom_half", TextureFormat::Rgba16Float).half_res());
691 g.declare_resource(ResourceDesc::color("bloom", TextureFormat::Rgba16Float).half_res());
692 g.declare_resource(ResourceDesc::color("ldr", TextureFormat::Rgba8Unorm));
693 g.declare_resource(ResourceDesc::color("particle_buf", TextureFormat::Rgba32Float).persistent());
694
695 g.set_output("backbuffer");
696
697 g.graphics_pass("depth_prepass")
699 .write_depth("depth")
700 .priority(-100)
701 .build();
702
703 g.compute_pass("particle_update")
704 .compute_read("particle_buf")
705 .compute_write("particle_buf")
706 .priority(-90)
707 .build();
708
709 g.graphics_pass("gbuffer")
710 .write("gbuffer_albedo")
711 .write("gbuffer_normal")
712 .write("gbuffer_emissive")
713 .read_depth("depth")
714 .priority(-80)
715 .build();
716
717 g.compute_pass("ssao")
718 .read("gbuffer_normal")
719 .read_depth("depth")
720 .compute_write("ssao")
721 .priority(-70)
722 .build();
723
724 g.graphics_pass("lighting")
725 .read("gbuffer_albedo")
726 .read("gbuffer_normal")
727 .read("gbuffer_emissive")
728 .read("ssao")
729 .write("hdr")
730 .priority(-60)
731 .build();
732
733 g.graphics_pass("particle_draw")
734 .compute_read("particle_buf")
735 .read_depth("depth")
736 .write("hdr")
737 .priority(-50)
738 .build();
739
740 g.graphics_pass("bloom_down")
741 .read("hdr")
742 .write("bloom_half")
743 .priority(-40)
744 .build();
745
746 g.graphics_pass("bloom_up")
747 .read("bloom_half")
748 .write("bloom")
749 .priority(-30)
750 .build();
751
752 g.graphics_pass("tonemap")
753 .read("hdr")
754 .read("bloom")
755 .write("ldr")
756 .priority(-20)
757 .build();
758
759 g.graphics_pass("fxaa")
760 .read("ldr")
761 .write("backbuffer")
762 .priority(-10)
763 .build();
764
765 g
766}
767
768#[cfg(test)]
771mod tests {
772 use super::*;
773
774 #[test]
775 fn test_standard_graph_compiles() {
776 let g = standard_frame_graph();
777 let compiled = g.compile().expect("standard graph should compile");
778 assert!(compiled.pass_count() >= 9);
779 }
780
781 #[test]
782 fn test_barrier_inserted_between_write_and_read() {
783 let mut g = RenderGraph::new();
784 g.declare_resource(ResourceDesc::color("tex", TextureFormat::Rgba8Unorm));
785 g.graphics_pass("writer").write("tex").build();
786 g.graphics_pass("reader").read("tex").build();
787 let compiled = g.compile().unwrap();
788 let reader = compiled.passes.iter().find(|p| p.pass.name == "reader").unwrap();
789 assert!(!reader.pre_barriers.is_empty(), "barrier expected before reader");
790 }
791
792 #[test]
793 fn test_cycle_detection() {
794 let mut g = RenderGraph::new();
795 g.declare_resource(ResourceDesc::color("a", TextureFormat::Rgba8Unorm));
796 g.declare_resource(ResourceDesc::color("b", TextureFormat::Rgba8Unorm));
797 g.graphics_pass("pass_a").write("a").read("b").build();
799 g.graphics_pass("pass_b").write("b").read("a").build();
800 assert!(matches!(g.compile(), Err(GraphError::CycleDetected(_))));
801 }
802
803 #[test]
804 fn test_no_barrier_between_two_reads() {
805 let mut g = RenderGraph::new();
806 g.declare_resource(ResourceDesc::color("tex", TextureFormat::Rgba8Unorm));
807 g.declare_resource(ResourceDesc::color("out1", TextureFormat::Rgba8Unorm));
808 g.declare_resource(ResourceDesc::color("out2", TextureFormat::Rgba8Unorm));
809 g.graphics_pass("init").write("tex").build();
811 g.graphics_pass("r1").read("tex").write("out1").after("init").build();
813 g.graphics_pass("r2").read("tex").write("out2").after("init").build();
814 let compiled = g.compile().unwrap();
815 let barriers: Vec<_> = compiled.passes.iter()
818 .flat_map(|p| p.pre_barriers.iter())
819 .filter(|b| b.resource == "tex"
820 && b.src_access == ResourceAccess::ShaderRead
821 && b.dst_access == ResourceAccess::ShaderRead)
822 .collect();
823 assert!(barriers.is_empty(), "read→read should not emit a barrier");
824 }
825
826 #[test]
827 fn test_resolve_size() {
828 let g = standard_frame_graph();
829 let compiled = g.compile().unwrap();
830 let (w, h) = compiled.resolve_size("ssao", 1920, 1080).unwrap();
831 assert_eq!(w, 960);
832 assert_eq!(h, 540);
833 }
834}