astrelis_render/
render_graph.rs

1//! Render graph system for automatic resource management and pass scheduling.
2//!
3//! The render graph provides:
4//! - Automatic resource barriers and transitions
5//! - Topological sort of render passes based on dependencies
6//! - Resource lifetime tracking for optimization
7//! - Clear dependency visualization
8//!
9//! # Example
10//!
11//! ```ignore
12//! use astrelis_render::*;
13//!
14//! let mut graph = RenderGraph::new();
15//!
16//! // Add resources
17//! let color_target = graph.add_texture(TextureDescriptor {
18//!     size: (800, 600, 1),
19//!     format: TextureFormat::Rgba8Unorm,
20//!     usage: TextureUsages::RENDER_ATTACHMENT | TextureUsages::TEXTURE_BINDING,
21//!     ..Default::default()
22//! });
23//!
24//! // Add passes
25//! graph.add_pass(RenderGraphPass {
26//!     name: "main_pass",
27//!     inputs: vec![],
28//!     outputs: vec![color_target],
29//!     execute: Box::new(|ctx| {
30//!         // Render code here
31//!     }),
32//! });
33//!
34//! // Compile and execute
35//! let plan = graph.compile()?;
36//! graph.execute(&context);
37//! ```
38
39use std::collections::{HashMap, HashSet};
40use std::sync::Arc;
41
42use crate::GraphicsContext;
43
44/// Resource identifier in the render graph.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct ResourceId(u64);
47
48impl ResourceId {
49    /// Create a new resource ID.
50    pub fn new(id: u64) -> Self {
51        Self(id)
52    }
53
54    /// Get the raw ID value.
55    pub fn as_u64(&self) -> u64 {
56        self.0
57    }
58}
59
60/// Pass identifier in the render graph.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub struct PassId(u64);
63
64impl PassId {
65    /// Create a new pass ID.
66    pub fn new(id: u64) -> Self {
67        Self(id)
68    }
69
70    /// Get the raw ID value.
71    pub fn as_u64(&self) -> u64 {
72        self.0
73    }
74}
75
76/// Resource type in the render graph.
77#[derive(Debug, Clone, PartialEq)]
78pub enum ResourceType {
79    /// Texture resource
80    Texture {
81        size: (u32, u32, u32),
82        format: wgpu::TextureFormat,
83        usage: wgpu::TextureUsages,
84    },
85    /// Buffer resource
86    Buffer {
87        size: u64,
88        usage: wgpu::BufferUsages,
89    },
90}
91
92/// Resource information in the render graph.
93#[derive(Debug, Clone)]
94pub struct ResourceInfo {
95    /// Resource ID
96    pub id: ResourceId,
97    /// Resource type and descriptor
98    pub resource_type: ResourceType,
99    /// Resource name for debugging
100    pub name: String,
101    /// First pass that reads this resource
102    pub first_read: Option<PassId>,
103    /// Last pass that writes this resource
104    pub last_write: Option<PassId>,
105    /// Last pass that reads this resource
106    pub last_read: Option<PassId>,
107}
108
109/// Render context passed to pass execution functions.
110pub struct RenderContext {
111    /// Graphics context
112    pub graphics: Arc<GraphicsContext>,
113    /// Resource textures (if created)
114    pub textures: HashMap<ResourceId, wgpu::Texture>,
115    /// Resource buffers (if created)
116    pub buffers: HashMap<ResourceId, wgpu::Buffer>,
117}
118
119impl RenderContext {
120    /// Create a new render context.
121    pub fn new(graphics: Arc<GraphicsContext>) -> Self {
122        Self {
123            graphics,
124            textures: HashMap::new(),
125            buffers: HashMap::new(),
126        }
127    }
128
129    /// Get a texture by resource ID.
130    pub fn get_texture(&self, id: ResourceId) -> Option<&wgpu::Texture> {
131        self.textures.get(&id)
132    }
133
134    /// Get a buffer by resource ID.
135    pub fn get_buffer(&self, id: ResourceId) -> Option<&wgpu::Buffer> {
136        self.buffers.get(&id)
137    }
138}
139
140/// A render pass in the graph.
141pub struct RenderGraphPass {
142    /// Pass name for debugging
143    pub name: &'static str,
144    /// Input resources (read)
145    pub inputs: Vec<ResourceId>,
146    /// Output resources (write)
147    pub outputs: Vec<ResourceId>,
148    /// Execution function
149    pub execute: Box<dyn Fn(&mut RenderContext) + Send + Sync>,
150}
151
152impl RenderGraphPass {
153    /// Create a new render pass.
154    pub fn new(
155        name: &'static str,
156        inputs: Vec<ResourceId>,
157        outputs: Vec<ResourceId>,
158        execute: impl Fn(&mut RenderContext) + Send + Sync + 'static,
159    ) -> Self {
160        Self {
161            name,
162            inputs,
163            outputs,
164            execute: Box::new(execute),
165        }
166    }
167}
168
169/// Execution plan for the render graph.
170#[derive(Debug, Clone)]
171pub struct ExecutionPlan {
172    /// Ordered list of pass IDs to execute
173    pub pass_order: Vec<PassId>,
174}
175
176/// Render graph error.
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum RenderGraphError {
179    /// Cyclic dependency detected
180    CyclicDependency,
181    /// Resource not found
182    ResourceNotFound(ResourceId),
183    /// Pass not found
184    PassNotFound(PassId),
185    /// Invalid resource usage
186    InvalidUsage(String),
187}
188
189impl std::fmt::Display for RenderGraphError {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        match self {
192            Self::CyclicDependency => write!(f, "Cyclic dependency detected in render graph"),
193            Self::ResourceNotFound(id) => write!(f, "Resource {:?} not found", id),
194            Self::PassNotFound(id) => write!(f, "Pass {:?} not found", id),
195            Self::InvalidUsage(msg) => write!(f, "Invalid resource usage: {}", msg),
196        }
197    }
198}
199
200impl std::error::Error for RenderGraphError {}
201
202/// Render graph managing passes and resources.
203pub struct RenderGraph {
204    /// All render passes
205    passes: HashMap<PassId, RenderGraphPass>,
206    /// All resources
207    resources: HashMap<ResourceId, ResourceInfo>,
208    /// Next pass ID
209    next_pass_id: u64,
210    /// Next resource ID
211    next_resource_id: u64,
212    /// Execution plan (cached after compilation)
213    execution_plan: Option<ExecutionPlan>,
214}
215
216impl RenderGraph {
217    /// Create a new render graph.
218    pub fn new() -> Self {
219        Self {
220            passes: HashMap::new(),
221            resources: HashMap::new(),
222            next_pass_id: 0,
223            next_resource_id: 0,
224            execution_plan: None,
225        }
226    }
227
228    /// Add a texture resource to the graph.
229    pub fn add_texture(
230        &mut self,
231        name: impl Into<String>,
232        size: (u32, u32, u32),
233        format: wgpu::TextureFormat,
234        usage: wgpu::TextureUsages,
235    ) -> ResourceId {
236        let id = ResourceId::new(self.next_resource_id);
237        self.next_resource_id += 1;
238
239        let resource = ResourceInfo {
240            id,
241            resource_type: ResourceType::Texture {
242                size,
243                format,
244                usage,
245            },
246            name: name.into(),
247            first_read: None,
248            last_write: None,
249            last_read: None,
250        };
251
252        self.resources.insert(id, resource);
253        self.execution_plan = None; // Invalidate plan
254
255        id
256    }
257
258    /// Add a buffer resource to the graph.
259    pub fn add_buffer(
260        &mut self,
261        name: impl Into<String>,
262        size: u64,
263        usage: wgpu::BufferUsages,
264    ) -> ResourceId {
265        let id = ResourceId::new(self.next_resource_id);
266        self.next_resource_id += 1;
267
268        let resource = ResourceInfo {
269            id,
270            resource_type: ResourceType::Buffer { size, usage },
271            name: name.into(),
272            first_read: None,
273            last_write: None,
274            last_read: None,
275        };
276
277        self.resources.insert(id, resource);
278        self.execution_plan = None; // Invalidate plan
279
280        id
281    }
282
283    /// Add a render pass to the graph.
284    pub fn add_pass(&mut self, pass: RenderGraphPass) -> PassId {
285        let id = PassId::new(self.next_pass_id);
286        self.next_pass_id += 1;
287
288        // Update resource usage tracking
289        for &input_id in &pass.inputs {
290            if let Some(resource) = self.resources.get_mut(&input_id) {
291                if resource.first_read.is_none() {
292                    resource.first_read = Some(id);
293                }
294                resource.last_read = Some(id);
295            }
296        }
297
298        for &output_id in &pass.outputs {
299            if let Some(resource) = self.resources.get_mut(&output_id) {
300                resource.last_write = Some(id);
301            }
302        }
303
304        self.passes.insert(id, pass);
305        self.execution_plan = None; // Invalidate plan
306
307        id
308    }
309
310    /// Compile the render graph into an execution plan.
311    ///
312    /// This performs topological sorting of passes based on their dependencies.
313    pub fn compile(&mut self) -> Result<ExecutionPlan, RenderGraphError> {
314        // Build dependency graph
315        let mut dependencies: HashMap<PassId, HashSet<PassId>> = HashMap::new();
316        let mut dependents: HashMap<PassId, HashSet<PassId>> = HashMap::new();
317
318        for (&pass_id, pass) in &self.passes {
319            dependencies.insert(pass_id, HashSet::new());
320            dependents.entry(pass_id).or_insert_with(HashSet::new);
321
322            // A pass depends on any pass that writes to its input resources
323            for &input_id in &pass.inputs {
324                // Find passes that write to this resource
325                for (&other_pass_id, other_pass) in &self.passes {
326                    if other_pass_id != pass_id && other_pass.outputs.contains(&input_id) {
327                        dependencies.get_mut(&pass_id).unwrap().insert(other_pass_id);
328                        dependents.entry(other_pass_id).or_insert_with(HashSet::new).insert(pass_id);
329                    }
330                }
331            }
332        }
333
334        // Topological sort using Kahn's algorithm
335        let mut sorted = Vec::new();
336        let mut no_incoming: Vec<PassId> = dependencies
337            .iter()
338            .filter(|(_, deps)| deps.is_empty())
339            .map(|(&id, _)| id)
340            .collect();
341
342        while let Some(pass_id) = no_incoming.pop() {
343            sorted.push(pass_id);
344
345            // Remove edges from this pass to its dependents
346            if let Some(deps) = dependents.get(&pass_id) {
347                for &dependent_id in deps {
348                    if let Some(dep_set) = dependencies.get_mut(&dependent_id) {
349                        dep_set.remove(&pass_id);
350                        if dep_set.is_empty() {
351                            no_incoming.push(dependent_id);
352                        }
353                    }
354                }
355            }
356        }
357
358        // Check for cycles
359        if sorted.len() != self.passes.len() {
360            return Err(RenderGraphError::CyclicDependency);
361        }
362
363        let plan = ExecutionPlan {
364            pass_order: sorted,
365        };
366
367        self.execution_plan = Some(plan.clone());
368
369        Ok(plan)
370    }
371
372    /// Execute the render graph.
373    ///
374    /// This must be called after `compile()`.
375    pub fn execute(&self, graphics: Arc<GraphicsContext>) -> Result<(), RenderGraphError> {
376        let plan = self
377            .execution_plan
378            .as_ref()
379            .ok_or(RenderGraphError::InvalidUsage(
380                "Graph not compiled".to_string(),
381            ))?;
382
383        let mut context = RenderContext::new(graphics);
384
385        // Create resources (simplified - in reality would manage lifetimes)
386        for (id, info) in &self.resources {
387            match &info.resource_type {
388                ResourceType::Texture {
389                    size,
390                    format,
391                    usage,
392                } => {
393                    let texture = context.graphics.device.create_texture(&wgpu::TextureDescriptor {
394                        label: Some(&info.name),
395                        size: wgpu::Extent3d {
396                            width: size.0,
397                            height: size.1,
398                            depth_or_array_layers: size.2,
399                        },
400                        mip_level_count: 1,
401                        sample_count: 1,
402                        dimension: wgpu::TextureDimension::D2,
403                        format: *format,
404                        usage: *usage,
405                        view_formats: &[],
406                    });
407                    context.textures.insert(*id, texture);
408                }
409                ResourceType::Buffer { size, usage } => {
410                    let buffer = context.graphics.device.create_buffer(&wgpu::BufferDescriptor {
411                        label: Some(&info.name),
412                        size: *size,
413                        usage: *usage,
414                        mapped_at_creation: false,
415                    });
416                    context.buffers.insert(*id, buffer);
417                }
418            }
419        }
420
421        // Execute passes in order
422        for &pass_id in &plan.pass_order {
423            if let Some(pass) = self.passes.get(&pass_id) {
424                (pass.execute)(&mut context);
425            }
426        }
427
428        Ok(())
429    }
430
431    /// Get the number of passes in the graph.
432    pub fn pass_count(&self) -> usize {
433        self.passes.len()
434    }
435
436    /// Get the number of resources in the graph.
437    pub fn resource_count(&self) -> usize {
438        self.resources.len()
439    }
440
441    /// Check if the graph has been compiled.
442    pub fn is_compiled(&self) -> bool {
443        self.execution_plan.is_some()
444    }
445}
446
447impl Default for RenderGraph {
448    fn default() -> Self {
449        Self::new()
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_render_graph_new() {
459        let graph = RenderGraph::new();
460        assert_eq!(graph.pass_count(), 0);
461        assert_eq!(graph.resource_count(), 0);
462        assert!(!graph.is_compiled());
463    }
464
465    #[test]
466    fn test_add_texture_resource() {
467        let mut graph = RenderGraph::new();
468        let tex = graph.add_texture(
469            "color_target",
470            (800, 600, 1),
471            wgpu::TextureFormat::Rgba8Unorm,
472            wgpu::TextureUsages::RENDER_ATTACHMENT,
473        );
474        assert_eq!(graph.resource_count(), 1);
475        assert_eq!(tex.as_u64(), 0);
476    }
477
478    #[test]
479    fn test_add_buffer_resource() {
480        let mut graph = RenderGraph::new();
481        let buf = graph.add_buffer(
482            "vertex_buffer",
483            1024,
484            wgpu::BufferUsages::VERTEX,
485        );
486        assert_eq!(graph.resource_count(), 1);
487        assert_eq!(buf.as_u64(), 0);
488    }
489
490    #[test]
491    fn test_add_pass() {
492        let mut graph = RenderGraph::new();
493        let tex = graph.add_texture(
494            "target",
495            (800, 600, 1),
496            wgpu::TextureFormat::Rgba8Unorm,
497            wgpu::TextureUsages::RENDER_ATTACHMENT,
498        );
499
500        let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
501        let pass_id = graph.add_pass(pass);
502
503        assert_eq!(graph.pass_count(), 1);
504        assert_eq!(pass_id.as_u64(), 0);
505    }
506
507    #[test]
508    fn test_compile_simple() {
509        let mut graph = RenderGraph::new();
510        let tex = graph.add_texture(
511            "target",
512            (800, 600, 1),
513            wgpu::TextureFormat::Rgba8Unorm,
514            wgpu::TextureUsages::RENDER_ATTACHMENT,
515        );
516
517        let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
518        graph.add_pass(pass);
519
520        let result = graph.compile();
521        assert!(result.is_ok());
522        assert!(graph.is_compiled());
523    }
524
525    #[test]
526    fn test_compile_multiple_passes() {
527        let mut graph = RenderGraph::new();
528        let tex1 = graph.add_texture(
529            "tex1",
530            (800, 600, 1),
531            wgpu::TextureFormat::Rgba8Unorm,
532            wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
533        );
534        let tex2 = graph.add_texture(
535            "tex2",
536            (800, 600, 1),
537            wgpu::TextureFormat::Rgba8Unorm,
538            wgpu::TextureUsages::RENDER_ATTACHMENT,
539        );
540
541        // Pass 1 writes to tex1
542        let pass1 = RenderGraphPass::new("pass1", vec![], vec![tex1], |_ctx| {});
543        graph.add_pass(pass1);
544
545        // Pass 2 reads tex1 and writes to tex2
546        let pass2 = RenderGraphPass::new("pass2", vec![tex1], vec![tex2], |_ctx| {});
547        graph.add_pass(pass2);
548
549        let result = graph.compile();
550        assert!(result.is_ok());
551
552        let plan = result.unwrap();
553        assert_eq!(plan.pass_order.len(), 2);
554        // Pass 1 should come before pass 2
555        assert!(plan.pass_order[0].as_u64() < plan.pass_order[1].as_u64());
556    }
557
558    #[test]
559    fn test_resource_id_equality() {
560        let id1 = ResourceId::new(1);
561        let id2 = ResourceId::new(1);
562        let id3 = ResourceId::new(2);
563        assert_eq!(id1, id2);
564        assert_ne!(id1, id3);
565    }
566
567    #[test]
568    fn test_pass_id_equality() {
569        let id1 = PassId::new(1);
570        let id2 = PassId::new(1);
571        let id3 = PassId::new(2);
572        assert_eq!(id1, id2);
573        assert_ne!(id1, id3);
574    }
575
576    #[test]
577    fn test_error_display() {
578        let err = RenderGraphError::CyclicDependency;
579        assert!(format!("{}", err).contains("Cyclic"));
580
581        let err = RenderGraphError::ResourceNotFound(ResourceId::new(42));
582        assert!(format!("{}", err).contains("Resource"));
583    }
584}