1use std::ops::Not;
13
14use bevy::{
15 asset::RenderAssetUsages,
16 color::palettes::tailwind::{RED_400, SKY_400},
17 core_pipeline::schedule::camera_driver,
18 mesh::Indices,
19 platform::collections::HashSet,
20 prelude::*,
21 render::{
22 extract_component::{ExtractComponent, ExtractComponentPlugin},
23 mesh::allocator::{MeshAllocator, MeshAllocatorSettings},
24 render_resource::{
25 binding_types::{storage_buffer, uniform_buffer},
26 *,
27 },
28 renderer::{RenderContext, RenderGraph, RenderQueue},
29 Render, RenderApp, RenderStartup,
30 },
31};
32
33const SHADER_ASSET_PATH: &str = "shaders/compute_mesh.wgsl";
35
36fn main() {
37 App::new()
38 .add_plugins((
39 DefaultPlugins,
40 ComputeShaderMeshGeneratorPlugin,
41 ExtractComponentPlugin::<GenerateMesh>::default(),
42 ))
43 .insert_resource(ClearColor(Color::BLACK))
44 .add_systems(Startup, setup)
45 .run();
46}
47
48struct ComputeShaderMeshGeneratorPlugin;
50impl Plugin for ComputeShaderMeshGeneratorPlugin {
51 fn build(&self, app: &mut App) {
52 let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
53 return;
54 };
55
56 render_app
57 .init_resource::<ChunksToProcess>()
58 .add_systems(RenderStartup, init_compute_pipeline)
59 .add_systems(Render, prepare_chunks)
60 .add_systems(RenderGraph, compute_mesh.before(camera_driver));
61 }
62 fn finish(&self, app: &mut App) {
63 let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
64 return;
65 };
66 render_app
67 .world_mut()
68 .resource_mut::<MeshAllocatorSettings>()
69 .extra_buffer_usages = BufferUsages::STORAGE;
74 }
75}
76
77#[derive(Component, ExtractComponent, Clone)]
80struct GenerateMesh(Handle<Mesh>);
81
82fn setup(
83 mut commands: Commands,
84 mut meshes: ResMut<Assets<Mesh>>,
85 mut materials: ResMut<Assets<StandardMaterial>>,
86) {
87 let empty_mesh = {
102 let mut mesh = Mesh::new(
103 PrimitiveTopology::TriangleList,
104 RenderAssetUsages::RENDER_WORLD,
105 )
106 .with_inserted_attribute(Mesh::ATTRIBUTE_POSITION, vec![[0.; 3]; 50])
107 .with_inserted_attribute(Mesh::ATTRIBUTE_NORMAL, vec![[0.; 3]; 50])
108 .with_inserted_attribute(Mesh::ATTRIBUTE_UV_0, vec![[0.; 2]; 50])
109 .with_inserted_indices(Indices::U32(vec![0; 50]));
110
111 mesh.asset_usage = RenderAssetUsages::RENDER_WORLD;
112 mesh
113 };
114
115 let handle = meshes.add(empty_mesh);
116
117 commands.spawn((
121 GenerateMesh(handle.clone()),
122 Mesh3d(handle.clone()),
123 MeshMaterial3d(materials.add(StandardMaterial {
124 base_color: RED_400.into(),
125 ..default()
126 })),
127 Transform::from_xyz(-2.5, 1.5, 0.),
128 ));
129
130 commands.spawn((
131 Mesh3d(handle),
132 MeshMaterial3d(materials.add(StandardMaterial {
133 base_color: SKY_400.into(),
134 ..default()
135 })),
136 Transform::from_xyz(2.5, 1.5, 0.),
137 ));
138
139 commands.spawn((
144 Mesh3d(meshes.add(Circle::new(4.0))),
145 MeshMaterial3d(materials.add(Color::WHITE)),
146 Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),
147 ));
148 commands.spawn((
149 PointLight {
150 shadow_maps_enabled: true,
151 ..default()
152 },
153 Transform::from_xyz(4.0, 8.0, 4.0),
154 ));
155 commands.spawn((
157 Camera3d::default(),
158 Transform::from_xyz(-2.5, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),
159 ));
160}
161
162#[derive(Resource, Default)]
166struct ChunksToProcess(Vec<AssetId<Mesh>>);
167
168fn prepare_chunks(
174 meshes_to_generate: Query<&GenerateMesh>,
175 mut chunks: ResMut<ChunksToProcess>,
176 pipeline_cache: Res<PipelineCache>,
177 pipeline: Res<ComputePipeline>,
178 mut processed: Local<HashSet<AssetId<Mesh>>>,
179) {
180 if pipeline_cache
184 .get_compute_pipeline(pipeline.pipeline)
185 .is_some()
186 {
187 let chunk_data: Vec<AssetId<Mesh>> = meshes_to_generate
191 .iter()
192 .filter_map(|gmesh| {
193 let id = gmesh.0.id();
194 processed.contains(&id).not().then_some(id)
195 })
196 .collect();
197
198 for id in &chunk_data {
200 processed.insert(*id);
201 }
202
203 chunks.0 = chunk_data;
204 }
205}
206
207#[derive(Resource)]
208struct ComputePipeline {
209 layout: BindGroupLayoutDescriptor,
210 pipeline: CachedComputePipelineId,
211}
212
213fn init_compute_pipeline(
215 mut commands: Commands,
216 asset_server: Res<AssetServer>,
217 pipeline_cache: Res<PipelineCache>,
218) {
219 let layout = BindGroupLayoutDescriptor::new(
220 "",
221 &BindGroupLayoutEntries::sequential(
222 ShaderStages::COMPUTE,
223 (
224 uniform_buffer::<DataRanges>(false),
226 storage_buffer::<Vec<u32>>(false),
228 storage_buffer::<Vec<u32>>(false),
230 ),
231 ),
232 );
233 let shader = asset_server.load(SHADER_ASSET_PATH);
234 let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
235 label: Some("Mesh generation compute shader".into()),
236 layout: vec![layout.clone()],
237 shader: shader.clone(),
238 ..default()
239 });
240 commands.insert_resource(ComputePipeline { layout, pipeline });
241}
242
243#[derive(ShaderType)]
246struct DataRanges {
247 vertex_start: u32,
248 vertex_end: u32,
249 index_start: u32,
250 index_end: u32,
251}
252
253fn compute_mesh(
254 mut render_context: RenderContext,
255 chunks: Res<ChunksToProcess>,
256 mesh_allocator: Res<MeshAllocator>,
257 pipeline_cache: Res<PipelineCache>,
258 pipeline: Res<ComputePipeline>,
259 render_queue: Res<RenderQueue>,
260) {
261 let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) else {
262 return;
263 };
264
265 for mesh_id in &chunks.0 {
266 info!(?mesh_id, "processing mesh");
267
268 let vertex_buffer_slice = mesh_allocator.mesh_vertex_slice(mesh_id).unwrap();
273 let index_buffer_slice = mesh_allocator.mesh_index_slice(mesh_id).unwrap();
274
275 let first = DataRanges {
276 vertex_start: vertex_buffer_slice.range.start * 8,
281 vertex_end: vertex_buffer_slice.range.end * 8,
282 index_start: index_buffer_slice.range.start,
285 index_end: index_buffer_slice.range.end,
286 };
287
288 let mut uniforms = UniformBuffer::from(first);
289 uniforms.write_buffer(render_context.render_device(), &render_queue);
290
291 let bind_group = render_context.render_device().create_bind_group(
294 None,
295 &pipeline_cache.get_bind_group_layout(&pipeline.layout),
296 &BindGroupEntries::sequential((
297 &uniforms,
298 vertex_buffer_slice.buffer.as_entire_buffer_binding(),
299 index_buffer_slice.buffer.as_entire_buffer_binding(),
300 )),
301 );
302
303 let mut pass =
304 render_context
305 .command_encoder()
306 .begin_compute_pass(&ComputePassDescriptor {
307 label: Some("Mesh generation compute pass"),
308 ..default()
309 });
310 pass.push_debug_group("compute_mesh");
311
312 pass.set_bind_group(0, &bind_group, &[]);
313 pass.set_pipeline(init_pipeline);
314 pass.dispatch_workgroups(1, 1, 1);
317
318 pass.pop_debug_group();
319 }
320}