use super::{WgpuRenderContext, WgpuRenderResourceContext};
use bevy_ecs::world::World;
use bevy_render::{
render_graph::{Edge, NodeId, ResourceSlots, StageBorrow},
renderer::RenderResourceContext,
};
use bevy_utils::HashMap;
use parking_lot::RwLock;
use std::sync::Arc;
#[derive(Debug)]
pub struct WgpuRenderGraphExecutor {
pub max_thread_count: usize,
}
impl WgpuRenderGraphExecutor {
pub fn execute(
&self,
world: &World,
device: Arc<wgpu::Device>,
queue: &mut wgpu::Queue,
stages: &mut [StageBorrow],
) {
let render_resource_context = {
let context = world
.get_resource::<Box<dyn RenderResourceContext>>()
.unwrap();
context
.downcast_ref::<WgpuRenderResourceContext>()
.unwrap()
.clone()
};
let node_outputs: Arc<RwLock<HashMap<NodeId, ResourceSlots>>> = Default::default();
for stage in stages.iter_mut() {
let (sender, receiver) = crossbeam_channel::bounded(self.max_thread_count);
let chunk_size = (stage.jobs.len() + self.max_thread_count - 1) / self.max_thread_count; let mut actual_thread_count = 0;
for jobs_chunk in stage.jobs.chunks_mut(chunk_size) {
let sender = sender.clone();
let world = &*world;
actual_thread_count += 1;
let device = device.clone();
let render_resource_context = render_resource_context.clone();
let node_outputs = node_outputs.clone();
let mut render_context = WgpuRenderContext::new(device, render_resource_context);
for job in jobs_chunk.iter_mut() {
for node_state in job.node_states.iter_mut() {
for (i, mut input_slot) in node_state.input_slots.iter_mut().enumerate() {
if let Edge::SlotEdge {
output_node,
output_index,
..
} = node_state.edges.get_input_slot_edge(i).unwrap()
{
let node_outputs = node_outputs.read();
let outputs = if let Some(outputs) = node_outputs.get(output_node) {
outputs
} else {
panic!("Node inputs not set.")
};
let output_resource =
outputs.get(*output_index).expect("Output should be set.");
input_slot.resource = Some(output_resource);
} else {
panic!("No edge connected to input.")
}
}
node_state.node.update(
world,
&mut render_context,
&node_state.input_slots,
&mut node_state.output_slots,
);
node_outputs
.write()
.insert(node_state.id, node_state.output_slots.clone());
}
}
sender.send(render_context.finish()).unwrap();
}
let mut command_buffers = Vec::new();
for _i in 0..actual_thread_count {
let command_buffer = receiver.recv().unwrap();
if let Some(command_buffer) = command_buffer {
command_buffers.push(command_buffer);
}
}
queue.submit(command_buffers.drain(..));
}
}
}