1use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
40use astrelis_core::profiling::profile_function;
41use std::sync::Arc;
42
43use crate::GraphicsContext;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct ResourceId(u64);
48
49impl ResourceId {
50 pub fn new(id: u64) -> Self {
52 Self(id)
53 }
54
55 pub fn as_u64(&self) -> u64 {
57 self.0
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63pub struct PassId(u64);
64
65impl PassId {
66 pub fn new(id: u64) -> Self {
68 Self(id)
69 }
70
71 pub fn as_u64(&self) -> u64 {
73 self.0
74 }
75}
76
77#[derive(Debug, Clone, PartialEq)]
79pub enum ResourceType {
80 Texture {
82 size: (u32, u32, u32),
83 format: wgpu::TextureFormat,
84 usage: wgpu::TextureUsages,
85 },
86 Buffer {
88 size: u64,
89 usage: wgpu::BufferUsages,
90 },
91}
92
93#[derive(Debug, Clone)]
95pub struct ResourceInfo {
96 pub id: ResourceId,
98 pub resource_type: ResourceType,
100 pub name: String,
102 pub first_read: Option<PassId>,
104 pub last_write: Option<PassId>,
106 pub last_read: Option<PassId>,
108}
109
110pub struct RenderContext {
112 pub graphics: Arc<GraphicsContext>,
114 pub textures: HashMap<ResourceId, wgpu::Texture>,
116 pub buffers: HashMap<ResourceId, wgpu::Buffer>,
118}
119
120impl RenderContext {
121 pub fn new(graphics: Arc<GraphicsContext>) -> Self {
123 Self {
124 graphics,
125 textures: HashMap::new(),
126 buffers: HashMap::new(),
127 }
128 }
129
130 pub fn get_texture(&self, id: ResourceId) -> Option<&wgpu::Texture> {
132 self.textures.get(&id)
133 }
134
135 pub fn get_buffer(&self, id: ResourceId) -> Option<&wgpu::Buffer> {
137 self.buffers.get(&id)
138 }
139}
140
141pub struct RenderGraphPass {
143 pub name: &'static str,
145 pub inputs: Vec<ResourceId>,
147 pub outputs: Vec<ResourceId>,
149 pub execute: Box<dyn Fn(&mut RenderContext) + Send + Sync>,
151}
152
153impl RenderGraphPass {
154 pub fn new(
156 name: &'static str,
157 inputs: Vec<ResourceId>,
158 outputs: Vec<ResourceId>,
159 execute: impl Fn(&mut RenderContext) + Send + Sync + 'static,
160 ) -> Self {
161 Self {
162 name,
163 inputs,
164 outputs,
165 execute: Box::new(execute),
166 }
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct ExecutionPlan {
173 pub pass_order: Vec<PassId>,
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
179pub enum RenderGraphError {
180 CyclicDependency,
182 ResourceNotFound(ResourceId),
184 PassNotFound(PassId),
186 InvalidUsage(String),
188}
189
190impl std::fmt::Display for RenderGraphError {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 Self::CyclicDependency => write!(f, "Cyclic dependency detected in render graph"),
194 Self::ResourceNotFound(id) => write!(f, "Resource {:?} not found", id),
195 Self::PassNotFound(id) => write!(f, "Pass {:?} not found", id),
196 Self::InvalidUsage(msg) => write!(f, "Invalid resource usage: {}", msg),
197 }
198 }
199}
200
201impl std::error::Error for RenderGraphError {}
202
203pub struct RenderGraph {
205 passes: HashMap<PassId, RenderGraphPass>,
207 resources: HashMap<ResourceId, ResourceInfo>,
209 next_pass_id: u64,
211 next_resource_id: u64,
213 execution_plan: Option<ExecutionPlan>,
215}
216
217impl RenderGraph {
218 pub fn new() -> Self {
220 Self {
221 passes: HashMap::new(),
222 resources: HashMap::new(),
223 next_pass_id: 0,
224 next_resource_id: 0,
225 execution_plan: None,
226 }
227 }
228
229 pub fn add_texture(
231 &mut self,
232 name: impl Into<String>,
233 size: (u32, u32, u32),
234 format: wgpu::TextureFormat,
235 usage: wgpu::TextureUsages,
236 ) -> ResourceId {
237 let id = ResourceId::new(self.next_resource_id);
238 self.next_resource_id += 1;
239
240 let resource = ResourceInfo {
241 id,
242 resource_type: ResourceType::Texture {
243 size,
244 format,
245 usage,
246 },
247 name: name.into(),
248 first_read: None,
249 last_write: None,
250 last_read: None,
251 };
252
253 self.resources.insert(id, resource);
254 self.execution_plan = None; id
257 }
258
259 pub fn add_buffer(
261 &mut self,
262 name: impl Into<String>,
263 size: u64,
264 usage: wgpu::BufferUsages,
265 ) -> ResourceId {
266 let id = ResourceId::new(self.next_resource_id);
267 self.next_resource_id += 1;
268
269 let resource = ResourceInfo {
270 id,
271 resource_type: ResourceType::Buffer { size, usage },
272 name: name.into(),
273 first_read: None,
274 last_write: None,
275 last_read: None,
276 };
277
278 self.resources.insert(id, resource);
279 self.execution_plan = None; id
282 }
283
284 pub fn add_pass(&mut self, pass: RenderGraphPass) -> PassId {
286 let id = PassId::new(self.next_pass_id);
287 self.next_pass_id += 1;
288
289 for &input_id in &pass.inputs {
291 if let Some(resource) = self.resources.get_mut(&input_id) {
292 if resource.first_read.is_none() {
293 resource.first_read = Some(id);
294 }
295 resource.last_read = Some(id);
296 }
297 }
298
299 for &output_id in &pass.outputs {
300 if let Some(resource) = self.resources.get_mut(&output_id) {
301 resource.last_write = Some(id);
302 }
303 }
304
305 self.passes.insert(id, pass);
306 self.execution_plan = None; id
309 }
310
311 pub fn compile(&mut self) -> Result<ExecutionPlan, RenderGraphError> {
315 profile_function!();
316 let mut dependencies: HashMap<PassId, HashSet<PassId>> = HashMap::new();
318 let mut dependents: HashMap<PassId, HashSet<PassId>> = HashMap::new();
319
320 for (&pass_id, pass) in &self.passes {
321 dependencies.insert(pass_id, HashSet::new());
322 dependents.entry(pass_id).or_default();
323
324 for &input_id in &pass.inputs {
326 for (&other_pass_id, other_pass) in &self.passes {
328 if other_pass_id != pass_id && other_pass.outputs.contains(&input_id) {
329 dependencies
330 .get_mut(&pass_id)
331 .unwrap()
332 .insert(other_pass_id);
333 dependents.entry(other_pass_id).or_default().insert(pass_id);
334 }
335 }
336 }
337 }
338
339 let mut sorted = Vec::new();
341 let mut no_incoming: Vec<PassId> = dependencies
342 .iter()
343 .filter(|(_, deps)| deps.is_empty())
344 .map(|(&id, _)| id)
345 .collect();
346
347 while let Some(pass_id) = no_incoming.pop() {
348 sorted.push(pass_id);
349
350 if let Some(deps) = dependents.get(&pass_id) {
352 for &dependent_id in deps {
353 if let Some(dep_set) = dependencies.get_mut(&dependent_id) {
354 dep_set.remove(&pass_id);
355 if dep_set.is_empty() {
356 no_incoming.push(dependent_id);
357 }
358 }
359 }
360 }
361 }
362
363 if sorted.len() != self.passes.len() {
365 return Err(RenderGraphError::CyclicDependency);
366 }
367
368 let plan = ExecutionPlan { pass_order: sorted };
369
370 self.execution_plan = Some(plan.clone());
371
372 Ok(plan)
373 }
374
375 pub fn execute(&self, graphics: Arc<GraphicsContext>) -> Result<(), RenderGraphError> {
379 profile_function!();
380 let plan = self
381 .execution_plan
382 .as_ref()
383 .ok_or(RenderGraphError::InvalidUsage(
384 "Graph not compiled".to_string(),
385 ))?;
386
387 let mut context = RenderContext::new(graphics);
388
389 for (id, info) in &self.resources {
391 match &info.resource_type {
392 ResourceType::Texture {
393 size,
394 format,
395 usage,
396 } => {
397 let texture =
398 context
399 .graphics
400 .device()
401 .create_texture(&wgpu::TextureDescriptor {
402 label: Some(&info.name),
403 size: wgpu::Extent3d {
404 width: size.0,
405 height: size.1,
406 depth_or_array_layers: size.2,
407 },
408 mip_level_count: 1,
409 sample_count: 1,
410 dimension: wgpu::TextureDimension::D2,
411 format: *format,
412 usage: *usage,
413 view_formats: &[],
414 });
415 context.textures.insert(*id, texture);
416 }
417 ResourceType::Buffer { size, usage } => {
418 let buffer = context
419 .graphics
420 .device()
421 .create_buffer(&wgpu::BufferDescriptor {
422 label: Some(&info.name),
423 size: *size,
424 usage: *usage,
425 mapped_at_creation: false,
426 });
427 context.buffers.insert(*id, buffer);
428 }
429 }
430 }
431
432 for &pass_id in &plan.pass_order {
434 if let Some(pass) = self.passes.get(&pass_id) {
435 (pass.execute)(&mut context);
436 }
437 }
438
439 Ok(())
440 }
441
442 pub fn pass_count(&self) -> usize {
444 self.passes.len()
445 }
446
447 pub fn resource_count(&self) -> usize {
449 self.resources.len()
450 }
451
452 pub fn is_compiled(&self) -> bool {
454 self.execution_plan.is_some()
455 }
456}
457
458impl Default for RenderGraph {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn test_render_graph_new() {
470 let graph = RenderGraph::new();
471 assert_eq!(graph.pass_count(), 0);
472 assert_eq!(graph.resource_count(), 0);
473 assert!(!graph.is_compiled());
474 }
475
476 #[test]
477 fn test_add_texture_resource() {
478 let mut graph = RenderGraph::new();
479 let tex = graph.add_texture(
480 "color_target",
481 (800, 600, 1),
482 wgpu::TextureFormat::Rgba8Unorm,
483 wgpu::TextureUsages::RENDER_ATTACHMENT,
484 );
485 assert_eq!(graph.resource_count(), 1);
486 assert_eq!(tex.as_u64(), 0);
487 }
488
489 #[test]
490 fn test_add_buffer_resource() {
491 let mut graph = RenderGraph::new();
492 let buf = graph.add_buffer("vertex_buffer", 1024, wgpu::BufferUsages::VERTEX);
493 assert_eq!(graph.resource_count(), 1);
494 assert_eq!(buf.as_u64(), 0);
495 }
496
497 #[test]
498 fn test_add_pass() {
499 let mut graph = RenderGraph::new();
500 let tex = graph.add_texture(
501 "target",
502 (800, 600, 1),
503 wgpu::TextureFormat::Rgba8Unorm,
504 wgpu::TextureUsages::RENDER_ATTACHMENT,
505 );
506
507 let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
508 let pass_id = graph.add_pass(pass);
509
510 assert_eq!(graph.pass_count(), 1);
511 assert_eq!(pass_id.as_u64(), 0);
512 }
513
514 #[test]
515 fn test_compile_simple() {
516 let mut graph = RenderGraph::new();
517 let tex = graph.add_texture(
518 "target",
519 (800, 600, 1),
520 wgpu::TextureFormat::Rgba8Unorm,
521 wgpu::TextureUsages::RENDER_ATTACHMENT,
522 );
523
524 let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
525 graph.add_pass(pass);
526
527 let result = graph.compile();
528 assert!(result.is_ok());
529 assert!(graph.is_compiled());
530 }
531
532 #[test]
533 fn test_compile_multiple_passes() {
534 let mut graph = RenderGraph::new();
535 let tex1 = graph.add_texture(
536 "tex1",
537 (800, 600, 1),
538 wgpu::TextureFormat::Rgba8Unorm,
539 wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
540 );
541 let tex2 = graph.add_texture(
542 "tex2",
543 (800, 600, 1),
544 wgpu::TextureFormat::Rgba8Unorm,
545 wgpu::TextureUsages::RENDER_ATTACHMENT,
546 );
547
548 let pass1 = RenderGraphPass::new("pass1", vec![], vec![tex1], |_ctx| {});
550 graph.add_pass(pass1);
551
552 let pass2 = RenderGraphPass::new("pass2", vec![tex1], vec![tex2], |_ctx| {});
554 graph.add_pass(pass2);
555
556 let result = graph.compile();
557 assert!(result.is_ok());
558
559 let plan = result.unwrap();
560 assert_eq!(plan.pass_order.len(), 2);
561 assert!(plan.pass_order[0].as_u64() < plan.pass_order[1].as_u64());
563 }
564
565 #[test]
566 fn test_resource_id_equality() {
567 let id1 = ResourceId::new(1);
568 let id2 = ResourceId::new(1);
569 let id3 = ResourceId::new(2);
570 assert_eq!(id1, id2);
571 assert_ne!(id1, id3);
572 }
573
574 #[test]
575 fn test_pass_id_equality() {
576 let id1 = PassId::new(1);
577 let id2 = PassId::new(1);
578 let id3 = PassId::new(2);
579 assert_eq!(id1, id2);
580 assert_ne!(id1, id3);
581 }
582
583 #[test]
584 fn test_error_display() {
585 let err = RenderGraphError::CyclicDependency;
586 assert!(format!("{}", err).contains("Cyclic"));
587
588 let err = RenderGraphError::ResourceNotFound(ResourceId::new(42));
589 assert!(format!("{}", err).contains("Resource"));
590 }
591}