bevy_compute_readback/
lib.rs

1//! Library to simplify compute shader readbacks.
2use std::{
3    fmt::Debug,
4    hash::{Hash, Hasher},
5    marker::PhantomData,
6};
7
8use bevy::app::{App, Plugin, Startup};
9use bevy::asset::DirectAssetAccessExt;
10use bevy::ecs::{
11    component::Component,
12    entity::Entity,
13    observer::Trigger,
14    query::With,
15    resource::Resource,
16    schedule::{
17        Condition, IntoScheduleConfigs,
18        common_conditions::{not, resource_changed, resource_exists, resource_exists_and_changed},
19    },
20    system::{Commands, Query, Res, ResMut, StaticSystemParam},
21    world::{DeferredWorld, FromWorld, World},
22};
23use bevy::math::UVec3;
24use bevy::render::{
25    ExtractSchedule, MainWorld, Render, RenderApp, RenderSet,
26    extract_resource::{ExtractResource, ExtractResourcePlugin, extract_resource},
27    gpu_readback::{Readback, ReadbackComplete},
28    render_graph::{self, RenderGraph, RenderLabel},
29    render_resource::{
30        AsBindGroup, BindGroup, BindGroupLayout, CachedComputePipelineId, CachedPipelineState,
31        ComputePassDescriptor, ComputePipelineDescriptor, PipelineCache, ShaderRef,
32    },
33    renderer::{RenderContext, RenderDevice},
34};
35use bevy::state::{
36    app::AppExtStates,
37    state::{NextState, OnEnter, States},
38};
39
40/// Plugin to create all the required systems for using a custom compute shader.
41pub struct ComputeShaderPlugin<S: ComputeShader> {
42    pub limit: ReadbackLimit,
43    pub remove_on_complete: bool,
44    pub _marker: PhantomData<S>,
45}
46impl<S: ComputeShader> Default for ComputeShaderPlugin<S> {
47    fn default() -> Self {
48        Self {
49            limit: ReadbackLimit::default(),
50            remove_on_complete: false,
51            _marker: PhantomData,
52        }
53    }
54}
55impl<S: ComputeShader> Plugin for ComputeShaderPlugin<S> {
56    fn build(&self, app: &mut App) {
57        app.init_resource::<S>()
58            .add_plugins(ExtractResourcePlugin::<S>::default())
59            .init_state::<ComputeNodeState<S>>()
60            .add_systems(
61                OnEnter(ComputeNodeState::<S>::from(ComputeNodeStatus::Ready)),
62                ComputeShaderReadback::<S>::on_shader_ready,
63            )
64            .add_systems(
65                OnEnter(ComputeNodeState::<S>::from(ComputeNodeStatus::Completed)),
66                ComputeShaderReadback::<S>::on_shader_complete,
67            )
68            .add_systems(Startup, ComputeShaderReadback::<S>::spawn);
69    }
70
71    fn finish(&self, app: &mut App) {
72        // Add the compute shader resources and systems to the render app.
73        let render_app = app.sub_app_mut(RenderApp);
74        render_app
75            .init_resource::<ComputePipeline<S>>()
76            .init_resource::<ComputeNodeState<S>>()
77            .add_systems(
78                ExtractSchedule,
79                ComputeNode::<S>::reset_on_change
80                    .run_if(resource_exists_and_changed::<S>)
81                    .after(extract_resource::<S>),
82            )
83            .add_systems(
84                ExtractSchedule,
85                ComputeNodeState::<S>::extract_to_main
86                    .run_if(resource_changed::<ComputeNodeState<S>>),
87            )
88            .add_systems(
89                Render,
90                (S::prepare_bind_group)
91                    .chain()
92                    .in_set(RenderSet::PrepareBindGroups)
93                    .run_if(
94                        not(resource_exists::<ComputeShaderBindGroup<S>>).or(resource_changed::<S>),
95                    ),
96            );
97
98        // Add the compute node as a top level node to the render graph
99        // This means it will only execute once per frame
100        render_app
101            .world_mut()
102            .resource_mut::<RenderGraph>()
103            .add_node(
104                ComputeNodeLabel::<S>::default(),
105                ComputeNode::<S> {
106                    limit: self.limit,
107                    ..Default::default()
108                },
109            );
110
111        // If the compute node should be removed on completion, schedule the removal systems.
112        if self.remove_on_complete {
113            render_app.add_systems(
114                ExtractSchedule,
115                ComputeNodeLabel::<S>::remove_on_complete
116                    .run_if(resource_changed::<ComputeNodeState<S>>),
117            );
118        }
119    }
120}
121
122/// How many readbacks should be sent per initialization of the shader.
123#[derive(Default, Debug, Copy, Clone)]
124pub enum ReadbackLimit {
125    /// No limit, readback will continue indefinitely.
126    #[default]
127    Infinite,
128    /// Finite readback limit, measured in number of frames.
129    Finite(usize),
130}
131
132/// Component that receives readback events from the compute shader.
133#[derive(Component)]
134pub struct ComputeShaderReadback<S: ComputeShader> {
135    pub _marker: PhantomData<S>,
136}
137impl<S: ComputeShader> Default for ComputeShaderReadback<S> {
138    fn default() -> Self {
139        Self {
140            _marker: PhantomData,
141        }
142    }
143}
144impl<S: ComputeShader> ComputeShaderReadback<S> {
145    /// Spawn the readback observer on startup.
146    fn spawn(mut commands: Commands) {
147        commands.spawn(Self::default()).observe(S::on_readback);
148    }
149    /// Insert GPU readback component only when the shader is ready.
150    fn on_shader_ready(
151        mut commands: Commands,
152        compute_shader: Res<S>,
153        mut compute_shader_readbacks: Query<Entity, With<Self>>,
154    ) {
155        for entity in compute_shader_readbacks.iter_mut() {
156            if let Some(readback) = compute_shader.readback() {
157                commands.entity(entity).insert(readback);
158            }
159        }
160    }
161    /// Disable the shader when it's done.
162    fn on_shader_complete(
163        mut commands: Commands,
164        mut compute_shader_readbacks: Query<Entity, With<Self>>,
165    ) {
166        for entity in compute_shader_readbacks.iter_mut() {
167            commands.entity(entity).remove::<Readback>();
168        }
169    }
170}
171
172/// Trait to implement for a custom compute shader.
173pub trait ComputeShader: AsBindGroup + Clone + Debug + FromWorld + ExtractResource {
174    /// Asset path or handle to the shader.
175    fn compute_shader() -> ShaderRef;
176    /// Workgroup size.
177    fn workgroup_size() -> UVec3;
178    /// Optional bind group preparation.
179    fn prepare_bind_group(
180        mut commands: Commands,
181        pipeline: Res<ComputePipeline<Self>>,
182        render_device: Res<RenderDevice>,
183        input: Res<Self>,
184        param: StaticSystemParam<<Self as AsBindGroup>::Param>,
185    ) {
186        let bind_group = input
187            .as_bind_group(&pipeline.layout, &render_device, &mut param.into_inner())
188            .unwrap();
189        commands.insert_resource(ComputeShaderBindGroup::<Self> {
190            bind_group: bind_group.bind_group,
191            _marker: PhantomData,
192        });
193    }
194    /// Optional readbacks.
195    fn readback(&self) -> Option<Readback> {
196        None
197    }
198    /// Optional processing on readback. Could write back to the CPU buffer, etc.
199    fn on_readback(_trigger: Trigger<ReadbackComplete>, mut _world: DeferredWorld) {}
200}
201
202/// Stores prepared bind group data for the compute shader.
203#[derive(Resource)]
204pub struct ComputeShaderBindGroup<S: ComputeShader> {
205    pub bind_group: BindGroup,
206    pub _marker: PhantomData<S>,
207}
208
209/// Enum representing possible compute node states.
210#[derive(Default, Copy, Clone, PartialEq, Eq, Debug, Hash)]
211pub enum ComputeNodeStatus {
212    #[default]
213    Loading,
214    Init,
215    Ready,
216    Completed,
217    Error,
218}
219/// Tracks compute node state.
220/// In render world, this is stored as a resource which is later extracted to main.
221/// In main world, this is a state so systems can react to state entry.
222#[derive(States, Resource, Clone, Copy, Debug)]
223pub struct ComputeNodeState<S: ComputeShader> {
224    status: ComputeNodeStatus,
225    _marker: PhantomData<S>,
226}
227impl<S: ComputeShader> Hash for ComputeNodeState<S> {
228    fn hash<H: Hasher>(&self, state: &mut H) {
229        self.status.hash(state);
230    }
231}
232impl<S: ComputeShader> PartialEq for ComputeNodeState<S> {
233    fn eq(&self, other: &Self) -> bool {
234        self.status == other.status
235    }
236}
237impl<S: ComputeShader> Eq for ComputeNodeState<S> {}
238impl<S: ComputeShader> From<ComputeNodeStatus> for ComputeNodeState<S> {
239    fn from(value: ComputeNodeStatus) -> Self {
240        Self {
241            status: value,
242            _marker: PhantomData,
243        }
244    }
245}
246impl<S: ComputeShader> Default for ComputeNodeState<S> {
247    fn default() -> Self {
248        Self {
249            status: ComputeNodeStatus::default(),
250            _marker: PhantomData,
251        }
252    }
253}
254impl<S: ComputeShader> ComputeNodeState<S> {
255    /// Extracts compute node state resource into a state
256    /// that systems can react to in the main world.
257    fn extract_to_main(compute_state: Res<ComputeNodeState<S>>, mut world: ResMut<MainWorld>) {
258        world
259            .resource_mut::<NextState<ComputeNodeState<S>>>()
260            .set(compute_state.clone());
261    }
262}
263
264/// Defines the pipeline for the compute shader.
265#[derive(Resource)]
266pub struct ComputePipeline<S: ComputeShader> {
267    pub layout: BindGroupLayout,
268    pipeline: CachedComputePipelineId,
269    _marker: PhantomData<S>,
270}
271impl<S: ComputeShader> FromWorld for ComputePipeline<S> {
272    fn from_world(world: &mut World) -> Self {
273        let render_device = world.resource::<RenderDevice>();
274        let layout = S::bind_group_layout(render_device);
275        let shader = match S::compute_shader() {
276            ShaderRef::Default => panic!("Must define compute_shader."),
277            ShaderRef::Handle(handle) => handle,
278            ShaderRef::Path(path) => world.load_asset(path),
279        };
280        let pipeline_cache = world.resource::<PipelineCache>();
281        let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
282            label: Some("GPU readback compute shader".into()),
283            layout: vec![layout.clone()],
284            push_constant_ranges: Vec::new(),
285            shader: shader.clone(),
286            shader_defs: Vec::new(),
287            entry_point: "main".into(),
288            zero_initialize_workgroup_memory: false,
289        });
290        Self {
291            layout,
292            pipeline,
293            _marker: PhantomData,
294        }
295    }
296}
297
298/// Label to identify the node in the render graph.
299#[derive(Debug, Clone, RenderLabel)]
300struct ComputeNodeLabel<S: ComputeShader> {
301    _marker: PhantomData<S>,
302}
303impl<S: ComputeShader> Default for ComputeNodeLabel<S> {
304    fn default() -> Self {
305        Self {
306            _marker: PhantomData,
307        }
308    }
309}
310impl<S: ComputeShader> PartialEq for ComputeNodeLabel<S> {
311    fn eq(&self, _other: &Self) -> bool {
312        true
313    }
314}
315impl<S: ComputeShader> Eq for ComputeNodeLabel<S> {}
316impl<S: ComputeShader> Hash for ComputeNodeLabel<S> {
317    fn hash<H: Hasher>(&self, _state: &mut H) {}
318}
319impl<S: ComputeShader> ComputeNodeLabel<S> {
320    fn remove_on_complete(mut render_graph: ResMut<RenderGraph>, state: Res<ComputeNodeState<S>>) {
321        if state.status == ComputeNodeStatus::Completed {
322            let _ = render_graph.remove_node(Self::default());
323        }
324    }
325}
326
327/// The node that will execute the compute shader.
328/// Updates `ComputeNodeState<S>` in the `RenderWorld`.
329struct ComputeNode<S: ComputeShader> {
330    status: ComputeNodeStatus,
331    limit: ReadbackLimit,
332    count: usize,
333    _marker: PhantomData<S>,
334}
335impl<S: ComputeShader> Default for ComputeNode<S> {
336    fn default() -> Self {
337        Self {
338            status: ComputeNodeStatus::default(),
339            limit: ReadbackLimit::Infinite,
340            count: 0,
341            _marker: PhantomData,
342        }
343    }
344}
345impl<S: ComputeShader> ComputeNode<S> {
346    /// When the input shader is changed, reset.
347    fn reset_on_change(
348        mut render_graph: ResMut<RenderGraph>,
349        mut state: ResMut<ComputeNodeState<S>>,
350    ) {
351        let Ok(node) = render_graph.get_node_mut::<Self>(ComputeNodeLabel::<S>::default()) else {
352            return;
353        };
354        node.count = 0;
355        node.status = ComputeNodeStatus::Loading;
356        *state = ComputeNodeState {
357            status: ComputeNodeStatus::Loading,
358            ..Default::default()
359        };
360    }
361}
362impl<S: ComputeShader> render_graph::Node for ComputeNode<S> {
363    fn update(&mut self, world: &mut World) {
364        let pipeline = world.resource::<ComputePipeline<S>>();
365        let pipeline_cache = world.resource::<PipelineCache>();
366
367        let next_status = match pipeline_cache.get_compute_pipeline_state(pipeline.pipeline) {
368            CachedPipelineState::Ok(_) => match (self.status, self.limit) {
369                (ComputeNodeStatus::Completed, _) => ComputeNodeStatus::Completed,
370                (_, ReadbackLimit::Finite(limit)) => {
371                    if self.count < limit {
372                        self.count += 1;
373                        ComputeNodeStatus::Ready
374                    } else {
375                        self.count = 0;
376                        ComputeNodeStatus::Completed
377                    }
378                }
379                _ => ComputeNodeStatus::Ready,
380            },
381            CachedPipelineState::Creating(_) => ComputeNodeStatus::Loading,
382            CachedPipelineState::Queued => ComputeNodeStatus::Loading,
383            CachedPipelineState::Err(_) => ComputeNodeStatus::Error,
384        };
385
386        if self.status != next_status {
387            self.status = next_status;
388            world.resource_mut::<ComputeNodeState<S>>().status = next_status;
389        }
390    }
391
392    fn run(
393        &self,
394        _graph: &mut render_graph::RenderGraphContext,
395        render_context: &mut RenderContext,
396        world: &World,
397    ) -> Result<(), render_graph::NodeRunError> {
398        let pipeline_cache = world.resource::<PipelineCache>();
399        let pipeline = world.resource::<ComputePipeline<S>>();
400        let bind_group = &world.resource::<ComputeShaderBindGroup<S>>().bind_group;
401        if self.status == ComputeNodeStatus::Ready {
402            if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
403                let workgroup_size = S::workgroup_size();
404                let mut pass =
405                    render_context
406                        .command_encoder()
407                        .begin_compute_pass(&ComputePassDescriptor {
408                            label: Some("GPU readback compute pass"),
409                            ..Default::default()
410                        });
411                pass.set_bind_group(0, bind_group, &[]);
412                pass.set_pipeline(init_pipeline);
413                pass.dispatch_workgroups(workgroup_size.x, workgroup_size.y, workgroup_size.z);
414            }
415        }
416        Ok(())
417    }
418}