1use bevy::{
5 asset::RenderAssetUsages,
6 prelude::*,
7 render::{
8 extract_resource::{ExtractResource, ExtractResourcePlugin},
9 gpu_readback::{Readback, ReadbackComplete},
10 render_asset::RenderAssets,
11 render_graph::{self, RenderGraph, RenderLabel},
12 render_resource::{
13 binding_types::{storage_buffer, texture_storage_2d},
14 *,
15 },
16 renderer::{RenderContext, RenderDevice},
17 storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
18 texture::GpuImage,
19 Render, RenderApp, RenderStartup, RenderSystems,
20 },
21};
22
23const SHADER_ASSET_PATH: &str = "shaders/gpu_readback.wgsl";
25
26const BUFFER_LEN: usize = 16;
28
29fn main() {
30 App::new()
31 .add_plugins((
32 DefaultPlugins,
33 GpuReadbackPlugin,
34 ExtractResourcePlugin::<ReadbackBuffer>::default(),
35 ExtractResourcePlugin::<ReadbackImage>::default(),
36 ))
37 .insert_resource(ClearColor(Color::BLACK))
38 .add_systems(Startup, setup)
39 .run();
40}
41
42struct GpuReadbackPlugin;
44impl Plugin for GpuReadbackPlugin {
45 fn build(&self, app: &mut App) {
46 let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
47 return;
48 };
49 render_app
50 .add_systems(
51 RenderStartup,
52 (init_compute_pipeline, add_compute_render_graph_node),
53 )
54 .add_systems(
55 Render,
56 prepare_bind_group
57 .in_set(RenderSystems::PrepareBindGroups)
58 .run_if(not(resource_exists::<GpuBufferBindGroup>)),
60 );
61 }
62}
63
64#[derive(Resource, ExtractResource, Clone)]
65struct ReadbackBuffer(Handle<ShaderStorageBuffer>);
66
67#[derive(Resource, ExtractResource, Clone)]
68struct ReadbackImage(Handle<Image>);
69
70fn setup(
71 mut commands: Commands,
72 mut images: ResMut<Assets<Image>>,
73 mut buffers: ResMut<Assets<ShaderStorageBuffer>>,
74) {
75 let buffer: Vec<u32> = (0..BUFFER_LEN as u32).collect();
77 let mut buffer = ShaderStorageBuffer::from(buffer);
78 buffer.buffer_description.usage |= BufferUsages::COPY_SRC;
80 let buffer = buffers.add(buffer);
81
82 let size = Extent3d {
84 width: BUFFER_LEN as u32,
85 height: 1,
86 ..default()
87 };
88 let mut image = Image::new_uninit(
91 size,
92 TextureDimension::D2,
93 TextureFormat::R32Uint,
94 RenderAssetUsages::RENDER_WORLD,
95 );
96 image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING;
99 let image = images.add(image);
100
101 commands
105 .spawn(Readback::buffer(buffer.clone()))
106 .observe(|event: On<ReadbackComplete>| {
107 let data: Vec<u32> = event.to_shader_type();
110 info!("Buffer {:?}", data);
111 });
112
113 commands
115 .spawn(Readback::buffer_range(
116 buffer.clone(),
117 4 * u32::SHADER_SIZE.get(), 8 * u32::SHADER_SIZE.get(), ))
120 .observe(|event: On<ReadbackComplete>| {
121 let data: Vec<u32> = event.to_shader_type();
122 info!("Buffer range {:?}", data);
123 });
124
125 commands.insert_resource(ReadbackBuffer(buffer));
127
128 commands
131 .spawn(Readback::texture(image.clone()))
132 .observe(|event: On<ReadbackComplete>| {
133 let data: Vec<u32> = event.to_shader_type();
137 info!("Image {:?}", data);
138 });
139 commands.insert_resource(ReadbackImage(image));
140}
141
142fn add_compute_render_graph_node(mut render_graph: ResMut<RenderGraph>) {
143 render_graph.add_node(ComputeNodeLabel, ComputeNode::default());
147}
148
149#[derive(Resource)]
150struct GpuBufferBindGroup(BindGroup);
151
152fn prepare_bind_group(
153 mut commands: Commands,
154 pipeline: Res<ComputePipeline>,
155 render_device: Res<RenderDevice>,
156 buffer: Res<ReadbackBuffer>,
157 image: Res<ReadbackImage>,
158 buffers: Res<RenderAssets<GpuShaderStorageBuffer>>,
159 images: Res<RenderAssets<GpuImage>>,
160) {
161 let buffer = buffers.get(&buffer.0).unwrap();
162 let image = images.get(&image.0).unwrap();
163 let bind_group = render_device.create_bind_group(
164 None,
165 &pipeline.layout,
166 &BindGroupEntries::sequential((
167 buffer.buffer.as_entire_buffer_binding(),
168 image.texture_view.into_binding(),
169 )),
170 );
171 commands.insert_resource(GpuBufferBindGroup(bind_group));
172}
173
174#[derive(Resource)]
175struct ComputePipeline {
176 layout: BindGroupLayout,
177 pipeline: CachedComputePipelineId,
178}
179
180fn init_compute_pipeline(
181 mut commands: Commands,
182 render_device: Res<RenderDevice>,
183 asset_server: Res<AssetServer>,
184 pipeline_cache: Res<PipelineCache>,
185) {
186 let layout = render_device.create_bind_group_layout(
187 None,
188 &BindGroupLayoutEntries::sequential(
189 ShaderStages::COMPUTE,
190 (
191 storage_buffer::<Vec<u32>>(false),
192 texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly),
193 ),
194 ),
195 );
196 let shader = asset_server.load(SHADER_ASSET_PATH);
197 let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
198 label: Some("GPU readback compute shader".into()),
199 layout: vec![layout.clone()],
200 shader: shader.clone(),
201 ..default()
202 });
203 commands.insert_resource(ComputePipeline { layout, pipeline });
204}
205
206#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
208struct ComputeNodeLabel;
209
210#[derive(Default)]
212struct ComputeNode {}
213
214impl render_graph::Node for ComputeNode {
215 fn run(
216 &self,
217 _graph: &mut render_graph::RenderGraphContext,
218 render_context: &mut RenderContext,
219 world: &World,
220 ) -> Result<(), render_graph::NodeRunError> {
221 let pipeline_cache = world.resource::<PipelineCache>();
222 let pipeline = world.resource::<ComputePipeline>();
223 let bind_group = world.resource::<GpuBufferBindGroup>();
224
225 if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
226 let mut pass =
227 render_context
228 .command_encoder()
229 .begin_compute_pass(&ComputePassDescriptor {
230 label: Some("GPU readback compute pass"),
231 ..default()
232 });
233
234 pass.set_bind_group(0, &bind_group.0, &[]);
235 pass.set_pipeline(init_pipeline);
236 pass.dispatch_workgroups(BUFFER_LEN as u32, 1, 1);
237 }
238 Ok(())
239 }
240}