1use bevy::{
5 prelude::*,
6 render::{
7 extract_resource::{ExtractResource, ExtractResourcePlugin},
8 gpu_readback::{Readback, ReadbackComplete},
9 render_asset::{RenderAssetUsages, RenderAssets},
10 render_graph::{self, RenderGraph, RenderLabel},
11 render_resource::{
12 binding_types::{storage_buffer, texture_storage_2d},
13 *,
14 },
15 renderer::{RenderContext, RenderDevice},
16 storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
17 texture::GpuImage,
18 Render, RenderApp, RenderSet,
19 },
20};
21
22const SHADER_ASSET_PATH: &str = "shaders/gpu_readback.wgsl";
24
25const BUFFER_LEN: usize = 16;
27
28fn main() {
29 App::new()
30 .add_plugins((
31 DefaultPlugins,
32 GpuReadbackPlugin,
33 ExtractResourcePlugin::<ReadbackBuffer>::default(),
34 ExtractResourcePlugin::<ReadbackImage>::default(),
35 ))
36 .insert_resource(ClearColor(Color::BLACK))
37 .add_systems(Startup, setup)
38 .run();
39}
40
41struct GpuReadbackPlugin;
43impl Plugin for GpuReadbackPlugin {
44 fn build(&self, _app: &mut App) {}
45
46 fn finish(&self, app: &mut App) {
47 let render_app = app.sub_app_mut(RenderApp);
48 render_app.init_resource::<ComputePipeline>().add_systems(
49 Render,
50 prepare_bind_group
51 .in_set(RenderSet::PrepareBindGroups)
52 .run_if(not(resource_exists::<GpuBufferBindGroup>)),
54 );
55
56 render_app
59 .world_mut()
60 .resource_mut::<RenderGraph>()
61 .add_node(ComputeNodeLabel, ComputeNode::default());
62 }
63}
64
65#[derive(Resource, ExtractResource, Clone)]
66struct ReadbackBuffer(Handle<ShaderStorageBuffer>);
67
68#[derive(Resource, ExtractResource, Clone)]
69struct ReadbackImage(Handle<Image>);
70
71fn setup(
72 mut commands: Commands,
73 mut images: ResMut<Assets<Image>>,
74 mut buffers: ResMut<Assets<ShaderStorageBuffer>>,
75) {
76 let buffer = vec![0u32; BUFFER_LEN];
78 let mut buffer = ShaderStorageBuffer::from(buffer);
79 buffer.buffer_description.usage |= BufferUsages::COPY_SRC;
81 let buffer = buffers.add(buffer);
82
83 let size = Extent3d {
85 width: BUFFER_LEN as u32,
86 height: 1,
87 ..default()
88 };
89 let mut image = Image::new_uninit(
92 size,
93 TextureDimension::D2,
94 TextureFormat::R32Uint,
95 RenderAssetUsages::RENDER_WORLD,
96 );
97 image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING;
100 let image = images.add(image);
101
102 commands.spawn(Readback::buffer(buffer.clone())).observe(
106 |trigger: Trigger<ReadbackComplete>| {
107 let data: Vec<u32> = trigger.event().to_shader_type();
110 info!("Buffer {:?}", data);
111 },
112 );
113 commands.insert_resource(ReadbackBuffer(buffer));
115
116 commands.spawn(Readback::texture(image.clone())).observe(
119 |trigger: Trigger<ReadbackComplete>| {
120 let data: Vec<u32> = trigger.event().to_shader_type();
124 info!("Image {:?}", data);
125 },
126 );
127 commands.insert_resource(ReadbackImage(image));
128}
129
130#[derive(Resource)]
131struct GpuBufferBindGroup(BindGroup);
132
133fn prepare_bind_group(
134 mut commands: Commands,
135 pipeline: Res<ComputePipeline>,
136 render_device: Res<RenderDevice>,
137 buffer: Res<ReadbackBuffer>,
138 image: Res<ReadbackImage>,
139 buffers: Res<RenderAssets<GpuShaderStorageBuffer>>,
140 images: Res<RenderAssets<GpuImage>>,
141) {
142 let buffer = buffers.get(&buffer.0).unwrap();
143 let image = images.get(&image.0).unwrap();
144 let bind_group = render_device.create_bind_group(
145 None,
146 &pipeline.layout,
147 &BindGroupEntries::sequential((
148 buffer.buffer.as_entire_buffer_binding(),
149 image.texture_view.into_binding(),
150 )),
151 );
152 commands.insert_resource(GpuBufferBindGroup(bind_group));
153}
154
155#[derive(Resource)]
156struct ComputePipeline {
157 layout: BindGroupLayout,
158 pipeline: CachedComputePipelineId,
159}
160
161impl FromWorld for ComputePipeline {
162 fn from_world(world: &mut World) -> Self {
163 let render_device = world.resource::<RenderDevice>();
164 let layout = render_device.create_bind_group_layout(
165 None,
166 &BindGroupLayoutEntries::sequential(
167 ShaderStages::COMPUTE,
168 (
169 storage_buffer::<Vec<u32>>(false),
170 texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly),
171 ),
172 ),
173 );
174 let shader = world.load_asset(SHADER_ASSET_PATH);
175 let pipeline_cache = world.resource::<PipelineCache>();
176 let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
177 label: Some("GPU readback compute shader".into()),
178 layout: vec![layout.clone()],
179 push_constant_ranges: Vec::new(),
180 shader: shader.clone(),
181 shader_defs: Vec::new(),
182 entry_point: "main".into(),
183 zero_initialize_workgroup_memory: false,
184 });
185 ComputePipeline { layout, pipeline }
186 }
187}
188
189#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
191struct ComputeNodeLabel;
192
193#[derive(Default)]
195struct ComputeNode {}
196impl render_graph::Node for ComputeNode {
197 fn run(
198 &self,
199 _graph: &mut render_graph::RenderGraphContext,
200 render_context: &mut RenderContext,
201 world: &World,
202 ) -> Result<(), render_graph::NodeRunError> {
203 let pipeline_cache = world.resource::<PipelineCache>();
204 let pipeline = world.resource::<ComputePipeline>();
205 let bind_group = world.resource::<GpuBufferBindGroup>();
206
207 if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
208 let mut pass =
209 render_context
210 .command_encoder()
211 .begin_compute_pass(&ComputePassDescriptor {
212 label: Some("GPU readback compute pass"),
213 ..default()
214 });
215
216 pass.set_bind_group(0, &bind_group.0, &[]);
217 pass.set_pipeline(init_pipeline);
218 pass.dispatch_workgroups(BUFFER_LEN as u32, 1, 1);
219 }
220 Ok(())
221 }
222}