Skip to main content

custom_shader_instancing/
custom_shader_instancing.rs

1//! A shader that renders a mesh multiple times in one draw call.
2//!
3//! Bevy will automatically batch and instance your meshes assuming you use the same
4//! `Handle<Material>` and `Handle<Mesh>` for all of your instances.
5//!
6//! This example is intended for advanced users and shows how to make a custom instancing
7//! implementation using bevy's low level rendering api.
8//! It's generally recommended to try the built-in instancing before going with this approach.
9
10use bevy::core_pipeline::core_3d::TransparentSortingInfo3d;
11use bevy::pbr::{
12    self, MeshInputUniform, MeshPipelineSystems, MeshUniform, SetMeshViewBindingArrayBindGroup,
13    ViewKeyCache,
14};
15use bevy::{
16    camera::visibility::NoFrustumCulling,
17    core_pipeline::core_3d::Transparent3d,
18    ecs::{
19        query::QueryItem,
20        system::{lifetimeless::*, SystemParamItem},
21    },
22    mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout},
23    pbr::{
24        MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
25    },
26    prelude::*,
27    render::{
28        batching::gpu_preprocessing::BatchedInstanceBuffers,
29        extract_component::{ExtractComponent, ExtractComponentPlugin},
30        mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
31        render_asset::RenderAssets,
32        render_phase::{
33            AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
34            RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
35        },
36        render_resource::*,
37        renderer::RenderDevice,
38        sync_component::SyncComponent,
39        sync_world::MainEntity,
40        view::{ExtractedView, NoIndirectDrawing},
41        Render, RenderApp, RenderStartup, RenderSystems,
42    },
43};
44use bytemuck::{Pod, Zeroable};
45
46/// This example uses a shader source file from the assets subdirectory
47const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
48
49fn main() {
50    App::new()
51        .add_plugins((DefaultPlugins, CustomMaterialPlugin))
52        .add_systems(Startup, setup)
53        .run();
54}
55
56fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
57    commands.spawn((
58        Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
59        InstanceMaterialData(
60            (1..=10)
61                .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
62                .map(|(x, y)| InstanceData {
63                    position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
64                    scale: 1.0,
65                    color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
66                })
67                .collect(),
68        ),
69        // NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
70        // As the cube is at the origin, if its Aabb moves outside the view frustum, all the
71        // instanced cubes will be culled.
72        // The InstanceMaterialData contains the 'GlobalTransform' information for this custom
73        // instancing, and that is not taken into account with the built-in frustum culling.
74        // We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
75        // component to avoid incorrect culling.
76        NoFrustumCulling,
77    ));
78
79    // camera
80    commands.spawn((
81        Camera3d::default(),
82        Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
83        // We need this component because we use `draw_indexed` and `draw`
84        // instead of `draw_indirect_indexed` and `draw_indirect` in
85        // `DrawMeshInstanced::render`.
86        NoIndirectDrawing,
87    ));
88}
89
90#[derive(Component, Deref)]
91struct InstanceMaterialData(Vec<InstanceData>);
92
93impl SyncComponent for InstanceMaterialData {
94    type Target = Self;
95}
96
97impl ExtractComponent for InstanceMaterialData {
98    type QueryData = &'static InstanceMaterialData;
99    type QueryFilter = ();
100    type Out = Self;
101
102    fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
103        Some(InstanceMaterialData(item.0.clone()))
104    }
105}
106
107struct CustomMaterialPlugin;
108
109impl Plugin for CustomMaterialPlugin {
110    fn build(&self, app: &mut App) {
111        app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
112        app.sub_app_mut(RenderApp)
113            .add_render_command::<Transparent3d, DrawCustom>()
114            .init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
115            .add_systems(
116                RenderStartup,
117                init_custom_pipeline.after(MeshPipelineSystems),
118            )
119            .add_systems(
120                Render,
121                (
122                    queue_custom.in_set(RenderSystems::QueueMeshes),
123                    prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
124                ),
125            );
126    }
127}
128
129#[derive(Clone, Copy, Pod, Zeroable)]
130#[repr(C)]
131struct InstanceData {
132    position: Vec3,
133    scale: f32,
134    color: [f32; 4],
135}
136
137fn queue_custom(
138    transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
139    custom_pipeline: Res<CustomPipeline>,
140    mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
141    pipeline_cache: Res<PipelineCache>,
142    meshes: Res<RenderAssets<RenderMesh>>,
143    render_mesh_instances: Res<RenderMeshInstances>,
144    maybe_batched_instance_buffers: Option<
145        Res<BatchedInstanceBuffers<MeshUniform, MeshInputUniform>>,
146    >,
147    material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
148    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
149    views: Query<&ExtractedView>,
150    view_key_cache: Res<ViewKeyCache>,
151) {
152    let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
153
154    for view in &views {
155        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
156        else {
157            continue;
158        };
159
160        let Some(&view_key) = view_key_cache.get(&view.retained_view_entity) else {
161            continue;
162        };
163
164        for (entity, main_entity) in &material_meshes {
165            let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
166            else {
167                continue;
168            };
169            let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id()) else {
170                continue;
171            };
172            let key = view_key
173                | MeshPipelineKey::from_primitive_topology_and_strip_index(
174                    mesh.primitive_topology(),
175                    mesh.index_format(),
176                );
177            let pipeline = pipelines
178                .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
179                .unwrap();
180            transparent_phase.add_retained(Transparent3d {
181                sorting_info: TransparentSortingInfo3d::Sorted {
182                    mesh_center: pbr::get_mesh_instance_world_from_local(
183                        *main_entity,
184                        mesh_instance.current_uniform_index,
185                        &render_mesh_instances,
186                        maybe_batched_instance_buffers.as_deref(),
187                    )
188                    .transform_point3(
189                        meshes
190                            .get(mesh_instance.mesh_asset_id())
191                            .unwrap()
192                            .aabb_center,
193                    ),
194                    depth_bias: 0.0,
195                },
196                entity: (entity, *main_entity),
197                pipeline,
198                draw_function: draw_custom,
199                distance: 0.0,
200                batch_range: 0..1,
201                extra_index: PhaseItemExtraIndex::None,
202                indexed: true,
203            });
204        }
205    }
206}
207
208#[derive(Component)]
209struct InstanceBuffer {
210    buffer: Buffer,
211    length: usize,
212}
213
214fn prepare_instance_buffers(
215    mut commands: Commands,
216    query: Query<(Entity, &InstanceMaterialData)>,
217    render_device: Res<RenderDevice>,
218) {
219    for (entity, instance_data) in &query {
220        let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
221            label: Some("instance data buffer"),
222            contents: bytemuck::cast_slice(instance_data.as_slice()),
223            usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
224        });
225        commands.entity(entity).insert(InstanceBuffer {
226            buffer,
227            length: instance_data.len(),
228        });
229    }
230}
231
232#[derive(Resource)]
233struct CustomPipeline {
234    shader: Handle<Shader>,
235    mesh_pipeline: MeshPipeline,
236}
237
238fn init_custom_pipeline(
239    mut commands: Commands,
240    asset_server: Res<AssetServer>,
241    mesh_pipeline: Res<MeshPipeline>,
242) {
243    commands.insert_resource(CustomPipeline {
244        shader: asset_server.load(SHADER_ASSET_PATH),
245        mesh_pipeline: mesh_pipeline.clone(),
246    });
247}
248
249impl SpecializedMeshPipeline for CustomPipeline {
250    type Key = MeshPipelineKey;
251
252    fn specialize(
253        &self,
254        key: Self::Key,
255        layout: &MeshVertexBufferLayoutRef,
256    ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
257        let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
258
259        descriptor.vertex.shader = self.shader.clone();
260        descriptor.vertex.buffers.push(VertexBufferLayout {
261            array_stride: size_of::<InstanceData>() as u64,
262            step_mode: VertexStepMode::Instance,
263            attributes: vec![
264                VertexAttribute {
265                    format: VertexFormat::Float32x4,
266                    offset: 0,
267                    shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
268                },
269                VertexAttribute {
270                    format: VertexFormat::Float32x4,
271                    offset: VertexFormat::Float32x4.size(),
272                    shader_location: 4,
273                },
274            ],
275        });
276        descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
277        Ok(descriptor)
278    }
279}
280
281type DrawCustom = (
282    SetItemPipeline,
283    SetMeshViewBindGroup<0>,
284    SetMeshViewBindingArrayBindGroup<1>,
285    SetMeshBindGroup<2>,
286    DrawMeshInstanced,
287);
288
289struct DrawMeshInstanced;
290
291impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
292    type Param = (
293        SRes<RenderAssets<RenderMesh>>,
294        SRes<RenderMeshInstances>,
295        SRes<MeshAllocator>,
296    );
297    type ViewQuery = ();
298    type ItemQuery = Read<InstanceBuffer>;
299
300    #[inline]
301    fn render<'w>(
302        item: &P,
303        _view: (),
304        instance_buffer: Option<&'w InstanceBuffer>,
305        (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
306        pass: &mut TrackedRenderPass<'w>,
307    ) -> RenderCommandResult {
308        // A borrow check workaround.
309        let mesh_allocator = mesh_allocator.into_inner();
310
311        let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
312        else {
313            return RenderCommandResult::Skip;
314        };
315        let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id()) else {
316            return RenderCommandResult::Skip;
317        };
318        let Some(instance_buffer) = instance_buffer else {
319            return RenderCommandResult::Skip;
320        };
321        let Some(vertex_buffer_slice) =
322            mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id())
323        else {
324            return RenderCommandResult::Skip;
325        };
326
327        pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
328        pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
329
330        match &gpu_mesh.buffer_info {
331            RenderMeshBufferInfo::Indexed {
332                index_format,
333                count,
334            } => {
335                let Some(index_buffer_slice) =
336                    mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id())
337                else {
338                    return RenderCommandResult::Skip;
339                };
340
341                pass.set_index_buffer(index_buffer_slice.buffer.slice(..), *index_format);
342                pass.draw_indexed(
343                    index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
344                    vertex_buffer_slice.range.start as i32,
345                    0..instance_buffer.length as u32,
346                );
347            }
348            RenderMeshBufferInfo::NonIndexed => {
349                pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
350            }
351        }
352        RenderCommandResult::Success
353    }
354}