gpu_readback/
gpu_readback.rs

1//! Simple example demonstrating the use of the [`Readback`] component to read back data from the GPU
2//! using both a storage buffer and texture.
3
4use bevy::{
5    prelude::*,
6    render::{
7        extract_resource::{ExtractResource, ExtractResourcePlugin},
8        gpu_readback::{Readback, ReadbackComplete},
9        render_asset::{RenderAssetUsages, RenderAssets},
10        render_graph::{self, RenderGraph, RenderLabel},
11        render_resource::{
12            binding_types::{storage_buffer, texture_storage_2d},
13            *,
14        },
15        renderer::{RenderContext, RenderDevice},
16        storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
17        texture::GpuImage,
18        Render, RenderApp, RenderSet,
19    },
20};
21
22/// This example uses a shader source file from the assets subdirectory
23const SHADER_ASSET_PATH: &str = "shaders/gpu_readback.wgsl";
24
25// The length of the buffer sent to the gpu
26const BUFFER_LEN: usize = 16;
27
28fn main() {
29    App::new()
30        .add_plugins((
31            DefaultPlugins,
32            GpuReadbackPlugin,
33            ExtractResourcePlugin::<ReadbackBuffer>::default(),
34            ExtractResourcePlugin::<ReadbackImage>::default(),
35        ))
36        .insert_resource(ClearColor(Color::BLACK))
37        .add_systems(Startup, setup)
38        .run();
39}
40
41// We need a plugin to organize all the systems and render node required for this example
42struct GpuReadbackPlugin;
43impl Plugin for GpuReadbackPlugin {
44    fn build(&self, _app: &mut App) {}
45
46    fn finish(&self, app: &mut App) {
47        let render_app = app.sub_app_mut(RenderApp);
48        render_app.init_resource::<ComputePipeline>().add_systems(
49            Render,
50            prepare_bind_group
51                .in_set(RenderSet::PrepareBindGroups)
52                // We don't need to recreate the bind group every frame
53                .run_if(not(resource_exists::<GpuBufferBindGroup>)),
54        );
55
56        // Add the compute node as a top level node to the render graph
57        // This means it will only execute once per frame
58        render_app
59            .world_mut()
60            .resource_mut::<RenderGraph>()
61            .add_node(ComputeNodeLabel, ComputeNode::default());
62    }
63}
64
65#[derive(Resource, ExtractResource, Clone)]
66struct ReadbackBuffer(Handle<ShaderStorageBuffer>);
67
68#[derive(Resource, ExtractResource, Clone)]
69struct ReadbackImage(Handle<Image>);
70
71fn setup(
72    mut commands: Commands,
73    mut images: ResMut<Assets<Image>>,
74    mut buffers: ResMut<Assets<ShaderStorageBuffer>>,
75) {
76    // Create a storage buffer with some data
77    let buffer = vec![0u32; BUFFER_LEN];
78    let mut buffer = ShaderStorageBuffer::from(buffer);
79    // We need to enable the COPY_SRC usage so we can copy the buffer to the cpu
80    buffer.buffer_description.usage |= BufferUsages::COPY_SRC;
81    let buffer = buffers.add(buffer);
82
83    // Create a storage texture with some data
84    let size = Extent3d {
85        width: BUFFER_LEN as u32,
86        height: 1,
87        ..default()
88    };
89    // We create an uninitialized image since this texture will only be used for getting data out
90    // of the compute shader, not getting data in, so there's no reason for it to exist on the CPU
91    let mut image = Image::new_uninit(
92        size,
93        TextureDimension::D2,
94        TextureFormat::R32Uint,
95        RenderAssetUsages::RENDER_WORLD,
96    );
97    // We also need to enable the COPY_SRC, as well as STORAGE_BINDING so we can use it in the
98    // compute shader
99    image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING;
100    let image = images.add(image);
101
102    // Spawn the readback components. For each frame, the data will be read back from the GPU
103    // asynchronously and trigger the `ReadbackComplete` event on this entity. Despawn the entity
104    // to stop reading back the data.
105    commands.spawn(Readback::buffer(buffer.clone())).observe(
106        |trigger: Trigger<ReadbackComplete>| {
107            // This matches the type which was used to create the `ShaderStorageBuffer` above,
108            // and is a convenient way to interpret the data.
109            let data: Vec<u32> = trigger.event().to_shader_type();
110            info!("Buffer {:?}", data);
111        },
112    );
113    // This is just a simple way to pass the buffer handle to the render app for our compute node
114    commands.insert_resource(ReadbackBuffer(buffer));
115
116    // Textures can also be read back from the GPU. Pay careful attention to the format of the
117    // texture, as it will affect how the data is interpreted.
118    commands.spawn(Readback::texture(image.clone())).observe(
119        |trigger: Trigger<ReadbackComplete>| {
120            // You probably want to interpret the data as a color rather than a `ShaderType`,
121            // but in this case we know the data is a single channel storage texture, so we can
122            // interpret it as a `Vec<u32>`
123            let data: Vec<u32> = trigger.event().to_shader_type();
124            info!("Image {:?}", data);
125        },
126    );
127    commands.insert_resource(ReadbackImage(image));
128}
129
130#[derive(Resource)]
131struct GpuBufferBindGroup(BindGroup);
132
133fn prepare_bind_group(
134    mut commands: Commands,
135    pipeline: Res<ComputePipeline>,
136    render_device: Res<RenderDevice>,
137    buffer: Res<ReadbackBuffer>,
138    image: Res<ReadbackImage>,
139    buffers: Res<RenderAssets<GpuShaderStorageBuffer>>,
140    images: Res<RenderAssets<GpuImage>>,
141) {
142    let buffer = buffers.get(&buffer.0).unwrap();
143    let image = images.get(&image.0).unwrap();
144    let bind_group = render_device.create_bind_group(
145        None,
146        &pipeline.layout,
147        &BindGroupEntries::sequential((
148            buffer.buffer.as_entire_buffer_binding(),
149            image.texture_view.into_binding(),
150        )),
151    );
152    commands.insert_resource(GpuBufferBindGroup(bind_group));
153}
154
155#[derive(Resource)]
156struct ComputePipeline {
157    layout: BindGroupLayout,
158    pipeline: CachedComputePipelineId,
159}
160
161impl FromWorld for ComputePipeline {
162    fn from_world(world: &mut World) -> Self {
163        let render_device = world.resource::<RenderDevice>();
164        let layout = render_device.create_bind_group_layout(
165            None,
166            &BindGroupLayoutEntries::sequential(
167                ShaderStages::COMPUTE,
168                (
169                    storage_buffer::<Vec<u32>>(false),
170                    texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly),
171                ),
172            ),
173        );
174        let shader = world.load_asset(SHADER_ASSET_PATH);
175        let pipeline_cache = world.resource::<PipelineCache>();
176        let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
177            label: Some("GPU readback compute shader".into()),
178            layout: vec![layout.clone()],
179            push_constant_ranges: Vec::new(),
180            shader: shader.clone(),
181            shader_defs: Vec::new(),
182            entry_point: "main".into(),
183            zero_initialize_workgroup_memory: false,
184        });
185        ComputePipeline { layout, pipeline }
186    }
187}
188
189/// Label to identify the node in the render graph
190#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
191struct ComputeNodeLabel;
192
193/// The node that will execute the compute shader
194#[derive(Default)]
195struct ComputeNode {}
196impl render_graph::Node for ComputeNode {
197    fn run(
198        &self,
199        _graph: &mut render_graph::RenderGraphContext,
200        render_context: &mut RenderContext,
201        world: &World,
202    ) -> Result<(), render_graph::NodeRunError> {
203        let pipeline_cache = world.resource::<PipelineCache>();
204        let pipeline = world.resource::<ComputePipeline>();
205        let bind_group = world.resource::<GpuBufferBindGroup>();
206
207        if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
208            let mut pass =
209                render_context
210                    .command_encoder()
211                    .begin_compute_pass(&ComputePassDescriptor {
212                        label: Some("GPU readback compute pass"),
213                        ..default()
214                    });
215
216            pass.set_bind_group(0, &bind_group.0, &[]);
217            pass.set_pipeline(init_pipeline);
218            pass.dispatch_workgroups(BUFFER_LEN as u32, 1, 1);
219        }
220        Ok(())
221    }
222}