1use std::{
3 fmt::Debug,
4 hash::{Hash, Hasher},
5 marker::PhantomData,
6};
7
8use bevy::app::{App, Plugin, Startup};
9use bevy::asset::DirectAssetAccessExt;
10use bevy::ecs::{
11 component::Component,
12 entity::Entity,
13 observer::Trigger,
14 query::With,
15 resource::Resource,
16 schedule::{
17 Condition, IntoScheduleConfigs,
18 common_conditions::{not, resource_changed, resource_exists, resource_exists_and_changed},
19 },
20 system::{Commands, Query, Res, ResMut, StaticSystemParam},
21 world::{DeferredWorld, FromWorld, World},
22};
23use bevy::math::UVec3;
24use bevy::render::{
25 ExtractSchedule, MainWorld, Render, RenderApp, RenderSet,
26 extract_resource::{ExtractResource, ExtractResourcePlugin, extract_resource},
27 gpu_readback::{Readback, ReadbackComplete},
28 render_graph::{self, RenderGraph, RenderLabel},
29 render_resource::{
30 AsBindGroup, BindGroup, BindGroupLayout, CachedComputePipelineId, CachedPipelineState,
31 ComputePassDescriptor, ComputePipelineDescriptor, PipelineCache, ShaderRef,
32 },
33 renderer::{RenderContext, RenderDevice},
34};
35use bevy::state::{
36 app::AppExtStates,
37 state::{NextState, OnEnter, States},
38};
39
40pub struct ComputeShaderPlugin<S: ComputeShader> {
42 pub limit: ReadbackLimit,
43 pub remove_on_complete: bool,
44 pub _marker: PhantomData<S>,
45}
46impl<S: ComputeShader> Default for ComputeShaderPlugin<S> {
47 fn default() -> Self {
48 Self {
49 limit: ReadbackLimit::default(),
50 remove_on_complete: false,
51 _marker: PhantomData,
52 }
53 }
54}
55impl<S: ComputeShader> Plugin for ComputeShaderPlugin<S> {
56 fn build(&self, app: &mut App) {
57 app.init_resource::<S>()
58 .add_plugins(ExtractResourcePlugin::<S>::default())
59 .init_state::<ComputeNodeState<S>>()
60 .add_systems(
61 OnEnter(ComputeNodeState::<S>::from(ComputeNodeStatus::Ready)),
62 ComputeShaderReadback::<S>::on_shader_ready,
63 )
64 .add_systems(
65 OnEnter(ComputeNodeState::<S>::from(ComputeNodeStatus::Completed)),
66 ComputeShaderReadback::<S>::on_shader_complete,
67 )
68 .add_systems(Startup, ComputeShaderReadback::<S>::spawn);
69 }
70
71 fn finish(&self, app: &mut App) {
72 let render_app = app.sub_app_mut(RenderApp);
74 render_app
75 .init_resource::<ComputePipeline<S>>()
76 .init_resource::<ComputeNodeState<S>>()
77 .add_systems(
78 ExtractSchedule,
79 ComputeNode::<S>::reset_on_change
80 .run_if(resource_exists_and_changed::<S>)
81 .after(extract_resource::<S>),
82 )
83 .add_systems(
84 ExtractSchedule,
85 ComputeNodeState::<S>::extract_to_main
86 .run_if(resource_changed::<ComputeNodeState<S>>),
87 )
88 .add_systems(
89 Render,
90 (S::prepare_bind_group)
91 .chain()
92 .in_set(RenderSet::PrepareBindGroups)
93 .run_if(
94 not(resource_exists::<ComputeShaderBindGroup<S>>).or(resource_changed::<S>),
95 ),
96 );
97
98 render_app
101 .world_mut()
102 .resource_mut::<RenderGraph>()
103 .add_node(
104 ComputeNodeLabel::<S>::default(),
105 ComputeNode::<S> {
106 limit: self.limit,
107 ..Default::default()
108 },
109 );
110
111 if self.remove_on_complete {
113 render_app.add_systems(
114 ExtractSchedule,
115 ComputeNodeLabel::<S>::remove_on_complete
116 .run_if(resource_changed::<ComputeNodeState<S>>),
117 );
118 }
119 }
120}
121
122#[derive(Default, Debug, Copy, Clone)]
124pub enum ReadbackLimit {
125 #[default]
127 Infinite,
128 Finite(usize),
130}
131
132#[derive(Component)]
134pub struct ComputeShaderReadback<S: ComputeShader> {
135 pub _marker: PhantomData<S>,
136}
137impl<S: ComputeShader> Default for ComputeShaderReadback<S> {
138 fn default() -> Self {
139 Self {
140 _marker: PhantomData,
141 }
142 }
143}
144impl<S: ComputeShader> ComputeShaderReadback<S> {
145 fn spawn(mut commands: Commands) {
147 commands.spawn(Self::default()).observe(S::on_readback);
148 }
149 fn on_shader_ready(
151 mut commands: Commands,
152 compute_shader: Res<S>,
153 mut compute_shader_readbacks: Query<Entity, With<Self>>,
154 ) {
155 for entity in compute_shader_readbacks.iter_mut() {
156 if let Some(readback) = compute_shader.readback() {
157 commands.entity(entity).insert(readback);
158 }
159 }
160 }
161 fn on_shader_complete(
163 mut commands: Commands,
164 mut compute_shader_readbacks: Query<Entity, With<Self>>,
165 ) {
166 for entity in compute_shader_readbacks.iter_mut() {
167 commands.entity(entity).remove::<Readback>();
168 }
169 }
170}
171
172pub trait ComputeShader: AsBindGroup + Clone + Debug + FromWorld + ExtractResource {
174 fn compute_shader() -> ShaderRef;
176 fn workgroup_size() -> UVec3;
178 fn prepare_bind_group(
180 mut commands: Commands,
181 pipeline: Res<ComputePipeline<Self>>,
182 render_device: Res<RenderDevice>,
183 input: Res<Self>,
184 param: StaticSystemParam<<Self as AsBindGroup>::Param>,
185 ) {
186 let bind_group = input
187 .as_bind_group(&pipeline.layout, &render_device, &mut param.into_inner())
188 .unwrap();
189 commands.insert_resource(ComputeShaderBindGroup::<Self> {
190 bind_group: bind_group.bind_group,
191 _marker: PhantomData,
192 });
193 }
194 fn readback(&self) -> Option<Readback> {
196 None
197 }
198 fn on_readback(_trigger: Trigger<ReadbackComplete>, mut _world: DeferredWorld) {}
200}
201
202#[derive(Resource)]
204pub struct ComputeShaderBindGroup<S: ComputeShader> {
205 pub bind_group: BindGroup,
206 pub _marker: PhantomData<S>,
207}
208
209#[derive(Default, Copy, Clone, PartialEq, Eq, Debug, Hash)]
211pub enum ComputeNodeStatus {
212 #[default]
213 Loading,
214 Init,
215 Ready,
216 Completed,
217 Error,
218}
219#[derive(States, Resource, Clone, Copy, Debug)]
223pub struct ComputeNodeState<S: ComputeShader> {
224 status: ComputeNodeStatus,
225 _marker: PhantomData<S>,
226}
227impl<S: ComputeShader> Hash for ComputeNodeState<S> {
228 fn hash<H: Hasher>(&self, state: &mut H) {
229 self.status.hash(state);
230 }
231}
232impl<S: ComputeShader> PartialEq for ComputeNodeState<S> {
233 fn eq(&self, other: &Self) -> bool {
234 self.status == other.status
235 }
236}
237impl<S: ComputeShader> Eq for ComputeNodeState<S> {}
238impl<S: ComputeShader> From<ComputeNodeStatus> for ComputeNodeState<S> {
239 fn from(value: ComputeNodeStatus) -> Self {
240 Self {
241 status: value,
242 _marker: PhantomData,
243 }
244 }
245}
246impl<S: ComputeShader> Default for ComputeNodeState<S> {
247 fn default() -> Self {
248 Self {
249 status: ComputeNodeStatus::default(),
250 _marker: PhantomData,
251 }
252 }
253}
254impl<S: ComputeShader> ComputeNodeState<S> {
255 fn extract_to_main(compute_state: Res<ComputeNodeState<S>>, mut world: ResMut<MainWorld>) {
258 world
259 .resource_mut::<NextState<ComputeNodeState<S>>>()
260 .set(compute_state.clone());
261 }
262}
263
264#[derive(Resource)]
266pub struct ComputePipeline<S: ComputeShader> {
267 pub layout: BindGroupLayout,
268 pipeline: CachedComputePipelineId,
269 _marker: PhantomData<S>,
270}
271impl<S: ComputeShader> FromWorld for ComputePipeline<S> {
272 fn from_world(world: &mut World) -> Self {
273 let render_device = world.resource::<RenderDevice>();
274 let layout = S::bind_group_layout(render_device);
275 let shader = match S::compute_shader() {
276 ShaderRef::Default => panic!("Must define compute_shader."),
277 ShaderRef::Handle(handle) => handle,
278 ShaderRef::Path(path) => world.load_asset(path),
279 };
280 let pipeline_cache = world.resource::<PipelineCache>();
281 let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
282 label: Some("GPU readback compute shader".into()),
283 layout: vec![layout.clone()],
284 push_constant_ranges: Vec::new(),
285 shader: shader.clone(),
286 shader_defs: Vec::new(),
287 entry_point: "main".into(),
288 zero_initialize_workgroup_memory: false,
289 });
290 Self {
291 layout,
292 pipeline,
293 _marker: PhantomData,
294 }
295 }
296}
297
298#[derive(Debug, Clone, RenderLabel)]
300struct ComputeNodeLabel<S: ComputeShader> {
301 _marker: PhantomData<S>,
302}
303impl<S: ComputeShader> Default for ComputeNodeLabel<S> {
304 fn default() -> Self {
305 Self {
306 _marker: PhantomData,
307 }
308 }
309}
310impl<S: ComputeShader> PartialEq for ComputeNodeLabel<S> {
311 fn eq(&self, _other: &Self) -> bool {
312 true
313 }
314}
315impl<S: ComputeShader> Eq for ComputeNodeLabel<S> {}
316impl<S: ComputeShader> Hash for ComputeNodeLabel<S> {
317 fn hash<H: Hasher>(&self, _state: &mut H) {}
318}
319impl<S: ComputeShader> ComputeNodeLabel<S> {
320 fn remove_on_complete(mut render_graph: ResMut<RenderGraph>, state: Res<ComputeNodeState<S>>) {
321 if state.status == ComputeNodeStatus::Completed {
322 let _ = render_graph.remove_node(Self::default());
323 }
324 }
325}
326
327struct ComputeNode<S: ComputeShader> {
330 status: ComputeNodeStatus,
331 limit: ReadbackLimit,
332 count: usize,
333 _marker: PhantomData<S>,
334}
335impl<S: ComputeShader> Default for ComputeNode<S> {
336 fn default() -> Self {
337 Self {
338 status: ComputeNodeStatus::default(),
339 limit: ReadbackLimit::Infinite,
340 count: 0,
341 _marker: PhantomData,
342 }
343 }
344}
345impl<S: ComputeShader> ComputeNode<S> {
346 fn reset_on_change(
348 mut render_graph: ResMut<RenderGraph>,
349 mut state: ResMut<ComputeNodeState<S>>,
350 ) {
351 let Ok(node) = render_graph.get_node_mut::<Self>(ComputeNodeLabel::<S>::default()) else {
352 return;
353 };
354 node.count = 0;
355 node.status = ComputeNodeStatus::Loading;
356 *state = ComputeNodeState {
357 status: ComputeNodeStatus::Loading,
358 ..Default::default()
359 };
360 }
361}
362impl<S: ComputeShader> render_graph::Node for ComputeNode<S> {
363 fn update(&mut self, world: &mut World) {
364 let pipeline = world.resource::<ComputePipeline<S>>();
365 let pipeline_cache = world.resource::<PipelineCache>();
366
367 let next_status = match pipeline_cache.get_compute_pipeline_state(pipeline.pipeline) {
368 CachedPipelineState::Ok(_) => match (self.status, self.limit) {
369 (ComputeNodeStatus::Completed, _) => ComputeNodeStatus::Completed,
370 (_, ReadbackLimit::Finite(limit)) => {
371 if self.count < limit {
372 self.count += 1;
373 ComputeNodeStatus::Ready
374 } else {
375 self.count = 0;
376 ComputeNodeStatus::Completed
377 }
378 }
379 _ => ComputeNodeStatus::Ready,
380 },
381 CachedPipelineState::Creating(_) => ComputeNodeStatus::Loading,
382 CachedPipelineState::Queued => ComputeNodeStatus::Loading,
383 CachedPipelineState::Err(_) => ComputeNodeStatus::Error,
384 };
385
386 if self.status != next_status {
387 self.status = next_status;
388 world.resource_mut::<ComputeNodeState<S>>().status = next_status;
389 }
390 }
391
392 fn run(
393 &self,
394 _graph: &mut render_graph::RenderGraphContext,
395 render_context: &mut RenderContext,
396 world: &World,
397 ) -> Result<(), render_graph::NodeRunError> {
398 let pipeline_cache = world.resource::<PipelineCache>();
399 let pipeline = world.resource::<ComputePipeline<S>>();
400 let bind_group = &world.resource::<ComputeShaderBindGroup<S>>().bind_group;
401 if self.status == ComputeNodeStatus::Ready {
402 if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
403 let workgroup_size = S::workgroup_size();
404 let mut pass =
405 render_context
406 .command_encoder()
407 .begin_compute_pass(&ComputePassDescriptor {
408 label: Some("GPU readback compute pass"),
409 ..Default::default()
410 });
411 pass.set_bind_group(0, bind_group, &[]);
412 pass.set_pipeline(init_pipeline);
413 pass.dispatch_workgroups(workgroup_size.x, workgroup_size.y, workgroup_size.z);
414 }
415 }
416 Ok(())
417 }
418}