1use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
40use astrelis_core::profiling::profile_function;
41use std::sync::Arc;
42
43use crate::GraphicsContext;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct ResourceId(u64);
48
49impl ResourceId {
50 pub fn new(id: u64) -> Self {
52 Self(id)
53 }
54
55 pub fn as_u64(&self) -> u64 {
57 self.0
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63pub struct PassId(u64);
64
65impl PassId {
66 pub fn new(id: u64) -> Self {
68 Self(id)
69 }
70
71 pub fn as_u64(&self) -> u64 {
73 self.0
74 }
75}
76
77#[derive(Debug, Clone, PartialEq)]
79pub enum ResourceType {
80 Texture {
82 size: (u32, u32, u32),
83 format: wgpu::TextureFormat,
84 usage: wgpu::TextureUsages,
85 },
86 Buffer {
88 size: u64,
89 usage: wgpu::BufferUsages,
90 },
91}
92
93#[derive(Debug, Clone)]
95pub struct ResourceInfo {
96 pub id: ResourceId,
98 pub resource_type: ResourceType,
100 pub name: String,
102 pub first_read: Option<PassId>,
104 pub last_write: Option<PassId>,
106 pub last_read: Option<PassId>,
108}
109
110pub struct RenderContext {
112 pub graphics: Arc<GraphicsContext>,
114 pub textures: HashMap<ResourceId, wgpu::Texture>,
116 pub buffers: HashMap<ResourceId, wgpu::Buffer>,
118}
119
120impl RenderContext {
121 pub fn new(graphics: Arc<GraphicsContext>) -> Self {
123 Self {
124 graphics,
125 textures: HashMap::new(),
126 buffers: HashMap::new(),
127 }
128 }
129
130 pub fn get_texture(&self, id: ResourceId) -> Option<&wgpu::Texture> {
132 self.textures.get(&id)
133 }
134
135 pub fn get_buffer(&self, id: ResourceId) -> Option<&wgpu::Buffer> {
137 self.buffers.get(&id)
138 }
139}
140
141pub struct RenderGraphPass {
143 pub name: &'static str,
145 pub inputs: Vec<ResourceId>,
147 pub outputs: Vec<ResourceId>,
149 pub execute: Box<dyn Fn(&mut RenderContext) + Send + Sync>,
151}
152
153impl RenderGraphPass {
154 pub fn new(
156 name: &'static str,
157 inputs: Vec<ResourceId>,
158 outputs: Vec<ResourceId>,
159 execute: impl Fn(&mut RenderContext) + Send + Sync + 'static,
160 ) -> Self {
161 Self {
162 name,
163 inputs,
164 outputs,
165 execute: Box::new(execute),
166 }
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct ExecutionPlan {
173 pub pass_order: Vec<PassId>,
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
179pub enum RenderGraphError {
180 CyclicDependency,
182 ResourceNotFound(ResourceId),
184 PassNotFound(PassId),
186 InvalidUsage(String),
188}
189
190impl std::fmt::Display for RenderGraphError {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 Self::CyclicDependency => write!(f, "Cyclic dependency detected in render graph"),
194 Self::ResourceNotFound(id) => write!(f, "Resource {:?} not found", id),
195 Self::PassNotFound(id) => write!(f, "Pass {:?} not found", id),
196 Self::InvalidUsage(msg) => write!(f, "Invalid resource usage: {}", msg),
197 }
198 }
199}
200
201impl std::error::Error for RenderGraphError {}
202
203pub struct RenderGraph {
205 passes: HashMap<PassId, RenderGraphPass>,
207 resources: HashMap<ResourceId, ResourceInfo>,
209 next_pass_id: u64,
211 next_resource_id: u64,
213 execution_plan: Option<ExecutionPlan>,
215}
216
217impl RenderGraph {
218 pub fn new() -> Self {
220 Self {
221 passes: HashMap::new(),
222 resources: HashMap::new(),
223 next_pass_id: 0,
224 next_resource_id: 0,
225 execution_plan: None,
226 }
227 }
228
229 pub fn add_texture(
231 &mut self,
232 name: impl Into<String>,
233 size: (u32, u32, u32),
234 format: wgpu::TextureFormat,
235 usage: wgpu::TextureUsages,
236 ) -> ResourceId {
237 let id = ResourceId::new(self.next_resource_id);
238 self.next_resource_id += 1;
239
240 let resource = ResourceInfo {
241 id,
242 resource_type: ResourceType::Texture {
243 size,
244 format,
245 usage,
246 },
247 name: name.into(),
248 first_read: None,
249 last_write: None,
250 last_read: None,
251 };
252
253 self.resources.insert(id, resource);
254 self.execution_plan = None; id
257 }
258
259 pub fn add_buffer(
261 &mut self,
262 name: impl Into<String>,
263 size: u64,
264 usage: wgpu::BufferUsages,
265 ) -> ResourceId {
266 let id = ResourceId::new(self.next_resource_id);
267 self.next_resource_id += 1;
268
269 let resource = ResourceInfo {
270 id,
271 resource_type: ResourceType::Buffer { size, usage },
272 name: name.into(),
273 first_read: None,
274 last_write: None,
275 last_read: None,
276 };
277
278 self.resources.insert(id, resource);
279 self.execution_plan = None; id
282 }
283
284 pub fn add_pass(&mut self, pass: RenderGraphPass) -> PassId {
286 let id = PassId::new(self.next_pass_id);
287 self.next_pass_id += 1;
288
289 for &input_id in &pass.inputs {
291 if let Some(resource) = self.resources.get_mut(&input_id) {
292 if resource.first_read.is_none() {
293 resource.first_read = Some(id);
294 }
295 resource.last_read = Some(id);
296 }
297 }
298
299 for &output_id in &pass.outputs {
300 if let Some(resource) = self.resources.get_mut(&output_id) {
301 resource.last_write = Some(id);
302 }
303 }
304
305 self.passes.insert(id, pass);
306 self.execution_plan = None; id
309 }
310
311 pub fn compile(&mut self) -> Result<ExecutionPlan, RenderGraphError> {
315 profile_function!();
316 let mut dependencies: HashMap<PassId, HashSet<PassId>> = HashMap::new();
318 let mut dependents: HashMap<PassId, HashSet<PassId>> = HashMap::new();
319
320 for (&pass_id, pass) in &self.passes {
321 dependencies.insert(pass_id, HashSet::new());
322 dependents.entry(pass_id).or_default();
323
324 for &input_id in &pass.inputs {
326 for (&other_pass_id, other_pass) in &self.passes {
328 if other_pass_id != pass_id && other_pass.outputs.contains(&input_id) {
329 dependencies.get_mut(&pass_id).unwrap().insert(other_pass_id);
330 dependents.entry(other_pass_id).or_default().insert(pass_id);
331 }
332 }
333 }
334 }
335
336 let mut sorted = Vec::new();
338 let mut no_incoming: Vec<PassId> = dependencies
339 .iter()
340 .filter(|(_, deps)| deps.is_empty())
341 .map(|(&id, _)| id)
342 .collect();
343
344 while let Some(pass_id) = no_incoming.pop() {
345 sorted.push(pass_id);
346
347 if let Some(deps) = dependents.get(&pass_id) {
349 for &dependent_id in deps {
350 if let Some(dep_set) = dependencies.get_mut(&dependent_id) {
351 dep_set.remove(&pass_id);
352 if dep_set.is_empty() {
353 no_incoming.push(dependent_id);
354 }
355 }
356 }
357 }
358 }
359
360 if sorted.len() != self.passes.len() {
362 return Err(RenderGraphError::CyclicDependency);
363 }
364
365 let plan = ExecutionPlan {
366 pass_order: sorted,
367 };
368
369 self.execution_plan = Some(plan.clone());
370
371 Ok(plan)
372 }
373
374 pub fn execute(&self, graphics: Arc<GraphicsContext>) -> Result<(), RenderGraphError> {
378 profile_function!();
379 let plan = self
380 .execution_plan
381 .as_ref()
382 .ok_or(RenderGraphError::InvalidUsage(
383 "Graph not compiled".to_string(),
384 ))?;
385
386 let mut context = RenderContext::new(graphics);
387
388 for (id, info) in &self.resources {
390 match &info.resource_type {
391 ResourceType::Texture {
392 size,
393 format,
394 usage,
395 } => {
396 let texture = context.graphics.device().create_texture(&wgpu::TextureDescriptor {
397 label: Some(&info.name),
398 size: wgpu::Extent3d {
399 width: size.0,
400 height: size.1,
401 depth_or_array_layers: size.2,
402 },
403 mip_level_count: 1,
404 sample_count: 1,
405 dimension: wgpu::TextureDimension::D2,
406 format: *format,
407 usage: *usage,
408 view_formats: &[],
409 });
410 context.textures.insert(*id, texture);
411 }
412 ResourceType::Buffer { size, usage } => {
413 let buffer = context.graphics.device().create_buffer(&wgpu::BufferDescriptor {
414 label: Some(&info.name),
415 size: *size,
416 usage: *usage,
417 mapped_at_creation: false,
418 });
419 context.buffers.insert(*id, buffer);
420 }
421 }
422 }
423
424 for &pass_id in &plan.pass_order {
426 if let Some(pass) = self.passes.get(&pass_id) {
427 (pass.execute)(&mut context);
428 }
429 }
430
431 Ok(())
432 }
433
434 pub fn pass_count(&self) -> usize {
436 self.passes.len()
437 }
438
439 pub fn resource_count(&self) -> usize {
441 self.resources.len()
442 }
443
444 pub fn is_compiled(&self) -> bool {
446 self.execution_plan.is_some()
447 }
448}
449
450impl Default for RenderGraph {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_render_graph_new() {
462 let graph = RenderGraph::new();
463 assert_eq!(graph.pass_count(), 0);
464 assert_eq!(graph.resource_count(), 0);
465 assert!(!graph.is_compiled());
466 }
467
468 #[test]
469 fn test_add_texture_resource() {
470 let mut graph = RenderGraph::new();
471 let tex = graph.add_texture(
472 "color_target",
473 (800, 600, 1),
474 wgpu::TextureFormat::Rgba8Unorm,
475 wgpu::TextureUsages::RENDER_ATTACHMENT,
476 );
477 assert_eq!(graph.resource_count(), 1);
478 assert_eq!(tex.as_u64(), 0);
479 }
480
481 #[test]
482 fn test_add_buffer_resource() {
483 let mut graph = RenderGraph::new();
484 let buf = graph.add_buffer(
485 "vertex_buffer",
486 1024,
487 wgpu::BufferUsages::VERTEX,
488 );
489 assert_eq!(graph.resource_count(), 1);
490 assert_eq!(buf.as_u64(), 0);
491 }
492
493 #[test]
494 fn test_add_pass() {
495 let mut graph = RenderGraph::new();
496 let tex = graph.add_texture(
497 "target",
498 (800, 600, 1),
499 wgpu::TextureFormat::Rgba8Unorm,
500 wgpu::TextureUsages::RENDER_ATTACHMENT,
501 );
502
503 let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
504 let pass_id = graph.add_pass(pass);
505
506 assert_eq!(graph.pass_count(), 1);
507 assert_eq!(pass_id.as_u64(), 0);
508 }
509
510 #[test]
511 fn test_compile_simple() {
512 let mut graph = RenderGraph::new();
513 let tex = graph.add_texture(
514 "target",
515 (800, 600, 1),
516 wgpu::TextureFormat::Rgba8Unorm,
517 wgpu::TextureUsages::RENDER_ATTACHMENT,
518 );
519
520 let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
521 graph.add_pass(pass);
522
523 let result = graph.compile();
524 assert!(result.is_ok());
525 assert!(graph.is_compiled());
526 }
527
528 #[test]
529 fn test_compile_multiple_passes() {
530 let mut graph = RenderGraph::new();
531 let tex1 = graph.add_texture(
532 "tex1",
533 (800, 600, 1),
534 wgpu::TextureFormat::Rgba8Unorm,
535 wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
536 );
537 let tex2 = graph.add_texture(
538 "tex2",
539 (800, 600, 1),
540 wgpu::TextureFormat::Rgba8Unorm,
541 wgpu::TextureUsages::RENDER_ATTACHMENT,
542 );
543
544 let pass1 = RenderGraphPass::new("pass1", vec![], vec![tex1], |_ctx| {});
546 graph.add_pass(pass1);
547
548 let pass2 = RenderGraphPass::new("pass2", vec![tex1], vec![tex2], |_ctx| {});
550 graph.add_pass(pass2);
551
552 let result = graph.compile();
553 assert!(result.is_ok());
554
555 let plan = result.unwrap();
556 assert_eq!(plan.pass_order.len(), 2);
557 assert!(plan.pass_order[0].as_u64() < plan.pass_order[1].as_u64());
559 }
560
561 #[test]
562 fn test_resource_id_equality() {
563 let id1 = ResourceId::new(1);
564 let id2 = ResourceId::new(1);
565 let id3 = ResourceId::new(2);
566 assert_eq!(id1, id2);
567 assert_ne!(id1, id3);
568 }
569
570 #[test]
571 fn test_pass_id_equality() {
572 let id1 = PassId::new(1);
573 let id2 = PassId::new(1);
574 let id3 = PassId::new(2);
575 assert_eq!(id1, id2);
576 assert_ne!(id1, id3);
577 }
578
579 #[test]
580 fn test_error_display() {
581 let err = RenderGraphError::CyclicDependency;
582 assert!(format!("{}", err).contains("Cyclic"));
583
584 let err = RenderGraphError::ResourceNotFound(ResourceId::new(42));
585 assert!(format!("{}", err).contains("Resource"));
586 }
587}