gpu_readback/
gpu_readback.rs

1//! Simple example demonstrating the use of the [`Readback`] component to read back data from the GPU
2//! using both a storage buffer and texture.
3
4use bevy::{
5    asset::RenderAssetUsages,
6    prelude::*,
7    render::{
8        extract_resource::{ExtractResource, ExtractResourcePlugin},
9        gpu_readback::{Readback, ReadbackComplete},
10        render_asset::RenderAssets,
11        render_graph::{self, RenderGraph, RenderLabel},
12        render_resource::{
13            binding_types::{storage_buffer, texture_storage_2d},
14            *,
15        },
16        renderer::{RenderContext, RenderDevice},
17        storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
18        texture::GpuImage,
19        Render, RenderApp, RenderStartup, RenderSystems,
20    },
21};
22
23/// This example uses a shader source file from the assets subdirectory
24const SHADER_ASSET_PATH: &str = "shaders/gpu_readback.wgsl";
25
26// The length of the buffer sent to the gpu
27const BUFFER_LEN: usize = 16;
28
29fn main() {
30    App::new()
31        .add_plugins((
32            DefaultPlugins,
33            GpuReadbackPlugin,
34            ExtractResourcePlugin::<ReadbackBuffer>::default(),
35            ExtractResourcePlugin::<ReadbackImage>::default(),
36        ))
37        .insert_resource(ClearColor(Color::BLACK))
38        .add_systems(Startup, setup)
39        .run();
40}
41
42// We need a plugin to organize all the systems and render node required for this example
43struct GpuReadbackPlugin;
44impl Plugin for GpuReadbackPlugin {
45    fn build(&self, app: &mut App) {
46        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
47            return;
48        };
49        render_app
50            .add_systems(
51                RenderStartup,
52                (init_compute_pipeline, add_compute_render_graph_node),
53            )
54            .add_systems(
55                Render,
56                prepare_bind_group
57                    .in_set(RenderSystems::PrepareBindGroups)
58                    // We don't need to recreate the bind group every frame
59                    .run_if(not(resource_exists::<GpuBufferBindGroup>)),
60            );
61    }
62}
63
64#[derive(Resource, ExtractResource, Clone)]
65struct ReadbackBuffer(Handle<ShaderStorageBuffer>);
66
67#[derive(Resource, ExtractResource, Clone)]
68struct ReadbackImage(Handle<Image>);
69
70fn setup(
71    mut commands: Commands,
72    mut images: ResMut<Assets<Image>>,
73    mut buffers: ResMut<Assets<ShaderStorageBuffer>>,
74) {
75    // Create a storage buffer with some data
76    let buffer: Vec<u32> = (0..BUFFER_LEN as u32).collect();
77    let mut buffer = ShaderStorageBuffer::from(buffer);
78    // We need to enable the COPY_SRC usage so we can copy the buffer to the cpu
79    buffer.buffer_description.usage |= BufferUsages::COPY_SRC;
80    let buffer = buffers.add(buffer);
81
82    // Create a storage texture with some data
83    let size = Extent3d {
84        width: BUFFER_LEN as u32,
85        height: 1,
86        ..default()
87    };
88    // We create an uninitialized image since this texture will only be used for getting data out
89    // of the compute shader, not getting data in, so there's no reason for it to exist on the CPU
90    let mut image = Image::new_uninit(
91        size,
92        TextureDimension::D2,
93        TextureFormat::R32Uint,
94        RenderAssetUsages::RENDER_WORLD,
95    );
96    // We also need to enable the COPY_SRC, as well as STORAGE_BINDING so we can use it in the
97    // compute shader
98    image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING;
99    let image = images.add(image);
100
101    // Spawn the readback components. For each frame, the data will be read back from the GPU
102    // asynchronously and trigger the `ReadbackComplete` event on this entity. Despawn the entity
103    // to stop reading back the data.
104    commands
105        .spawn(Readback::buffer(buffer.clone()))
106        .observe(|event: On<ReadbackComplete>| {
107            // This matches the type which was used to create the `ShaderStorageBuffer` above,
108            // and is a convenient way to interpret the data.
109            let data: Vec<u32> = event.to_shader_type();
110            info!("Buffer {:?}", data);
111        });
112
113    // It is also possible to read only a range of the buffer.
114    commands
115        .spawn(Readback::buffer_range(
116            buffer.clone(),
117            4 * u32::SHADER_SIZE.get(), // skip the first four elements
118            8 * u32::SHADER_SIZE.get(), // read eight elements
119        ))
120        .observe(|event: On<ReadbackComplete>| {
121            let data: Vec<u32> = event.to_shader_type();
122            info!("Buffer range {:?}", data);
123        });
124
125    // This is just a simple way to pass the buffer handle to the render app for our compute node
126    commands.insert_resource(ReadbackBuffer(buffer));
127
128    // Textures can also be read back from the GPU. Pay careful attention to the format of the
129    // texture, as it will affect how the data is interpreted.
130    commands
131        .spawn(Readback::texture(image.clone()))
132        .observe(|event: On<ReadbackComplete>| {
133            // You probably want to interpret the data as a color rather than a `ShaderType`,
134            // but in this case we know the data is a single channel storage texture, so we can
135            // interpret it as a `Vec<u32>`
136            let data: Vec<u32> = event.to_shader_type();
137            info!("Image {:?}", data);
138        });
139    commands.insert_resource(ReadbackImage(image));
140}
141
142fn add_compute_render_graph_node(mut render_graph: ResMut<RenderGraph>) {
143    // Add the compute node as a top-level node to the render graph. This means it will only execute
144    // once per frame. Normally, adding a node would use the `RenderGraphApp::add_render_graph_node`
145    // method, but it does not allow adding as a top-level node.
146    render_graph.add_node(ComputeNodeLabel, ComputeNode::default());
147}
148
149#[derive(Resource)]
150struct GpuBufferBindGroup(BindGroup);
151
152fn prepare_bind_group(
153    mut commands: Commands,
154    pipeline: Res<ComputePipeline>,
155    render_device: Res<RenderDevice>,
156    buffer: Res<ReadbackBuffer>,
157    image: Res<ReadbackImage>,
158    buffers: Res<RenderAssets<GpuShaderStorageBuffer>>,
159    images: Res<RenderAssets<GpuImage>>,
160) {
161    let buffer = buffers.get(&buffer.0).unwrap();
162    let image = images.get(&image.0).unwrap();
163    let bind_group = render_device.create_bind_group(
164        None,
165        &pipeline.layout,
166        &BindGroupEntries::sequential((
167            buffer.buffer.as_entire_buffer_binding(),
168            image.texture_view.into_binding(),
169        )),
170    );
171    commands.insert_resource(GpuBufferBindGroup(bind_group));
172}
173
174#[derive(Resource)]
175struct ComputePipeline {
176    layout: BindGroupLayout,
177    pipeline: CachedComputePipelineId,
178}
179
180fn init_compute_pipeline(
181    mut commands: Commands,
182    render_device: Res<RenderDevice>,
183    asset_server: Res<AssetServer>,
184    pipeline_cache: Res<PipelineCache>,
185) {
186    let layout = render_device.create_bind_group_layout(
187        None,
188        &BindGroupLayoutEntries::sequential(
189            ShaderStages::COMPUTE,
190            (
191                storage_buffer::<Vec<u32>>(false),
192                texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly),
193            ),
194        ),
195    );
196    let shader = asset_server.load(SHADER_ASSET_PATH);
197    let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
198        label: Some("GPU readback compute shader".into()),
199        layout: vec![layout.clone()],
200        shader: shader.clone(),
201        ..default()
202    });
203    commands.insert_resource(ComputePipeline { layout, pipeline });
204}
205
206/// Label to identify the node in the render graph
207#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
208struct ComputeNodeLabel;
209
210/// The node that will execute the compute shader
211#[derive(Default)]
212struct ComputeNode {}
213
214impl render_graph::Node for ComputeNode {
215    fn run(
216        &self,
217        _graph: &mut render_graph::RenderGraphContext,
218        render_context: &mut RenderContext,
219        world: &World,
220    ) -> Result<(), render_graph::NodeRunError> {
221        let pipeline_cache = world.resource::<PipelineCache>();
222        let pipeline = world.resource::<ComputePipeline>();
223        let bind_group = world.resource::<GpuBufferBindGroup>();
224
225        if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
226            let mut pass =
227                render_context
228                    .command_encoder()
229                    .begin_compute_pass(&ComputePassDescriptor {
230                        label: Some("GPU readback compute pass"),
231                        ..default()
232                    });
233
234            pass.set_bind_group(0, &bind_group.0, &[]);
235            pass.set_pipeline(init_pipeline);
236            pass.dispatch_workgroups(BUFFER_LEN as u32, 1, 1);
237        }
238        Ok(())
239    }
240}