Skip to main content

astrelis_render/
render_graph.rs

1//! Render graph system for automatic resource management and pass scheduling.
2//!
3//! The render graph provides:
4//! - Automatic resource barriers and transitions
5//! - Topological sort of render passes based on dependencies
6//! - Resource lifetime tracking for optimization
7//! - Clear dependency visualization
8//!
9//! # Example
10//!
11//! ```ignore
12//! use astrelis_render::*;
13//!
14//! let mut graph = RenderGraph::new();
15//!
16//! // Add resources
17//! let color_target = graph.add_texture(TextureDescriptor {
18//!     size: (800, 600, 1),
19//!     format: TextureFormat::Rgba8Unorm,
20//!     usage: TextureUsages::RENDER_ATTACHMENT | TextureUsages::TEXTURE_BINDING,
21//!     ..Default::default()
22//! });
23//!
24//! // Add passes
25//! graph.add_pass(RenderGraphPass {
26//!     name: "main_pass",
27//!     inputs: vec![],
28//!     outputs: vec![color_target],
29//!     execute: Box::new(|ctx| {
30//!         // Render code here
31//!     }),
32//! });
33//!
34//! // Compile and execute
35//! let plan = graph.compile()?;
36//! graph.execute(&context);
37//! ```
38
39use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
40use astrelis_core::profiling::profile_function;
41use std::sync::Arc;
42
43use crate::GraphicsContext;
44
45/// Resource identifier in the render graph.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct ResourceId(u64);
48
49impl ResourceId {
50    /// Create a new resource ID.
51    pub fn new(id: u64) -> Self {
52        Self(id)
53    }
54
55    /// Get the raw ID value.
56    pub fn as_u64(&self) -> u64 {
57        self.0
58    }
59}
60
61/// Pass identifier in the render graph.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63pub struct PassId(u64);
64
65impl PassId {
66    /// Create a new pass ID.
67    pub fn new(id: u64) -> Self {
68        Self(id)
69    }
70
71    /// Get the raw ID value.
72    pub fn as_u64(&self) -> u64 {
73        self.0
74    }
75}
76
77/// Resource type in the render graph.
78#[derive(Debug, Clone, PartialEq)]
79pub enum ResourceType {
80    /// Texture resource
81    Texture {
82        size: (u32, u32, u32),
83        format: wgpu::TextureFormat,
84        usage: wgpu::TextureUsages,
85    },
86    /// Buffer resource
87    Buffer {
88        size: u64,
89        usage: wgpu::BufferUsages,
90    },
91}
92
93/// Resource information in the render graph.
94#[derive(Debug, Clone)]
95pub struct ResourceInfo {
96    /// Resource ID
97    pub id: ResourceId,
98    /// Resource type and descriptor
99    pub resource_type: ResourceType,
100    /// Resource name for debugging
101    pub name: String,
102    /// First pass that reads this resource
103    pub first_read: Option<PassId>,
104    /// Last pass that writes this resource
105    pub last_write: Option<PassId>,
106    /// Last pass that reads this resource
107    pub last_read: Option<PassId>,
108}
109
110/// Render context passed to pass execution functions.
111pub struct RenderContext {
112    /// Graphics context
113    pub graphics: Arc<GraphicsContext>,
114    /// Resource textures (if created)
115    pub textures: HashMap<ResourceId, wgpu::Texture>,
116    /// Resource buffers (if created)
117    pub buffers: HashMap<ResourceId, wgpu::Buffer>,
118}
119
120impl RenderContext {
121    /// Create a new render context.
122    pub fn new(graphics: Arc<GraphicsContext>) -> Self {
123        Self {
124            graphics,
125            textures: HashMap::new(),
126            buffers: HashMap::new(),
127        }
128    }
129
130    /// Get a texture by resource ID.
131    pub fn get_texture(&self, id: ResourceId) -> Option<&wgpu::Texture> {
132        self.textures.get(&id)
133    }
134
135    /// Get a buffer by resource ID.
136    pub fn get_buffer(&self, id: ResourceId) -> Option<&wgpu::Buffer> {
137        self.buffers.get(&id)
138    }
139}
140
141/// A render pass in the graph.
142pub struct RenderGraphPass {
143    /// Pass name for debugging
144    pub name: &'static str,
145    /// Input resources (read)
146    pub inputs: Vec<ResourceId>,
147    /// Output resources (write)
148    pub outputs: Vec<ResourceId>,
149    /// Execution function
150    pub execute: Box<dyn Fn(&mut RenderContext) + Send + Sync>,
151}
152
153impl RenderGraphPass {
154    /// Create a new render pass.
155    pub fn new(
156        name: &'static str,
157        inputs: Vec<ResourceId>,
158        outputs: Vec<ResourceId>,
159        execute: impl Fn(&mut RenderContext) + Send + Sync + 'static,
160    ) -> Self {
161        Self {
162            name,
163            inputs,
164            outputs,
165            execute: Box::new(execute),
166        }
167    }
168}
169
170/// Execution plan for the render graph.
171#[derive(Debug, Clone)]
172pub struct ExecutionPlan {
173    /// Ordered list of pass IDs to execute
174    pub pass_order: Vec<PassId>,
175}
176
177/// Render graph error.
178#[derive(Debug, Clone, PartialEq, Eq)]
179pub enum RenderGraphError {
180    /// Cyclic dependency detected
181    CyclicDependency,
182    /// Resource not found
183    ResourceNotFound(ResourceId),
184    /// Pass not found
185    PassNotFound(PassId),
186    /// Invalid resource usage
187    InvalidUsage(String),
188}
189
190impl std::fmt::Display for RenderGraphError {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            Self::CyclicDependency => write!(f, "Cyclic dependency detected in render graph"),
194            Self::ResourceNotFound(id) => write!(f, "Resource {:?} not found", id),
195            Self::PassNotFound(id) => write!(f, "Pass {:?} not found", id),
196            Self::InvalidUsage(msg) => write!(f, "Invalid resource usage: {}", msg),
197        }
198    }
199}
200
201impl std::error::Error for RenderGraphError {}
202
203/// Render graph managing passes and resources.
204pub struct RenderGraph {
205    /// All render passes
206    passes: HashMap<PassId, RenderGraphPass>,
207    /// All resources
208    resources: HashMap<ResourceId, ResourceInfo>,
209    /// Next pass ID
210    next_pass_id: u64,
211    /// Next resource ID
212    next_resource_id: u64,
213    /// Execution plan (cached after compilation)
214    execution_plan: Option<ExecutionPlan>,
215}
216
217impl RenderGraph {
218    /// Create a new render graph.
219    pub fn new() -> Self {
220        Self {
221            passes: HashMap::new(),
222            resources: HashMap::new(),
223            next_pass_id: 0,
224            next_resource_id: 0,
225            execution_plan: None,
226        }
227    }
228
229    /// Add a texture resource to the graph.
230    pub fn add_texture(
231        &mut self,
232        name: impl Into<String>,
233        size: (u32, u32, u32),
234        format: wgpu::TextureFormat,
235        usage: wgpu::TextureUsages,
236    ) -> ResourceId {
237        let id = ResourceId::new(self.next_resource_id);
238        self.next_resource_id += 1;
239
240        let resource = ResourceInfo {
241            id,
242            resource_type: ResourceType::Texture {
243                size,
244                format,
245                usage,
246            },
247            name: name.into(),
248            first_read: None,
249            last_write: None,
250            last_read: None,
251        };
252
253        self.resources.insert(id, resource);
254        self.execution_plan = None; // Invalidate plan
255
256        id
257    }
258
259    /// Add a buffer resource to the graph.
260    pub fn add_buffer(
261        &mut self,
262        name: impl Into<String>,
263        size: u64,
264        usage: wgpu::BufferUsages,
265    ) -> ResourceId {
266        let id = ResourceId::new(self.next_resource_id);
267        self.next_resource_id += 1;
268
269        let resource = ResourceInfo {
270            id,
271            resource_type: ResourceType::Buffer { size, usage },
272            name: name.into(),
273            first_read: None,
274            last_write: None,
275            last_read: None,
276        };
277
278        self.resources.insert(id, resource);
279        self.execution_plan = None; // Invalidate plan
280
281        id
282    }
283
284    /// Add a render pass to the graph.
285    pub fn add_pass(&mut self, pass: RenderGraphPass) -> PassId {
286        let id = PassId::new(self.next_pass_id);
287        self.next_pass_id += 1;
288
289        // Update resource usage tracking
290        for &input_id in &pass.inputs {
291            if let Some(resource) = self.resources.get_mut(&input_id) {
292                if resource.first_read.is_none() {
293                    resource.first_read = Some(id);
294                }
295                resource.last_read = Some(id);
296            }
297        }
298
299        for &output_id in &pass.outputs {
300            if let Some(resource) = self.resources.get_mut(&output_id) {
301                resource.last_write = Some(id);
302            }
303        }
304
305        self.passes.insert(id, pass);
306        self.execution_plan = None; // Invalidate plan
307
308        id
309    }
310
311    /// Compile the render graph into an execution plan.
312    ///
313    /// This performs topological sorting of passes based on their dependencies.
314    pub fn compile(&mut self) -> Result<ExecutionPlan, RenderGraphError> {
315        profile_function!();
316        // Build dependency graph
317        let mut dependencies: HashMap<PassId, HashSet<PassId>> = HashMap::new();
318        let mut dependents: HashMap<PassId, HashSet<PassId>> = HashMap::new();
319
320        for (&pass_id, pass) in &self.passes {
321            dependencies.insert(pass_id, HashSet::new());
322            dependents.entry(pass_id).or_default();
323
324            // A pass depends on any pass that writes to its input resources
325            for &input_id in &pass.inputs {
326                // Find passes that write to this resource
327                for (&other_pass_id, other_pass) in &self.passes {
328                    if other_pass_id != pass_id && other_pass.outputs.contains(&input_id) {
329                        dependencies
330                            .get_mut(&pass_id)
331                            .unwrap()
332                            .insert(other_pass_id);
333                        dependents.entry(other_pass_id).or_default().insert(pass_id);
334                    }
335                }
336            }
337        }
338
339        // Topological sort using Kahn's algorithm
340        let mut sorted = Vec::new();
341        let mut no_incoming: Vec<PassId> = dependencies
342            .iter()
343            .filter(|(_, deps)| deps.is_empty())
344            .map(|(&id, _)| id)
345            .collect();
346
347        while let Some(pass_id) = no_incoming.pop() {
348            sorted.push(pass_id);
349
350            // Remove edges from this pass to its dependents
351            if let Some(deps) = dependents.get(&pass_id) {
352                for &dependent_id in deps {
353                    if let Some(dep_set) = dependencies.get_mut(&dependent_id) {
354                        dep_set.remove(&pass_id);
355                        if dep_set.is_empty() {
356                            no_incoming.push(dependent_id);
357                        }
358                    }
359                }
360            }
361        }
362
363        // Check for cycles
364        if sorted.len() != self.passes.len() {
365            return Err(RenderGraphError::CyclicDependency);
366        }
367
368        let plan = ExecutionPlan { pass_order: sorted };
369
370        self.execution_plan = Some(plan.clone());
371
372        Ok(plan)
373    }
374
375    /// Execute the render graph.
376    ///
377    /// This must be called after `compile()`.
378    pub fn execute(&self, graphics: Arc<GraphicsContext>) -> Result<(), RenderGraphError> {
379        profile_function!();
380        let plan = self
381            .execution_plan
382            .as_ref()
383            .ok_or(RenderGraphError::InvalidUsage(
384                "Graph not compiled".to_string(),
385            ))?;
386
387        let mut context = RenderContext::new(graphics);
388
389        // Create resources (simplified - in reality would manage lifetimes)
390        for (id, info) in &self.resources {
391            match &info.resource_type {
392                ResourceType::Texture {
393                    size,
394                    format,
395                    usage,
396                } => {
397                    let texture =
398                        context
399                            .graphics
400                            .device()
401                            .create_texture(&wgpu::TextureDescriptor {
402                                label: Some(&info.name),
403                                size: wgpu::Extent3d {
404                                    width: size.0,
405                                    height: size.1,
406                                    depth_or_array_layers: size.2,
407                                },
408                                mip_level_count: 1,
409                                sample_count: 1,
410                                dimension: wgpu::TextureDimension::D2,
411                                format: *format,
412                                usage: *usage,
413                                view_formats: &[],
414                            });
415                    context.textures.insert(*id, texture);
416                }
417                ResourceType::Buffer { size, usage } => {
418                    let buffer = context
419                        .graphics
420                        .device()
421                        .create_buffer(&wgpu::BufferDescriptor {
422                            label: Some(&info.name),
423                            size: *size,
424                            usage: *usage,
425                            mapped_at_creation: false,
426                        });
427                    context.buffers.insert(*id, buffer);
428                }
429            }
430        }
431
432        // Execute passes in order
433        for &pass_id in &plan.pass_order {
434            if let Some(pass) = self.passes.get(&pass_id) {
435                (pass.execute)(&mut context);
436            }
437        }
438
439        Ok(())
440    }
441
442    /// Get the number of passes in the graph.
443    pub fn pass_count(&self) -> usize {
444        self.passes.len()
445    }
446
447    /// Get the number of resources in the graph.
448    pub fn resource_count(&self) -> usize {
449        self.resources.len()
450    }
451
452    /// Check if the graph has been compiled.
453    pub fn is_compiled(&self) -> bool {
454        self.execution_plan.is_some()
455    }
456}
457
458impl Default for RenderGraph {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn test_render_graph_new() {
470        let graph = RenderGraph::new();
471        assert_eq!(graph.pass_count(), 0);
472        assert_eq!(graph.resource_count(), 0);
473        assert!(!graph.is_compiled());
474    }
475
476    #[test]
477    fn test_add_texture_resource() {
478        let mut graph = RenderGraph::new();
479        let tex = graph.add_texture(
480            "color_target",
481            (800, 600, 1),
482            wgpu::TextureFormat::Rgba8Unorm,
483            wgpu::TextureUsages::RENDER_ATTACHMENT,
484        );
485        assert_eq!(graph.resource_count(), 1);
486        assert_eq!(tex.as_u64(), 0);
487    }
488
489    #[test]
490    fn test_add_buffer_resource() {
491        let mut graph = RenderGraph::new();
492        let buf = graph.add_buffer("vertex_buffer", 1024, wgpu::BufferUsages::VERTEX);
493        assert_eq!(graph.resource_count(), 1);
494        assert_eq!(buf.as_u64(), 0);
495    }
496
497    #[test]
498    fn test_add_pass() {
499        let mut graph = RenderGraph::new();
500        let tex = graph.add_texture(
501            "target",
502            (800, 600, 1),
503            wgpu::TextureFormat::Rgba8Unorm,
504            wgpu::TextureUsages::RENDER_ATTACHMENT,
505        );
506
507        let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
508        let pass_id = graph.add_pass(pass);
509
510        assert_eq!(graph.pass_count(), 1);
511        assert_eq!(pass_id.as_u64(), 0);
512    }
513
514    #[test]
515    fn test_compile_simple() {
516        let mut graph = RenderGraph::new();
517        let tex = graph.add_texture(
518            "target",
519            (800, 600, 1),
520            wgpu::TextureFormat::Rgba8Unorm,
521            wgpu::TextureUsages::RENDER_ATTACHMENT,
522        );
523
524        let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
525        graph.add_pass(pass);
526
527        let result = graph.compile();
528        assert!(result.is_ok());
529        assert!(graph.is_compiled());
530    }
531
532    #[test]
533    fn test_compile_multiple_passes() {
534        let mut graph = RenderGraph::new();
535        let tex1 = graph.add_texture(
536            "tex1",
537            (800, 600, 1),
538            wgpu::TextureFormat::Rgba8Unorm,
539            wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
540        );
541        let tex2 = graph.add_texture(
542            "tex2",
543            (800, 600, 1),
544            wgpu::TextureFormat::Rgba8Unorm,
545            wgpu::TextureUsages::RENDER_ATTACHMENT,
546        );
547
548        // Pass 1 writes to tex1
549        let pass1 = RenderGraphPass::new("pass1", vec![], vec![tex1], |_ctx| {});
550        graph.add_pass(pass1);
551
552        // Pass 2 reads tex1 and writes to tex2
553        let pass2 = RenderGraphPass::new("pass2", vec![tex1], vec![tex2], |_ctx| {});
554        graph.add_pass(pass2);
555
556        let result = graph.compile();
557        assert!(result.is_ok());
558
559        let plan = result.unwrap();
560        assert_eq!(plan.pass_order.len(), 2);
561        // Pass 1 should come before pass 2
562        assert!(plan.pass_order[0].as_u64() < plan.pass_order[1].as_u64());
563    }
564
565    #[test]
566    fn test_resource_id_equality() {
567        let id1 = ResourceId::new(1);
568        let id2 = ResourceId::new(1);
569        let id3 = ResourceId::new(2);
570        assert_eq!(id1, id2);
571        assert_ne!(id1, id3);
572    }
573
574    #[test]
575    fn test_pass_id_equality() {
576        let id1 = PassId::new(1);
577        let id2 = PassId::new(1);
578        let id3 = PassId::new(2);
579        assert_eq!(id1, id2);
580        assert_ne!(id1, id3);
581    }
582
583    #[test]
584    fn test_error_display() {
585        let err = RenderGraphError::CyclicDependency;
586        assert!(format!("{}", err).contains("Cyclic"));
587
588        let err = RenderGraphError::ResourceNotFound(ResourceId::new(42));
589        assert!(format!("{}", err).contains("Resource"));
590    }
591}