1use std::collections::{HashMap, HashSet};
40use std::sync::Arc;
41
42use crate::GraphicsContext;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct ResourceId(u64);
47
48impl ResourceId {
49 pub fn new(id: u64) -> Self {
51 Self(id)
52 }
53
54 pub fn as_u64(&self) -> u64 {
56 self.0
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub struct PassId(u64);
63
64impl PassId {
65 pub fn new(id: u64) -> Self {
67 Self(id)
68 }
69
70 pub fn as_u64(&self) -> u64 {
72 self.0
73 }
74}
75
76#[derive(Debug, Clone, PartialEq)]
78pub enum ResourceType {
79 Texture {
81 size: (u32, u32, u32),
82 format: wgpu::TextureFormat,
83 usage: wgpu::TextureUsages,
84 },
85 Buffer {
87 size: u64,
88 usage: wgpu::BufferUsages,
89 },
90}
91
92#[derive(Debug, Clone)]
94pub struct ResourceInfo {
95 pub id: ResourceId,
97 pub resource_type: ResourceType,
99 pub name: String,
101 pub first_read: Option<PassId>,
103 pub last_write: Option<PassId>,
105 pub last_read: Option<PassId>,
107}
108
109pub struct RenderContext {
111 pub graphics: Arc<GraphicsContext>,
113 pub textures: HashMap<ResourceId, wgpu::Texture>,
115 pub buffers: HashMap<ResourceId, wgpu::Buffer>,
117}
118
119impl RenderContext {
120 pub fn new(graphics: Arc<GraphicsContext>) -> Self {
122 Self {
123 graphics,
124 textures: HashMap::new(),
125 buffers: HashMap::new(),
126 }
127 }
128
129 pub fn get_texture(&self, id: ResourceId) -> Option<&wgpu::Texture> {
131 self.textures.get(&id)
132 }
133
134 pub fn get_buffer(&self, id: ResourceId) -> Option<&wgpu::Buffer> {
136 self.buffers.get(&id)
137 }
138}
139
140pub struct RenderGraphPass {
142 pub name: &'static str,
144 pub inputs: Vec<ResourceId>,
146 pub outputs: Vec<ResourceId>,
148 pub execute: Box<dyn Fn(&mut RenderContext) + Send + Sync>,
150}
151
152impl RenderGraphPass {
153 pub fn new(
155 name: &'static str,
156 inputs: Vec<ResourceId>,
157 outputs: Vec<ResourceId>,
158 execute: impl Fn(&mut RenderContext) + Send + Sync + 'static,
159 ) -> Self {
160 Self {
161 name,
162 inputs,
163 outputs,
164 execute: Box::new(execute),
165 }
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct ExecutionPlan {
172 pub pass_order: Vec<PassId>,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum RenderGraphError {
179 CyclicDependency,
181 ResourceNotFound(ResourceId),
183 PassNotFound(PassId),
185 InvalidUsage(String),
187}
188
189impl std::fmt::Display for RenderGraphError {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 match self {
192 Self::CyclicDependency => write!(f, "Cyclic dependency detected in render graph"),
193 Self::ResourceNotFound(id) => write!(f, "Resource {:?} not found", id),
194 Self::PassNotFound(id) => write!(f, "Pass {:?} not found", id),
195 Self::InvalidUsage(msg) => write!(f, "Invalid resource usage: {}", msg),
196 }
197 }
198}
199
200impl std::error::Error for RenderGraphError {}
201
202pub struct RenderGraph {
204 passes: HashMap<PassId, RenderGraphPass>,
206 resources: HashMap<ResourceId, ResourceInfo>,
208 next_pass_id: u64,
210 next_resource_id: u64,
212 execution_plan: Option<ExecutionPlan>,
214}
215
216impl RenderGraph {
217 pub fn new() -> Self {
219 Self {
220 passes: HashMap::new(),
221 resources: HashMap::new(),
222 next_pass_id: 0,
223 next_resource_id: 0,
224 execution_plan: None,
225 }
226 }
227
228 pub fn add_texture(
230 &mut self,
231 name: impl Into<String>,
232 size: (u32, u32, u32),
233 format: wgpu::TextureFormat,
234 usage: wgpu::TextureUsages,
235 ) -> ResourceId {
236 let id = ResourceId::new(self.next_resource_id);
237 self.next_resource_id += 1;
238
239 let resource = ResourceInfo {
240 id,
241 resource_type: ResourceType::Texture {
242 size,
243 format,
244 usage,
245 },
246 name: name.into(),
247 first_read: None,
248 last_write: None,
249 last_read: None,
250 };
251
252 self.resources.insert(id, resource);
253 self.execution_plan = None; id
256 }
257
258 pub fn add_buffer(
260 &mut self,
261 name: impl Into<String>,
262 size: u64,
263 usage: wgpu::BufferUsages,
264 ) -> ResourceId {
265 let id = ResourceId::new(self.next_resource_id);
266 self.next_resource_id += 1;
267
268 let resource = ResourceInfo {
269 id,
270 resource_type: ResourceType::Buffer { size, usage },
271 name: name.into(),
272 first_read: None,
273 last_write: None,
274 last_read: None,
275 };
276
277 self.resources.insert(id, resource);
278 self.execution_plan = None; id
281 }
282
283 pub fn add_pass(&mut self, pass: RenderGraphPass) -> PassId {
285 let id = PassId::new(self.next_pass_id);
286 self.next_pass_id += 1;
287
288 for &input_id in &pass.inputs {
290 if let Some(resource) = self.resources.get_mut(&input_id) {
291 if resource.first_read.is_none() {
292 resource.first_read = Some(id);
293 }
294 resource.last_read = Some(id);
295 }
296 }
297
298 for &output_id in &pass.outputs {
299 if let Some(resource) = self.resources.get_mut(&output_id) {
300 resource.last_write = Some(id);
301 }
302 }
303
304 self.passes.insert(id, pass);
305 self.execution_plan = None; id
308 }
309
310 pub fn compile(&mut self) -> Result<ExecutionPlan, RenderGraphError> {
314 let mut dependencies: HashMap<PassId, HashSet<PassId>> = HashMap::new();
316 let mut dependents: HashMap<PassId, HashSet<PassId>> = HashMap::new();
317
318 for (&pass_id, pass) in &self.passes {
319 dependencies.insert(pass_id, HashSet::new());
320 dependents.entry(pass_id).or_insert_with(HashSet::new);
321
322 for &input_id in &pass.inputs {
324 for (&other_pass_id, other_pass) in &self.passes {
326 if other_pass_id != pass_id && other_pass.outputs.contains(&input_id) {
327 dependencies.get_mut(&pass_id).unwrap().insert(other_pass_id);
328 dependents.entry(other_pass_id).or_insert_with(HashSet::new).insert(pass_id);
329 }
330 }
331 }
332 }
333
334 let mut sorted = Vec::new();
336 let mut no_incoming: Vec<PassId> = dependencies
337 .iter()
338 .filter(|(_, deps)| deps.is_empty())
339 .map(|(&id, _)| id)
340 .collect();
341
342 while let Some(pass_id) = no_incoming.pop() {
343 sorted.push(pass_id);
344
345 if let Some(deps) = dependents.get(&pass_id) {
347 for &dependent_id in deps {
348 if let Some(dep_set) = dependencies.get_mut(&dependent_id) {
349 dep_set.remove(&pass_id);
350 if dep_set.is_empty() {
351 no_incoming.push(dependent_id);
352 }
353 }
354 }
355 }
356 }
357
358 if sorted.len() != self.passes.len() {
360 return Err(RenderGraphError::CyclicDependency);
361 }
362
363 let plan = ExecutionPlan {
364 pass_order: sorted,
365 };
366
367 self.execution_plan = Some(plan.clone());
368
369 Ok(plan)
370 }
371
372 pub fn execute(&self, graphics: Arc<GraphicsContext>) -> Result<(), RenderGraphError> {
376 let plan = self
377 .execution_plan
378 .as_ref()
379 .ok_or(RenderGraphError::InvalidUsage(
380 "Graph not compiled".to_string(),
381 ))?;
382
383 let mut context = RenderContext::new(graphics);
384
385 for (id, info) in &self.resources {
387 match &info.resource_type {
388 ResourceType::Texture {
389 size,
390 format,
391 usage,
392 } => {
393 let texture = context.graphics.device.create_texture(&wgpu::TextureDescriptor {
394 label: Some(&info.name),
395 size: wgpu::Extent3d {
396 width: size.0,
397 height: size.1,
398 depth_or_array_layers: size.2,
399 },
400 mip_level_count: 1,
401 sample_count: 1,
402 dimension: wgpu::TextureDimension::D2,
403 format: *format,
404 usage: *usage,
405 view_formats: &[],
406 });
407 context.textures.insert(*id, texture);
408 }
409 ResourceType::Buffer { size, usage } => {
410 let buffer = context.graphics.device.create_buffer(&wgpu::BufferDescriptor {
411 label: Some(&info.name),
412 size: *size,
413 usage: *usage,
414 mapped_at_creation: false,
415 });
416 context.buffers.insert(*id, buffer);
417 }
418 }
419 }
420
421 for &pass_id in &plan.pass_order {
423 if let Some(pass) = self.passes.get(&pass_id) {
424 (pass.execute)(&mut context);
425 }
426 }
427
428 Ok(())
429 }
430
431 pub fn pass_count(&self) -> usize {
433 self.passes.len()
434 }
435
436 pub fn resource_count(&self) -> usize {
438 self.resources.len()
439 }
440
441 pub fn is_compiled(&self) -> bool {
443 self.execution_plan.is_some()
444 }
445}
446
447impl Default for RenderGraph {
448 fn default() -> Self {
449 Self::new()
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_render_graph_new() {
459 let graph = RenderGraph::new();
460 assert_eq!(graph.pass_count(), 0);
461 assert_eq!(graph.resource_count(), 0);
462 assert!(!graph.is_compiled());
463 }
464
465 #[test]
466 fn test_add_texture_resource() {
467 let mut graph = RenderGraph::new();
468 let tex = graph.add_texture(
469 "color_target",
470 (800, 600, 1),
471 wgpu::TextureFormat::Rgba8Unorm,
472 wgpu::TextureUsages::RENDER_ATTACHMENT,
473 );
474 assert_eq!(graph.resource_count(), 1);
475 assert_eq!(tex.as_u64(), 0);
476 }
477
478 #[test]
479 fn test_add_buffer_resource() {
480 let mut graph = RenderGraph::new();
481 let buf = graph.add_buffer(
482 "vertex_buffer",
483 1024,
484 wgpu::BufferUsages::VERTEX,
485 );
486 assert_eq!(graph.resource_count(), 1);
487 assert_eq!(buf.as_u64(), 0);
488 }
489
490 #[test]
491 fn test_add_pass() {
492 let mut graph = RenderGraph::new();
493 let tex = graph.add_texture(
494 "target",
495 (800, 600, 1),
496 wgpu::TextureFormat::Rgba8Unorm,
497 wgpu::TextureUsages::RENDER_ATTACHMENT,
498 );
499
500 let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
501 let pass_id = graph.add_pass(pass);
502
503 assert_eq!(graph.pass_count(), 1);
504 assert_eq!(pass_id.as_u64(), 0);
505 }
506
507 #[test]
508 fn test_compile_simple() {
509 let mut graph = RenderGraph::new();
510 let tex = graph.add_texture(
511 "target",
512 (800, 600, 1),
513 wgpu::TextureFormat::Rgba8Unorm,
514 wgpu::TextureUsages::RENDER_ATTACHMENT,
515 );
516
517 let pass = RenderGraphPass::new("test_pass", vec![], vec![tex], |_ctx| {});
518 graph.add_pass(pass);
519
520 let result = graph.compile();
521 assert!(result.is_ok());
522 assert!(graph.is_compiled());
523 }
524
525 #[test]
526 fn test_compile_multiple_passes() {
527 let mut graph = RenderGraph::new();
528 let tex1 = graph.add_texture(
529 "tex1",
530 (800, 600, 1),
531 wgpu::TextureFormat::Rgba8Unorm,
532 wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
533 );
534 let tex2 = graph.add_texture(
535 "tex2",
536 (800, 600, 1),
537 wgpu::TextureFormat::Rgba8Unorm,
538 wgpu::TextureUsages::RENDER_ATTACHMENT,
539 );
540
541 let pass1 = RenderGraphPass::new("pass1", vec![], vec![tex1], |_ctx| {});
543 graph.add_pass(pass1);
544
545 let pass2 = RenderGraphPass::new("pass2", vec![tex1], vec![tex2], |_ctx| {});
547 graph.add_pass(pass2);
548
549 let result = graph.compile();
550 assert!(result.is_ok());
551
552 let plan = result.unwrap();
553 assert_eq!(plan.pass_order.len(), 2);
554 assert!(plan.pass_order[0].as_u64() < plan.pass_order[1].as_u64());
556 }
557
558 #[test]
559 fn test_resource_id_equality() {
560 let id1 = ResourceId::new(1);
561 let id2 = ResourceId::new(1);
562 let id3 = ResourceId::new(2);
563 assert_eq!(id1, id2);
564 assert_ne!(id1, id3);
565 }
566
567 #[test]
568 fn test_pass_id_equality() {
569 let id1 = PassId::new(1);
570 let id2 = PassId::new(1);
571 let id3 = PassId::new(2);
572 assert_eq!(id1, id2);
573 assert_ne!(id1, id3);
574 }
575
576 #[test]
577 fn test_error_display() {
578 let err = RenderGraphError::CyclicDependency;
579 assert!(format!("{}", err).contains("Cyclic"));
580
581 let err = RenderGraphError::ResourceNotFound(ResourceId::new(42));
582 assert!(format!("{}", err).contains("Resource"));
583 }
584}