rafx_framework/graph/
prepared_graph.rs

1use super::PhysicalImageId;
2use crate::graph::graph_buffer::PhysicalBufferId;
3use crate::graph::graph_image::PhysicalImageViewId;
4use crate::graph::graph_node::{RenderGraphNodeId, RenderGraphNodeName};
5use crate::graph::graph_pass::{PrepassBufferBarrier, PrepassImageBarrier, RenderGraphOutputPass};
6use crate::graph::graph_plan::RenderGraphPlan;
7use crate::graph::{
8    RenderGraphBufferUsageId, RenderGraphBuilder, RenderGraphImageUsageId,
9    RenderGraphNodeVisitNodeCallback,
10};
11use crate::render_features::{
12    PreparedRenderData, RenderJobBeginExecuteGraphContext, RenderJobCommandBufferContext,
13    RenderJobWriteContext, RenderPhase, RenderView,
14};
15use crate::resources::DynCommandBuffer;
16use crate::{BufferResource, GraphicsPipelineRenderTargetMeta, ImageResource, RenderResources};
17use crate::{ImageViewResource, ResourceArc, ResourceContext};
18use fnv::FnvHashMap;
19use rafx_api::{
20    RafxBarrierQueueTransition, RafxBufferBarrier, RafxColorRenderTargetBinding, RafxCommandBuffer,
21    RafxCommandBufferDef, RafxCommandPoolDef, RafxDepthStencilRenderTargetBinding,
22    RafxDeviceContext, RafxExtents2D, RafxFormat, RafxQueue, RafxResult, RafxSwapchainColorSpace,
23    RafxTextureBarrier,
24};
25use std::hash::Hash;
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
28pub struct SwapchainSurfaceInfo {
29    pub extents: RafxExtents2D,
30    pub format: RafxFormat,
31    pub color_space: RafxSwapchainColorSpace,
32}
33
34#[derive(Copy, Clone)]
35pub struct RenderGraphContext<'graph, 'write> {
36    prepared_render_graph: &'graph PreparedRenderGraph,
37    prepared_render_data: &'graph PreparedRenderData<'write>,
38    render_resources: &'graph RenderResources,
39}
40
41impl<'graph, 'write> RenderGraphContext<'graph, 'write> {
42    pub fn buffer(
43        &self,
44        buffer: RenderGraphBufferUsageId,
45    ) -> Option<ResourceArc<BufferResource>> {
46        self.prepared_render_graph.buffer(buffer)
47    }
48
49    pub fn image_view(
50        &self,
51        image: RenderGraphImageUsageId,
52    ) -> Option<ResourceArc<ImageViewResource>> {
53        self.prepared_render_graph.image_view(image)
54    }
55
56    pub fn device_context(&self) -> &RafxDeviceContext {
57        &self.prepared_render_graph.device_context
58    }
59
60    pub fn resource_context(&self) -> &ResourceContext {
61        &self.prepared_render_graph.resource_context
62    }
63
64    pub fn prepared_render_data(&self) -> &PreparedRenderData<'write> {
65        &self.prepared_render_data
66    }
67
68    pub fn render_resources(&self) -> &RenderResources {
69        &self.render_resources
70    }
71}
72
73pub struct OnBeginExecuteGraphArgs<'graph, 'write> {
74    pub command_buffer: DynCommandBuffer,
75    pub graph_context: RenderGraphContext<'graph, 'write>,
76}
77
78pub struct VisitComputeNodeArgs<'graph, 'write> {
79    pub command_buffer: DynCommandBuffer,
80    pub graph_context: RenderGraphContext<'graph, 'write>,
81}
82
83pub struct VisitRenderpassNodeArgs<'graph, 'write> {
84    pub command_buffer: DynCommandBuffer,
85    pub render_target_meta: GraphicsPipelineRenderTargetMeta,
86    pub graph_context: RenderGraphContext<'graph, 'write>,
87}
88
89// Convenience function for creating a write context and triggering writing a phase for a view.
90// (Alternatively you can make your own write context, which allows calling write_view_phase
91// multiple times with the same context)
92impl<'graph, 'write> VisitRenderpassNodeArgs<'graph, 'write> {
93    pub fn write_view_phase<PhaseT: RenderPhase>(
94        &self,
95        render_view: &RenderView,
96    ) -> RafxResult<()> {
97        let mut write_context =
98            RenderJobCommandBufferContext::from_graph_visit_render_pass_args(self);
99        self.graph_context
100            .prepared_render_data()
101            .write_view_phase::<PhaseT>(render_view, &mut write_context)
102    }
103}
104
105/// Encapsulates a render graph plan and all resources required to execute it
106pub struct PreparedRenderGraph {
107    device_context: RafxDeviceContext,
108    resource_context: ResourceContext,
109    buffer_resources: FnvHashMap<PhysicalBufferId, ResourceArc<BufferResource>>,
110    image_resources: FnvHashMap<PhysicalImageId, ResourceArc<ImageResource>>,
111    image_view_resources: FnvHashMap<PhysicalImageViewId, ResourceArc<ImageViewResource>>,
112    graph_plan: RenderGraphPlan,
113}
114
115impl PreparedRenderGraph {
116    pub fn node_debug_name(
117        &self,
118        node_id: RenderGraphNodeId,
119    ) -> Option<RenderGraphNodeName> {
120        let pass_index = *self.graph_plan.node_to_pass_index.get(&node_id)?;
121        self.graph_plan.passes[pass_index].debug_name()
122    }
123
124    pub fn new(
125        device_context: &RafxDeviceContext,
126        resource_context: &ResourceContext,
127        graph: RenderGraphBuilder,
128        swapchain_surface_info: &SwapchainSurfaceInfo,
129    ) -> RafxResult<Self> {
130        let graph_plan = graph.build_plan(swapchain_surface_info);
131        let mut cache_guard = resource_context.render_graph_cache().inner.lock().unwrap();
132        let cache = &mut *cache_guard;
133
134        profiling::scope!("allocate resources");
135        let buffer_resources =
136            cache.allocate_buffers(device_context, &graph_plan, resource_context.resources())?;
137
138        let image_resources = cache.allocate_images(
139            device_context,
140            &graph_plan,
141            resource_context.resources(),
142            swapchain_surface_info,
143        )?;
144
145        let image_view_resources = cache.allocate_image_views(
146            &graph_plan,
147            resource_context.resources(),
148            &image_resources,
149        )?;
150
151        Ok(PreparedRenderGraph {
152            device_context: device_context.clone(),
153            resource_context: resource_context.clone(),
154            buffer_resources,
155            image_resources,
156            image_view_resources,
157            graph_plan,
158        })
159    }
160
161    pub fn buffer(
162        &self,
163        buffer: RenderGraphBufferUsageId,
164    ) -> Option<ResourceArc<BufferResource>> {
165        let physical_buffer = self.graph_plan.buffer_usage_to_physical.get(&buffer)?;
166        self.buffer_resources.get(physical_buffer).cloned()
167    }
168
169    // pub fn image(
170    //     &self,
171    //     image_usage: RenderGraphImageUsageId,
172    // ) -> Option<ResourceArc<ImageResource>> {
173    //     let image = self.graph_plan.image_usage_to_physical.get(&image_usage)?;
174    //     self.image_resources.get(image).cloned()
175    // }
176
177    pub fn image_view(
178        &self,
179        image: RenderGraphImageUsageId,
180    ) -> Option<ResourceArc<ImageViewResource>> {
181        let physical_image = self.graph_plan.image_usage_to_view.get(&image)?;
182        self.image_view_resources.get(physical_image).cloned()
183    }
184
185    fn insert_barriers(
186        &self,
187        command_buffer: &RafxCommandBuffer,
188        pass_buffer_barriers: &[PrepassBufferBarrier],
189        pass_image_barriers: &[PrepassImageBarrier],
190    ) -> RafxResult<()> {
191        assert!(!pass_buffer_barriers.is_empty() || !pass_image_barriers.is_empty());
192
193        let mut buffer_barriers = Vec::with_capacity(pass_buffer_barriers.len());
194        let buffers: Vec<_> = pass_buffer_barriers
195            .iter()
196            .map(|x| self.buffer_resources[&x.buffer].get_raw().buffer.clone())
197            .collect();
198        for (buffer_barrier, buffer) in pass_buffer_barriers.iter().zip(&buffers) {
199            log::trace!(
200                "add buffer barrier for buffer {:?} state {:?} -> {:?}",
201                buffer_barrier.buffer,
202                buffer_barrier.old_state,
203                buffer_barrier.new_state
204            );
205
206            buffer_barriers.push(RafxBufferBarrier {
207                buffer: buffer.as_ref(),
208                src_state: buffer_barrier.old_state,
209                dst_state: buffer_barrier.new_state,
210                queue_transition: RafxBarrierQueueTransition::None,
211                offset_size: None,
212            });
213        }
214
215        let mut image_barriers = Vec::with_capacity(pass_image_barriers.len());
216        let images: Vec<_> = pass_image_barriers
217            .iter()
218            .map(|x| self.image_resources[&x.image].get_raw().image.clone())
219            .collect();
220        for (image_barrier, image) in pass_image_barriers.iter().zip(&images) {
221            log::trace!(
222                "add image barrier for image {:?} state {:?} -> {:?}",
223                image_barrier.image,
224                image_barrier.old_state,
225                image_barrier.new_state
226            );
227
228            image_barriers.push(RafxTextureBarrier {
229                texture: image,
230                src_state: image_barrier.old_state,
231                dst_state: image_barrier.new_state,
232                array_slice: None,
233                mip_slice: None,
234                queue_transition: RafxBarrierQueueTransition::None,
235            });
236        }
237
238        // for buffer_barrier in rafx_buffer_barriers {
239        //     println!("{:?}", buffer_barrier);
240        // }
241        //
242        // for rt_barrier in rt_barriers {
243        //     println!("{:?}", rt_barrier);
244        // }
245
246        command_buffer.cmd_resource_barrier(&buffer_barriers, &image_barriers)
247    }
248
249    fn visit_render_node(
250        &self,
251        node_id: RenderGraphNodeId,
252        args: VisitRenderpassNodeArgs,
253    ) -> RafxResult<()> {
254        if let Some(callback) = self.graph_plan.visit_node_callbacks.get(&node_id) {
255            if let RenderGraphNodeVisitNodeCallback::Render(render_callback) = callback {
256                (render_callback)(args)?
257            } else {
258                let debug_name = args
259                    .graph_context
260                    .prepared_render_graph
261                    .node_debug_name(node_id);
262                log::error!("Tried to call a render node callback but a simple callback was registered for node {:?} ({:?})", node_id, debug_name);
263            }
264        } else {
265            //let debug_name = args.graph_context.prepared_render_graph.node_debug_name(node_id);
266            //log::error!("No callback found for node {:?} ({:?})", node_id, debug_name);
267        }
268
269        Ok(())
270    }
271
272    fn visit_callback_node(
273        &self,
274        node_id: RenderGraphNodeId,
275        args: VisitComputeNodeArgs,
276    ) -> RafxResult<()> {
277        if let Some(callback) = self.graph_plan.visit_node_callbacks.get(&node_id) {
278            if let RenderGraphNodeVisitNodeCallback::Callback(callback) = callback {
279                (callback)(args)?
280            } else {
281                let debug_name = args
282                    .graph_context
283                    .prepared_render_graph
284                    .node_debug_name(node_id);
285                log::error!("Tried to call a simple callback node callback but a render node callback was registered for node {:?} ({:?})", node_id, debug_name);
286            }
287        } else {
288            //let debug_name = args.graph_context.prepared_render_graph.node_debug_name(node_id);
289            //log::error!("No callback found for node {:?} {:?}", node_id, debug_name);
290        }
291
292        Ok(())
293    }
294
295    pub fn execute_graph<'write>(
296        &'write self,
297        write_context: &RenderJobWriteContext,
298        prepared_render_data: PreparedRenderData<'write>,
299        queue: &RafxQueue,
300    ) -> RafxResult<Vec<DynCommandBuffer>> {
301        profiling::scope!("Execute Graph");
302        //
303        // Start a command writer. For now just do a single primary writer, later we can multithread this.
304        //
305        let mut command_writer = self
306            .resource_context
307            .create_dyn_command_pool_allocator()
308            .allocate_dyn_pool(queue, &RafxCommandPoolDef { transient: true }, 0)?;
309
310        let command_buffer = command_writer.allocate_dyn_command_buffer(&RafxCommandBufferDef {
311            is_secondary: false,
312        })?;
313
314        command_buffer.begin()?;
315
316        let render_graph_context = RenderGraphContext {
317            prepared_render_graph: &self,
318            prepared_render_data: &prepared_render_data,
319            render_resources: write_context.render_resources,
320        };
321
322        render_graph_context
323            .prepared_render_data()
324            .on_begin_execute_graph(
325                &mut RenderJobBeginExecuteGraphContext::from_on_begin_execute_graph_args(
326                    &OnBeginExecuteGraphArgs {
327                        graph_context: render_graph_context.clone(),
328                        command_buffer: command_buffer.clone(),
329                    },
330                ),
331            )?;
332
333        //
334        // Iterate through all passes
335        //
336        for (pass_index, pass) in self.graph_plan.passes.iter().enumerate() {
337            //TODO output pass is?
338            //TODO: add_compute_node/add_render_node?
339
340            profiling::scope!("pass", pass.debug_name().unwrap_or("unnamed"));
341            log::trace!("Execute pass name: {:?}", pass.debug_name());
342
343            if let Some(name) = pass.debug_name() {
344                command_buffer.cmd_push_group_debug_name(name);
345            }
346
347            let node_id = pass.node();
348
349            if let Some(pre_pass_barrier) = pass.pre_pass_barrier() {
350                log::trace!(
351                    "prepass barriers for pass {} {:?}",
352                    pass_index,
353                    pass.debug_name()
354                );
355                self.insert_barriers(
356                    &command_buffer,
357                    &pre_pass_barrier.buffer_barriers,
358                    &pre_pass_barrier.image_barriers,
359                )?;
360            }
361
362            match pass {
363                RenderGraphOutputPass::Render(pass) => {
364                    let color_images: Vec<_> = pass
365                        .color_render_targets
366                        .iter()
367                        .map(|x| self.image_resources[&x.image].get_raw().image.clone())
368                        .collect();
369
370                    let resolve_images: Vec<_> = pass
371                        .color_render_targets
372                        .iter()
373                        .map(|x| {
374                            //x.map(|x| self.image_resources[&x.image].get_raw().image.clone())
375                            x.resolve_image
376                                .map(|x| self.image_resources[&x].get_raw().image.clone())
377                        })
378                        .collect();
379
380                    let color_target_bindings: Vec<_> = pass
381                        .color_render_targets
382                        .iter()
383                        .enumerate()
384                        .map(
385                            |(color_image_index, color_image)| RafxColorRenderTargetBinding {
386                                texture: &color_images[color_image_index],
387                                clear_value: color_image.clear_value.clone(),
388                                load_op: color_image.load_op,
389                                store_op: color_image.store_op,
390                                array_slice: color_image.array_slice,
391                                mip_slice: color_image.mip_slice,
392                                resolve_target: resolve_images[color_image_index].as_ref(),
393                                resolve_store_op: color_image.resolve_store_op.into(),
394                                resolve_array_slice: color_image.resolve_array_slice,
395                                resolve_mip_slice: color_image.resolve_mip_slice,
396                            },
397                        )
398                        .collect();
399
400                    let mut depth_stencil_image = None;
401                    let depth_target_binding = pass.depth_stencil_render_target.as_ref().map(|x| {
402                        depth_stencil_image =
403                            Some(self.image_resources[&x.image].get_raw().image.clone());
404                        RafxDepthStencilRenderTargetBinding {
405                            texture: depth_stencil_image.as_ref().unwrap(),
406                            clear_value: x.clear_value.clone(),
407                            depth_load_op: x.depth_load_op,
408                            stencil_load_op: x.stencil_load_op,
409                            depth_store_op: x.depth_store_op,
410                            stencil_store_op: x.stencil_store_op,
411                            array_slice: x.array_slice,
412                            mip_slice: x.mip_slice,
413                        }
414                    });
415
416                    //println!("color bindings:\n{:#?}", color_target_bindings);
417                    //println!("depth binding:\n{:#?}", depth_target_binding);
418
419                    command_buffer
420                        .cmd_begin_render_pass(&color_target_bindings, depth_target_binding)?;
421
422                    let args = VisitRenderpassNodeArgs {
423                        render_target_meta: pass.render_target_meta.clone(),
424                        graph_context: render_graph_context,
425                        command_buffer: command_buffer.clone(),
426                    };
427
428                    self.visit_render_node(node_id, args)?;
429
430                    command_buffer.cmd_end_render_pass()?;
431                }
432                RenderGraphOutputPass::Callback(_pass) => {
433                    let args = VisitComputeNodeArgs {
434                        graph_context: render_graph_context,
435                        command_buffer: command_buffer.clone(),
436                    };
437
438                    self.visit_callback_node(node_id, args)?;
439                }
440            }
441
442            if pass.debug_name().is_some() {
443                command_buffer.cmd_pop_group_debug_name();
444            }
445        }
446
447        command_buffer.end()?;
448
449        Ok(vec![command_buffer])
450    }
451}