Skip to main content

astrelis_render/
render_graph.rs

1//! Render graph system for automatic resource management and pass scheduling.
2//!
3//! The render graph provides:
4//! - Automatic resource barriers and transitions
5//! - Topological sort of render passes based on dependencies
6//! - Resource lifetime tracking for optimization
7//! - Clear dependency visualization
8//!
9//! # Example
10//!
11//! ```ignore
12//! use astrelis_render::*;
13//!
14//! let mut graph = RenderGraph::new();
15//!
16//! // Add resources
17//! let color_target = graph.add_texture(TextureDescriptor {
18//!     size: (800, 600, 1),
19//!     format: TextureFormat::Rgba8Unorm,
20//!     usage: TextureUsages::RENDER_ATTACHMENT | TextureUsages::TEXTURE_BINDING,
21//!     ..Default::default()
22//! });
23//!
24//! // Add passes
25//! graph.add_pass(RenderGraphPass {
26//!     name: "main_pass",
27//!     inputs: vec![],
28//!     outputs: vec![color_target],
29//!     execute: Box::new(|ctx| {
30//!         // Render code here
31//!     }),
32//! });
33//!
34//! // Compile and execute
35//! let plan = graph.compile()?;
36//! graph.execute(&context);
37//! ```
38
39use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
40use astrelis_core::profiling::profile_function;
41use std::sync::Arc;
42
43use crate::GraphicsContext;
44
45/// Resource identifier in the render graph.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct ResourceId(u64);
48
49impl ResourceId {
50    /// Create a new resource ID.
51    pub fn new(id: u64) -> Self {
52        Self(id)
53    }
54
55    /// Get the raw ID value.
56    pub fn as_u64(&self) -> u64 {
57        self.0
58    }
59}
60
61/// Pass identifier in the render graph.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63pub struct PassId(u64);
64
65impl PassId {
66    /// Create a new pass ID.
67    pub fn new(id: u64) -> Self {
68        Self(id)
69    }
70
71    /// Get the raw ID value.
72    pub fn as_u64(&self) -> u64 {
73        self.0
74    }
75}
76
77/// Resource type in the render graph.
78#[derive(Debug, Clone, PartialEq)]
79pub enum ResourceType {
80    /// Texture resource
81    Texture {
82        size: (u32, u32, u32),
83        format: wgpu::TextureFormat,
84        usage: wgpu::TextureUsages,
85    },
86    /// Buffer resource
87    Buffer {
88        size: u64,
89        usage: wgpu::BufferUsages,
90    },
91}
92
93/// Resource information in the render graph.
94#[derive(Debug, Clone)]
95pub struct ResourceInfo {
96    /// Resource ID
97    pub id: ResourceId,
98    /// Resource type and descriptor
99    pub resource_type: ResourceType,
100    /// Resource name for debugging
101    pub name: String,
102    /// First pass that reads this resource
103    pub first_read: Option<PassId>,
104    /// Last pass that writes this resource
105    pub last_write: Option<PassId>,
106    /// Last pass that reads this resource
107    pub last_read: Option<PassId>,
108}
109
110/// Render context passed to pass execution functions.
111pub struct RenderContext {
112    /// Graphics context
113    pub graphics: Arc<GraphicsContext>,
114    /// Resource textures (if created)
115    pub textures: HashMap<ResourceId, wgpu::Texture>,
116    /// Resource buffers (if created)
117    pub buffers: HashMap<ResourceId, wgpu::Buffer>,
118}
119
120impl RenderContext {
121    /// Create a new render context.
122    pub fn new(graphics: Arc<GraphicsContext>) -> Self {
123        Self {
124            graphics,
125            textures: HashMap::new(),
126            buffers: HashMap::new(),
127        }
128    }
129
130    /// Get a texture by resource ID.
131    pub fn get_texture(&self, id: ResourceId) -> Option<&wgpu::Texture> {
132        self.textures.get(&id)
133    }
134
135    /// Get a buffer by resource ID.
136    pub fn get_buffer(&self, id: ResourceId) -> Option<&wgpu::Buffer> {
137        self.buffers.get(&id)
138    }
139}
140
141/// A render pass in the graph.
142pub struct RenderGraphPass {
143    /// Pass name for debugging
144    pub name: &'static str,
145    /// Input resources (read)
146    pub inputs: Vec<ResourceId>,
147    /// Output resources (write)
148    pub outputs: Vec<ResourceId>,
149    /// Execution function
150    pub execute: Box<dyn Fn(&mut RenderContext) + Send + Sync>,
151}
152
153impl RenderGraphPass {
154    /// Create a new render pass.
155    pub fn new(
156        name: &'static str,
157        inputs: Vec<ResourceId>,
158        outputs: Vec<ResourceId>,
159        execute: impl Fn(&mut RenderContext) + Send + Sync + 'static,
160    ) -> Self {
161        Self {
162            name,
163            inputs,
164            outputs,
165            execute: Box::new(execute),
166        }
167    }
168}
169
170/// Execution plan for the render graph.
171#[derive(Debug, Clone)]
172pub struct ExecutionPlan {
173    /// Ordered list of pass IDs to execute
174    pub pass_order: Vec<PassId>,
175}
176
177/// Render graph error.
178#[derive(Debug, Clone, PartialEq, Eq)]
179pub enum RenderGraphError {
180    /// Cyclic dependency detected
181    CyclicDependency,
182    /// Resource not found
183    ResourceNotFound(ResourceId),
184    /// Pass not found
185    PassNotFound(PassId),
186    /// Invalid resource usage
187    InvalidUsage(String),
188}
189
190impl std::fmt::Display for RenderGraphError {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            Self::CyclicDependency => write!(f, "Cyclic dependency detected in render graph"),
194            Self::ResourceNotFound(id) => write!(f, "Resource {:?} not found", id),
195            Self::PassNotFound(id) => write!(f, "Pass {:?} not found", id),
196            Self::InvalidUsage(msg) => write!(f, "Invalid resource usage: {}", msg),
197        }
198    }
199}
200
201impl std::error::Error for RenderGraphError {}
202
203/// Render graph managing passes and resources.
204pub struct RenderGraph {
205    /// All render passes
206    passes: HashMap<PassId, RenderGraphPass>,
207    /// All resources
208    resources: HashMap<ResourceId, ResourceInfo>,
209    /// Next pass ID
210    next_pass_id: u64,
211    /// Next resource ID
212    next_resource_id: u64,
213    /// Execution plan (cached after compilation)
214    execution_plan: Option<ExecutionPlan>,
215}
216
217impl RenderGraph {
218    /// Create a new render graph.
219    pub fn new() -> Self {
220        Self {
221            passes: HashMap::new(),
222            resources: HashMap::new(),
223            next_pass_id: 0,
224            next_resource_id: 0,
225            execution_plan: None,
226        }
227    }
228
229    /// Add a texture resource to the graph.
230    pub fn add_texture(
231        &mut self,
232        name: impl Into<String>,
233        size: (u32, u32, u32),
234        format: wgpu::TextureFormat,
235        usage: wgpu::TextureUsages,
236    ) -> ResourceId {
237        let id = ResourceId::new(self.next_resource_id);
238        self.next_resource_id += 1;
239
240        let resource = ResourceInfo {
241            id,
242            resource_type: ResourceType::Texture {
243                size,
244                format,
245                usage,
246            },
247            name: name.into(),
248            first_read: None,
249            last_write: None,
250            last_read: None,
251        };
252
253        self.resources.insert(id, resource);
254        self.execution_plan = None; // Invalidate plan
255
256        id
257    }
258
259    /// Add a buffer resource to the graph.
260    pub fn add_buffer(
261        &mut self,
262        name: impl Into<String>,
263        size: u64,
264        usage: wgpu::BufferUsages,
265    ) -> ResourceId {
266        let id = ResourceId::new(self.next_resource_id);
267        self.next_resource_id += 1;
268
269        let resource = ResourceInfo {
270            id,
271            resource_type: ResourceType::Buffer { size, usage },
272            name: name.into(),
273            first_read: None,
274            last_write: None,
275            last_read: None,
276        };
277
278        self.resources.insert(id, resource);
279        self.execution_plan = None; // Invalidate plan
280
281        id
282    }
283
284    /// Add a render pass to the graph.
285    pub fn add_pass(&mut self, pass: RenderGraphPass) -> PassId {
286        let id = PassId::new(self.next_pass_id);
287        self.next_pass_id += 1;
288
289        // Update resource usage tracking
290        for &input_id in &pass.inputs {
291            if let Some(resource) = self.resources.get_mut(&input_id) {
292                if resource.first_read.is_none() {
293                    resource.first_read = Some(id);
294                }
295                resource.last_read = Some(id);
296            }
297        }
298
299        for &output_id in &pass.outputs {
300            if let Some(resource) = self.resources.get_mut(&output_id) {
301                resource.last_write = Some(id);
302            }
303        }
304
305        self.passes.insert(id, pass);
306        self.execution_plan = None; // Invalidate plan
307
308        id
309    }
310
311    /// Compile the render graph into an execution plan.
312    ///
313    /// This performs topological sorting of passes based on their dependencies.
314    pub fn compile(&mut self) -> Result<ExecutionPlan, RenderGraphError> {
315        profile_function!();
316        // Build dependency graph
317        let mut dependencies: HashMap<PassId, HashSet<PassId>> = HashMap::new();
318        let mut dependents: HashMap<PassId, HashSet<PassId>> = HashMap::new();
319
320        for (&pass_id, pass) in &self.passes {
321            dependencies.insert(pass_id, HashSet::new());
322            dependents.entry(pass_id).or_default();
323
324            // A pass depends on any pass that writes to its input resources
325            for &input_id in &pass.inputs {
326                // Find passes that write to this resource
327                for (&other_pass_id, other_pass) in &self.passes {
328                    if other_pass_id != pass_id && other_pass.outputs.contains(&input_id) {
329                        dependencies.get_mut(&pass_id).unwrap().insert(other_pass_id);
330                        dependents.entry(other_pass_id).or_default().insert(pass_id);
331                    }
332                }
333            }
334        }
335
336        // Topological sort using Kahn's algorithm
337        let mut sorted = Vec::new();
338        let mut no_incoming: Vec<PassId> = dependencies
339            .iter()
340            .filter(|(_, deps)| deps.is_empty())
341            .map(|(&id, _)| id)
342            .collect();
343
344        while let Some(pass_id) = no_incoming.pop() {
345            sorted.push(pass_id);
346
347            // Remove edges from this pass to its dependents
348            if let Some(deps) = dependents.get(&pass_id) {
349                for &dependent_id in deps {
350                    if let Some(dep_set) = dependencies.get_mut(&dependent_id) {
351                        dep_set.remove(&pass_id);
352                        if dep_set.is_empty() {
353                            no_incoming.push(dependent_id);
354                        }
355                    }
356                }
357            }
358        }
359
360        // Check for cycles
361        if sorted.len() != self.passes.len() {
362            return Err(RenderGraphError::CyclicDependency);
363        }
364
365        let plan = ExecutionPlan {
366            pass_order: sorted,
367        };
368
369        self.execution_plan = Some(plan.clone());
370
371        Ok(plan)
372    }
373
374    /// Execute the render graph.
375    ///
376    /// This must be called after `compile()`.
377    pub fn execute(&self, graphics: Arc<GraphicsContext>) -> Result<(), RenderGraphError> {
378        profile_function!();
379        let plan = self
380            .execution_plan
381            .as_ref()
382            .ok_or(RenderGraphError::InvalidUsage(
383                "Graph not compiled".to_string(),
384            ))?;
385
386        let mut context = RenderContext::new(graphics);
387
388        // Create resources (simplified - in reality would manage lifetimes)
389        for (id, info) in &self.resources {
390            match &info.resource_type {
391                ResourceType::Texture {
392                    size,
393                    format,
394                    usage,
395                } => {
396                    let texture = context.graphics.device().create_texture(&wgpu::TextureDescriptor {
397                        label: Some(&info.name),
398                        size: wgpu::Extent3d {
399                            width: size.0,
400                            height: size.1,
401                            depth_or_array_layers: size.2,
402                        },
403                        mip_level_count: 1,
404                        sample_count: 1,
405                        dimension: wgpu::TextureDimension::D2,
406                        format: *format,
407                        usage: *usage,
408                        view_formats: &[],
409                    });
410                    context.textures.insert(*id, texture);
411                }
412                ResourceType::Buffer { size, usage } => {
413                    let buffer = context.graphics.device().create_buffer(&wgpu::BufferDescriptor {
414                        label: Some(&info.name),
415                        size: *size,
416                        usage: *usage,
417                        mapped_at_creation: false,
418                    });
419                    context.buffers.insert(*id, buffer);
420                }
421            }
422        }
423
424        // Execute passes in order
425        for &pass_id in &plan.pass_order {
426            if let Some(pass) = self.passes.get(&pass_id) {
427                (pass.execute)(&mut context);
428            }
429        }
430
431        Ok(())
432    }
433
434    /// Get the number of passes in the graph.
435    pub fn pass_count(&self) -> usize {
436        self.passes.len()
437    }
438
439    /// Get the number of resources in the graph.
440    pub fn resource_count(&self) -> usize {
441        self.resources.len()
442    }
443
444    /// Check if the graph has been compiled.
445    pub fn is_compiled(&self) -> bool {
446        self.execution_plan.is_some()
447    }
448}
449
450impl Default for RenderGraph {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_render_graph_new() {
462        let graph = RenderGraph::new();
463        assert_eq!(graph.pass_count(), 0);
464        assert_eq!(graph.resource_count(), 0);
465        assert!(!graph.is_compiled());
466    }
467
468    #[test]
469    fn test_add_texture_resource() {
470        let mut graph = RenderGraph::new();
471        let tex = graph.add_texture(
472            "color_target",
473            (800, 600, 1),
474            wgpu::TextureFormat::Rgba8Unorm,
475            wgpu::TextureUsages::RENDER_ATTACHMENT,
476        );
477        assert_eq!(graph.resource_count(), 1);
478        assert_eq!(tex.as_u64(), 0);
479    }
480
481    #[test]
482    fn test_add_buffer_resource() {
483        let mut graph = RenderGraph::new();
484        let buf = graph.add_buffer(
485            "vertex_buffer",
486            1024,
487            wgpu::BufferUsages::VERTEX,
488        );
489        assert_eq!(graph.resource_count(), 1);
490        assert_eq!(buf.as_u64(), 0);
491    }
492
493    #[test]
494    fn test_add_pass() {
495        let mut graph = RenderGraph::new();
496        let tex = graph.add_texture(
497            "target",
498            (800, 600, 1),
499            wgpu::TextureFormat::Rgba8Unorm,
500            wgpu::TextureUsages::RENDER_ATTACHMENT,
501        );
502
503        let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
504        let pass_id = graph.add_pass(pass);
505
506        assert_eq!(graph.pass_count(), 1);
507        assert_eq!(pass_id.as_u64(), 0);
508    }
509
510    #[test]
511    fn test_compile_simple() {
512        let mut graph = RenderGraph::new();
513        let tex = graph.add_texture(
514            "target",
515            (800, 600, 1),
516            wgpu::TextureFormat::Rgba8Unorm,
517            wgpu::TextureUsages::RENDER_ATTACHMENT,
518        );
519
520        let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
521        graph.add_pass(pass);
522
523        let result = graph.compile();
524        assert!(result.is_ok());
525        assert!(graph.is_compiled());
526    }
527
528    #[test]
529    fn test_compile_multiple_passes() {
530        let mut graph = RenderGraph::new();
531        let tex1 = graph.add_texture(
532            "tex1",
533            (800, 600, 1),
534            wgpu::TextureFormat::Rgba8Unorm,
535            wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
536        );
537        let tex2 = graph.add_texture(
538            "tex2",
539            (800, 600, 1),
540            wgpu::TextureFormat::Rgba8Unorm,
541            wgpu::TextureUsages::RENDER_ATTACHMENT,
542        );
543
544        // Pass 1 writes to tex1
545        let pass1 = RenderGraphPass::new("pass1", vec![], vec![tex1], |_ctx| {});
546        graph.add_pass(pass1);
547
548        // Pass 2 reads tex1 and writes to tex2
549        let pass2 = RenderGraphPass::new("pass2", vec![tex1], vec![tex2], |_ctx| {});
550        graph.add_pass(pass2);
551
552        let result = graph.compile();
553        assert!(result.is_ok());
554
555        let plan = result.unwrap();
556        assert_eq!(plan.pass_order.len(), 2);
557        // Pass 1 should come before pass 2
558        assert!(plan.pass_order[0].as_u64() < plan.pass_order[1].as_u64());
559    }
560
561    #[test]
562    fn test_resource_id_equality() {
563        let id1 = ResourceId::new(1);
564        let id2 = ResourceId::new(1);
565        let id3 = ResourceId::new(2);
566        assert_eq!(id1, id2);
567        assert_ne!(id1, id3);
568    }
569
570    #[test]
571    fn test_pass_id_equality() {
572        let id1 = PassId::new(1);
573        let id2 = PassId::new(1);
574        let id3 = PassId::new(2);
575        assert_eq!(id1, id2);
576        assert_ne!(id1, id3);
577    }
578
579    #[test]
580    fn test_error_display() {
581        let err = RenderGraphError::CyclicDependency;
582        assert!(format!("{}", err).contains("Cyclic"));
583
584        let err = RenderGraphError::ResourceNotFound(ResourceId::new(42));
585        assert!(format!("{}", err).contains("Resource"));
586    }
587}