Skip to main content

compute_mesh/
compute_mesh.rs

1//! This example shows how to initialize an empty mesh with a Handle
2//! and a render-world only usage. That buffer is then filled by a
3//! compute shader on the GPU without transferring data back
4//! to the CPU.
5//!
6//! The `mesh_allocator` is used to get references to the relevant slabs
7//! that contain the mesh data we're interested in.
8//!
9//! This example does not remove the `GenerateMesh` component after
10//! generating the mesh.
11
12use std::ops::Not;
13
14use bevy::{
15    asset::RenderAssetUsages,
16    color::palettes::tailwind::{RED_400, SKY_400},
17    core_pipeline::schedule::camera_driver,
18    mesh::Indices,
19    platform::collections::HashSet,
20    prelude::*,
21    render::{
22        extract_component::{ExtractComponent, ExtractComponentPlugin},
23        mesh::allocator::{MeshAllocator, MeshAllocatorSettings},
24        render_resource::{
25            binding_types::{storage_buffer, uniform_buffer},
26            *,
27        },
28        renderer::{RenderContext, RenderGraph, RenderQueue},
29        Render, RenderApp, RenderStartup,
30    },
31};
32
33/// This example uses a shader source file from the assets subdirectory
34const SHADER_ASSET_PATH: &str = "shaders/compute_mesh.wgsl";
35
36fn main() {
37    App::new()
38        .add_plugins((
39            DefaultPlugins,
40            ComputeShaderMeshGeneratorPlugin,
41            ExtractComponentPlugin::<GenerateMesh>::default(),
42        ))
43        .insert_resource(ClearColor(Color::BLACK))
44        .add_systems(Startup, setup)
45        .run();
46}
47
48// We need a plugin to organize all the systems and render node required for this example
49struct ComputeShaderMeshGeneratorPlugin;
50impl Plugin for ComputeShaderMeshGeneratorPlugin {
51    fn build(&self, app: &mut App) {
52        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
53            return;
54        };
55
56        render_app
57            .init_resource::<ChunksToProcess>()
58            .add_systems(RenderStartup, init_compute_pipeline)
59            .add_systems(Render, prepare_chunks)
60            .add_systems(RenderGraph, compute_mesh.before(camera_driver));
61    }
62    fn finish(&self, app: &mut App) {
63        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
64            return;
65        };
66        render_app
67            .world_mut()
68            .resource_mut::<MeshAllocatorSettings>()
69            // This allows using the mesh allocator slabs as
70            // storage buffers directly in the compute shader.
71            // Which means that we can write from our compute
72            // shader directly to the allocated mesh slabs.
73            .extra_buffer_usages = BufferUsages::STORAGE;
74    }
75}
76
77/// Holds a handle to the empty mesh that should be filled
78/// by the compute shader.
79#[derive(Component, ExtractComponent, Clone)]
80struct GenerateMesh(Handle<Mesh>);
81
82fn setup(
83    mut commands: Commands,
84    mut meshes: ResMut<Assets<Mesh>>,
85    mut materials: ResMut<Assets<StandardMaterial>>,
86) {
87    // a truly empty mesh will error if used in Mesh3d
88    // so we set up the data to be what we want the compute shader to output
89    // We're using 36 indices and 24 vertices which is directly taken from
90    // the Bevy Cuboid mesh implementation.
91    //
92    // We allocate 50 spots for each attribute here because
93    // it is *very important* that the amount of data allocated here is
94    // *bigger* than (or exactly equal to) the amount of data we intend to
95    // write from the compute shader. This amount of data defines how big
96    // the buffer we get from the mesh_allocator will be, which in turn
97    // defines how big the buffer is when we're in the compute shader.
98    //
99    // If it turns out you don't need all of the space when the compute shader
100    // is writing data, you can write NaN to the rest of the data.
101    let empty_mesh = {
102        let mut mesh = Mesh::new(
103            PrimitiveTopology::TriangleList,
104            RenderAssetUsages::RENDER_WORLD,
105        )
106        .with_inserted_attribute(Mesh::ATTRIBUTE_POSITION, vec![[0.; 3]; 50])
107        .with_inserted_attribute(Mesh::ATTRIBUTE_NORMAL, vec![[0.; 3]; 50])
108        .with_inserted_attribute(Mesh::ATTRIBUTE_UV_0, vec![[0.; 2]; 50])
109        .with_inserted_indices(Indices::U32(vec![0; 50]));
110
111        mesh.asset_usage = RenderAssetUsages::RENDER_WORLD;
112        mesh
113    };
114
115    let handle = meshes.add(empty_mesh);
116
117    // we spawn two "users" of the mesh handle,
118    // but only insert `GenerateMesh` on one of them
119    // to show that the mesh handle works as usual
120    commands.spawn((
121        GenerateMesh(handle.clone()),
122        Mesh3d(handle.clone()),
123        MeshMaterial3d(materials.add(StandardMaterial {
124            base_color: RED_400.into(),
125            ..default()
126        })),
127        Transform::from_xyz(-2.5, 1.5, 0.),
128    ));
129
130    commands.spawn((
131        Mesh3d(handle),
132        MeshMaterial3d(materials.add(StandardMaterial {
133            base_color: SKY_400.into(),
134            ..default()
135        })),
136        Transform::from_xyz(2.5, 1.5, 0.),
137    ));
138
139    // some additional scene elements.
140    // This mesh specifically is here so that we don't assume
141    // mesh_allocator offsets that would only work if we had
142    // one mesh in the scene.
143    commands.spawn((
144        Mesh3d(meshes.add(Circle::new(4.0))),
145        MeshMaterial3d(materials.add(Color::WHITE)),
146        Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),
147    ));
148    commands.spawn((
149        PointLight {
150            shadow_maps_enabled: true,
151            ..default()
152        },
153        Transform::from_xyz(4.0, 8.0, 4.0),
154    ));
155    // camera
156    commands.spawn((
157        Camera3d::default(),
158        Transform::from_xyz(-2.5, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),
159    ));
160}
161
162/// This is called `ChunksToProcess` because this example originated
163/// from a use case of generating chunks of landscape or voxels
164/// It only exists in the render world.
165#[derive(Resource, Default)]
166struct ChunksToProcess(Vec<AssetId<Mesh>>);
167
168/// `processed` is a `HashSet` contains the `AssetId`s that have been
169/// processed. We use that to remove `AssetId`s that have already
170/// been processed, which means each unique `GenerateMesh` will result
171/// in one compute shader mesh generation process instead of generating
172/// the mesh every frame.
173fn prepare_chunks(
174    meshes_to_generate: Query<&GenerateMesh>,
175    mut chunks: ResMut<ChunksToProcess>,
176    pipeline_cache: Res<PipelineCache>,
177    pipeline: Res<ComputePipeline>,
178    mut processed: Local<HashSet<AssetId<Mesh>>>,
179) {
180    // If the pipeline isn't ready, then meshes
181    // won't be processed. So we want to wait until
182    // the pipeline is ready before considering any mesh processed.
183    if pipeline_cache
184        .get_compute_pipeline(pipeline.pipeline)
185        .is_some()
186    {
187        // get the AssetId for each Handle<Mesh>
188        // which we'll use later to get the relevant buffers
189        // from the mesh_allocator
190        let chunk_data: Vec<AssetId<Mesh>> = meshes_to_generate
191            .iter()
192            .filter_map(|gmesh| {
193                let id = gmesh.0.id();
194                processed.contains(&id).not().then_some(id)
195            })
196            .collect();
197
198        // Cache any meshes we're going to process this frame
199        for id in &chunk_data {
200            processed.insert(*id);
201        }
202
203        chunks.0 = chunk_data;
204    }
205}
206
207#[derive(Resource)]
208struct ComputePipeline {
209    layout: BindGroupLayoutDescriptor,
210    pipeline: CachedComputePipelineId,
211}
212
213// init only happens once
214fn init_compute_pipeline(
215    mut commands: Commands,
216    asset_server: Res<AssetServer>,
217    pipeline_cache: Res<PipelineCache>,
218) {
219    let layout = BindGroupLayoutDescriptor::new(
220        "",
221        &BindGroupLayoutEntries::sequential(
222            ShaderStages::COMPUTE,
223            (
224                // offsets
225                uniform_buffer::<DataRanges>(false),
226                // vertices
227                storage_buffer::<Vec<u32>>(false),
228                // indices
229                storage_buffer::<Vec<u32>>(false),
230            ),
231        ),
232    );
233    let shader = asset_server.load(SHADER_ASSET_PATH);
234    let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
235        label: Some("Mesh generation compute shader".into()),
236        layout: vec![layout.clone()],
237        shader: shader.clone(),
238        ..default()
239    });
240    commands.insert_resource(ComputePipeline { layout, pipeline });
241}
242
243// A uniform that holds the vertex and index offsets
244// for the vertex/index mesh_allocator buffer slabs
245#[derive(ShaderType)]
246struct DataRanges {
247    vertex_start: u32,
248    vertex_end: u32,
249    index_start: u32,
250    index_end: u32,
251}
252
253fn compute_mesh(
254    mut render_context: RenderContext,
255    chunks: Res<ChunksToProcess>,
256    mesh_allocator: Res<MeshAllocator>,
257    pipeline_cache: Res<PipelineCache>,
258    pipeline: Res<ComputePipeline>,
259    render_queue: Res<RenderQueue>,
260) {
261    let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) else {
262        return;
263    };
264
265    for mesh_id in &chunks.0 {
266        info!(?mesh_id, "processing mesh");
267
268        // the mesh_allocator holds slabs of meshes, so the buffers we get here
269        // can contain more data than just the mesh we're asking for.
270        // That's why there is a range field.
271        // You should *not* touch data in these buffers that is outside of the range.
272        let vertex_buffer_slice = mesh_allocator.mesh_vertex_slice(mesh_id).unwrap();
273        let index_buffer_slice = mesh_allocator.mesh_index_slice(mesh_id).unwrap();
274
275        let first = DataRanges {
276            // there are 8 vertex data values (pos, normal, uv) per vertex
277            // and the vertex_buffer_slice.range.start is in "vertex elements"
278            // which includes all of that data, so each index is worth 8 indices
279            // to our shader code.
280            vertex_start: vertex_buffer_slice.range.start * 8,
281            vertex_end: vertex_buffer_slice.range.end * 8,
282            // but each vertex index is a single value, so the index of the
283            // vertex indices is exactly what the value is
284            index_start: index_buffer_slice.range.start,
285            index_end: index_buffer_slice.range.end,
286        };
287
288        let mut uniforms = UniformBuffer::from(first);
289        uniforms.write_buffer(render_context.render_device(), &render_queue);
290
291        // pass in the full mesh_allocator slabs as well as the first index
292        // offsets for the vertex and index buffers
293        let bind_group = render_context.render_device().create_bind_group(
294            None,
295            &pipeline_cache.get_bind_group_layout(&pipeline.layout),
296            &BindGroupEntries::sequential((
297                &uniforms,
298                vertex_buffer_slice.buffer.as_entire_buffer_binding(),
299                index_buffer_slice.buffer.as_entire_buffer_binding(),
300            )),
301        );
302
303        let mut pass =
304            render_context
305                .command_encoder()
306                .begin_compute_pass(&ComputePassDescriptor {
307                    label: Some("Mesh generation compute pass"),
308                    ..default()
309                });
310        pass.push_debug_group("compute_mesh");
311
312        pass.set_bind_group(0, &bind_group, &[]);
313        pass.set_pipeline(init_pipeline);
314        // we only dispatch 1,1,1 workgroup here, but a real compute shader
315        // would take advantage of more and larger size workgroups
316        pass.dispatch_workgroups(1, 1, 1);
317
318        pass.pop_debug_group();
319    }
320}