Skip to main content

gpu_readback/
gpu_readback.rs

1//! Simple example demonstrating the use of the [`Readback`] component to read back data from the GPU
2//! using both a storage buffer and texture.
3
4use bevy::{
5    asset::RenderAssetUsages,
6    prelude::*,
7    render::{
8        extract_resource::{ExtractResource, ExtractResourcePlugin},
9        gpu_readback::{Readback, ReadbackComplete},
10        render_asset::RenderAssets,
11        render_resource::{
12            binding_types::{storage_buffer, texture_storage_2d},
13            *,
14        },
15        renderer::{RenderContext, RenderDevice, RenderGraph},
16        storage::{GpuShaderBuffer, ShaderBuffer},
17        texture::GpuImage,
18        Render, RenderApp, RenderStartup, RenderSystems,
19    },
20};
21
22/// This example uses a shader source file from the assets subdirectory
23const SHADER_ASSET_PATH: &str = "shaders/gpu_readback.wgsl";
24
25// The length of the buffer sent to the gpu
26const BUFFER_LEN: usize = 16;
27
28fn main() {
29    App::new()
30        .add_plugins((
31            DefaultPlugins,
32            GpuReadbackPlugin,
33            ExtractResourcePlugin::<ReadbackBuffer>::default(),
34            ExtractResourcePlugin::<ReadbackImage>::default(),
35        ))
36        .insert_resource(ClearColor(Color::BLACK))
37        .add_systems(Startup, setup)
38        .run();
39}
40
41// We need a plugin to organize all the systems and render node required for this example
42struct GpuReadbackPlugin;
43impl Plugin for GpuReadbackPlugin {
44    fn build(&self, app: &mut App) {
45        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
46            return;
47        };
48        render_app
49            .add_systems(RenderStartup, init_compute_pipeline)
50            .add_systems(
51                Render,
52                prepare_bind_group
53                    .in_set(RenderSystems::PrepareBindGroups)
54                    // We don't need to recreate the bind group every frame
55                    .run_if(not(resource_exists::<GpuBufferBindGroup>)),
56            )
57            .add_systems(RenderGraph, compute);
58    }
59}
60
61#[derive(Resource, ExtractResource, Clone)]
62struct ReadbackBuffer(Handle<ShaderBuffer>);
63
64#[derive(Resource, ExtractResource, Clone)]
65struct ReadbackImage(Handle<Image>);
66
67fn setup(
68    mut commands: Commands,
69    mut images: ResMut<Assets<Image>>,
70    mut buffers: ResMut<Assets<ShaderBuffer>>,
71) {
72    // Create a storage buffer with some data
73    let buffer: Vec<u32> = (0..BUFFER_LEN as u32).collect();
74    let mut buffer = ShaderBuffer::from(buffer);
75    // We need to enable the COPY_SRC usage so we can copy the buffer to the cpu
76    buffer.buffer_description.usage |= BufferUsages::COPY_SRC;
77    let buffer = buffers.add(buffer);
78
79    // Create a storage texture with some data
80    let size = Extent3d {
81        width: BUFFER_LEN as u32,
82        height: 1,
83        ..default()
84    };
85    // We create an uninitialized image since this texture will only be used for getting data out
86    // of the compute shader, not getting data in, so there's no reason for it to exist on the CPU
87    let mut image = Image::new_uninit(
88        size,
89        TextureDimension::D2,
90        TextureFormat::R32Uint,
91        RenderAssetUsages::RENDER_WORLD,
92    );
93    // We also need to enable the COPY_SRC, as well as STORAGE_BINDING so we can use it in the
94    // compute shader
95    image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING;
96    let image = images.add(image);
97
98    // Spawn the readback components. For each frame, the data will be read back from the GPU
99    // asynchronously and trigger the `ReadbackComplete` event on this entity. Despawn the entity
100    // to stop reading back the data.
101    commands
102        .spawn(Readback::buffer(buffer.clone()))
103        .observe(|event: On<ReadbackComplete>| {
104            // This matches the type which was used to create the `ShaderBuffer` above,
105            // and is a convenient way to interpret the data.
106            let data: Vec<u32> = event.to_shader_type();
107            info!("Buffer {:?}", data);
108        });
109
110    // It is also possible to read only a range of the buffer.
111    commands
112        .spawn(Readback::buffer_range(
113            buffer.clone(),
114            4 * u32::SHADER_SIZE.get(), // skip the first four elements
115            8 * u32::SHADER_SIZE.get(), // read eight elements
116        ))
117        .observe(|event: On<ReadbackComplete>| {
118            let data: Vec<u32> = event.to_shader_type();
119            info!("Buffer range {:?}", data);
120        });
121
122    // This is just a simple way to pass the buffer handle to the render app for our compute node
123    commands.insert_resource(ReadbackBuffer(buffer));
124
125    // Textures can also be read back from the GPU. Pay careful attention to the format of the
126    // texture, as it will affect how the data is interpreted.
127    commands
128        .spawn(Readback::texture(image.clone()))
129        .observe(|event: On<ReadbackComplete>| {
130            // You probably want to interpret the data as a color rather than a `ShaderType`,
131            // but in this case we know the data is a single channel storage texture, so we can
132            // interpret it as a `Vec<u32>`
133            let data: Vec<u32> = event.to_shader_type();
134            info!("Image {:?}", data);
135        });
136    commands.insert_resource(ReadbackImage(image));
137}
138
139#[derive(Resource)]
140struct GpuBufferBindGroup(BindGroup);
141
142fn prepare_bind_group(
143    mut commands: Commands,
144    pipeline: Res<ComputePipeline>,
145    render_device: Res<RenderDevice>,
146    pipeline_cache: Res<PipelineCache>,
147    buffer: Res<ReadbackBuffer>,
148    image: Res<ReadbackImage>,
149    buffers: Res<RenderAssets<GpuShaderBuffer>>,
150    images: Res<RenderAssets<GpuImage>>,
151) {
152    let buffer = buffers.get(&buffer.0).unwrap();
153    let image = images.get(&image.0).unwrap();
154    let bind_group = render_device.create_bind_group(
155        None,
156        &pipeline_cache.get_bind_group_layout(&pipeline.layout),
157        &BindGroupEntries::sequential((
158            buffer.buffer.as_entire_buffer_binding(),
159            image.texture_view.into_binding(),
160        )),
161    );
162    commands.insert_resource(GpuBufferBindGroup(bind_group));
163}
164
165#[derive(Resource)]
166struct ComputePipeline {
167    layout: BindGroupLayoutDescriptor,
168    pipeline: CachedComputePipelineId,
169}
170
171fn init_compute_pipeline(
172    mut commands: Commands,
173    asset_server: Res<AssetServer>,
174    pipeline_cache: Res<PipelineCache>,
175) {
176    let layout = BindGroupLayoutDescriptor::new(
177        "",
178        &BindGroupLayoutEntries::sequential(
179            ShaderStages::COMPUTE,
180            (
181                storage_buffer::<Vec<u32>>(false),
182                texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly),
183            ),
184        ),
185    );
186    let shader = asset_server.load(SHADER_ASSET_PATH);
187    let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
188        label: Some("GPU readback compute shader".into()),
189        layout: vec![layout.clone()],
190        shader: shader.clone(),
191        ..default()
192    });
193    commands.insert_resource(ComputePipeline { layout, pipeline });
194}
195
196fn compute(
197    mut render_context: RenderContext,
198    pipeline_cache: Res<PipelineCache>,
199    pipeline: Res<ComputePipeline>,
200    bind_group: Res<GpuBufferBindGroup>,
201) {
202    if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
203        let mut pass =
204            render_context
205                .command_encoder()
206                .begin_compute_pass(&ComputePassDescriptor {
207                    label: Some("GPU readback compute pass"),
208                    ..default()
209                });
210
211        pass.set_bind_group(0, &bind_group.0, &[]);
212        pass.set_pipeline(init_pipeline);
213        pass.dispatch_workgroups(BUFFER_LEN as u32, 1, 1);
214    }
215}