custom_shader_instancing/
custom_shader_instancing.rs

1//! A shader that renders a mesh multiple times in one draw call.
2//!
3//! Bevy will automatically batch and instance your meshes assuming you use the same
4//! `Handle<Material>` and `Handle<Mesh>` for all of your instances.
5//!
6//! This example is intended for advanced users and shows how to make a custom instancing
7//! implementation using bevy's low level rendering api.
8//! It's generally recommended to try the built-in instancing before going with this approach.
9
10use bevy::pbr::SetMeshViewBindingArrayBindGroup;
11use bevy::{
12    camera::visibility::NoFrustumCulling,
13    core_pipeline::core_3d::Transparent3d,
14    ecs::{
15        query::QueryItem,
16        system::{lifetimeless::*, SystemParamItem},
17    },
18    mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout},
19    pbr::{
20        MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
21    },
22    prelude::*,
23    render::{
24        extract_component::{ExtractComponent, ExtractComponentPlugin},
25        mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
26        render_asset::RenderAssets,
27        render_phase::{
28            AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
29            RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
30        },
31        render_resource::*,
32        renderer::RenderDevice,
33        sync_world::MainEntity,
34        view::{ExtractedView, NoIndirectDrawing},
35        Render, RenderApp, RenderStartup, RenderSystems,
36    },
37};
38use bytemuck::{Pod, Zeroable};
39
40/// This example uses a shader source file from the assets subdirectory
41const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
42
43fn main() {
44    App::new()
45        .add_plugins((DefaultPlugins, CustomMaterialPlugin))
46        .add_systems(Startup, setup)
47        .run();
48}
49
50fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
51    commands.spawn((
52        Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
53        InstanceMaterialData(
54            (1..=10)
55                .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
56                .map(|(x, y)| InstanceData {
57                    position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
58                    scale: 1.0,
59                    color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
60                })
61                .collect(),
62        ),
63        // NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
64        // As the cube is at the origin, if its Aabb moves outside the view frustum, all the
65        // instanced cubes will be culled.
66        // The InstanceMaterialData contains the 'GlobalTransform' information for this custom
67        // instancing, and that is not taken into account with the built-in frustum culling.
68        // We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
69        // component to avoid incorrect culling.
70        NoFrustumCulling,
71    ));
72
73    // camera
74    commands.spawn((
75        Camera3d::default(),
76        Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
77        // We need this component because we use `draw_indexed` and `draw`
78        // instead of `draw_indirect_indexed` and `draw_indirect` in
79        // `DrawMeshInstanced::render`.
80        NoIndirectDrawing,
81    ));
82}
83
84#[derive(Component, Deref)]
85struct InstanceMaterialData(Vec<InstanceData>);
86
87impl ExtractComponent for InstanceMaterialData {
88    type QueryData = &'static InstanceMaterialData;
89    type QueryFilter = ();
90    type Out = Self;
91
92    fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
93        Some(InstanceMaterialData(item.0.clone()))
94    }
95}
96
97struct CustomMaterialPlugin;
98
99impl Plugin for CustomMaterialPlugin {
100    fn build(&self, app: &mut App) {
101        app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
102        app.sub_app_mut(RenderApp)
103            .add_render_command::<Transparent3d, DrawCustom>()
104            .init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
105            .add_systems(RenderStartup, init_custom_pipeline)
106            .add_systems(
107                Render,
108                (
109                    queue_custom.in_set(RenderSystems::QueueMeshes),
110                    prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
111                ),
112            );
113    }
114}
115
116#[derive(Clone, Copy, Pod, Zeroable)]
117#[repr(C)]
118struct InstanceData {
119    position: Vec3,
120    scale: f32,
121    color: [f32; 4],
122}
123
124fn queue_custom(
125    transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
126    custom_pipeline: Res<CustomPipeline>,
127    mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
128    pipeline_cache: Res<PipelineCache>,
129    meshes: Res<RenderAssets<RenderMesh>>,
130    render_mesh_instances: Res<RenderMeshInstances>,
131    material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
132    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
133    views: Query<(&ExtractedView, &Msaa)>,
134) {
135    let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
136
137    for (view, msaa) in &views {
138        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
139        else {
140            continue;
141        };
142
143        let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
144
145        let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
146        let rangefinder = view.rangefinder3d();
147        for (entity, main_entity) in &material_meshes {
148            let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
149            else {
150                continue;
151            };
152            let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
153                continue;
154            };
155            let key =
156                view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
157            let pipeline = pipelines
158                .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
159                .unwrap();
160            transparent_phase.add(Transparent3d {
161                entity: (entity, *main_entity),
162                pipeline,
163                draw_function: draw_custom,
164                distance: rangefinder.distance_translation(&mesh_instance.translation),
165                batch_range: 0..1,
166                extra_index: PhaseItemExtraIndex::None,
167                indexed: true,
168            });
169        }
170    }
171}
172
173#[derive(Component)]
174struct InstanceBuffer {
175    buffer: Buffer,
176    length: usize,
177}
178
179fn prepare_instance_buffers(
180    mut commands: Commands,
181    query: Query<(Entity, &InstanceMaterialData)>,
182    render_device: Res<RenderDevice>,
183) {
184    for (entity, instance_data) in &query {
185        let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
186            label: Some("instance data buffer"),
187            contents: bytemuck::cast_slice(instance_data.as_slice()),
188            usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
189        });
190        commands.entity(entity).insert(InstanceBuffer {
191            buffer,
192            length: instance_data.len(),
193        });
194    }
195}
196
197#[derive(Resource)]
198struct CustomPipeline {
199    shader: Handle<Shader>,
200    mesh_pipeline: MeshPipeline,
201}
202
203fn init_custom_pipeline(
204    mut commands: Commands,
205    asset_server: Res<AssetServer>,
206    mesh_pipeline: Res<MeshPipeline>,
207) {
208    commands.insert_resource(CustomPipeline {
209        shader: asset_server.load(SHADER_ASSET_PATH),
210        mesh_pipeline: mesh_pipeline.clone(),
211    });
212}
213
214impl SpecializedMeshPipeline for CustomPipeline {
215    type Key = MeshPipelineKey;
216
217    fn specialize(
218        &self,
219        key: Self::Key,
220        layout: &MeshVertexBufferLayoutRef,
221    ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
222        let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
223
224        descriptor.vertex.shader = self.shader.clone();
225        descriptor.vertex.buffers.push(VertexBufferLayout {
226            array_stride: size_of::<InstanceData>() as u64,
227            step_mode: VertexStepMode::Instance,
228            attributes: vec![
229                VertexAttribute {
230                    format: VertexFormat::Float32x4,
231                    offset: 0,
232                    shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
233                },
234                VertexAttribute {
235                    format: VertexFormat::Float32x4,
236                    offset: VertexFormat::Float32x4.size(),
237                    shader_location: 4,
238                },
239            ],
240        });
241        descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
242        Ok(descriptor)
243    }
244}
245
246type DrawCustom = (
247    SetItemPipeline,
248    SetMeshViewBindGroup<0>,
249    SetMeshViewBindingArrayBindGroup<1>,
250    SetMeshBindGroup<2>,
251    DrawMeshInstanced,
252);
253
254struct DrawMeshInstanced;
255
256impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
257    type Param = (
258        SRes<RenderAssets<RenderMesh>>,
259        SRes<RenderMeshInstances>,
260        SRes<MeshAllocator>,
261    );
262    type ViewQuery = ();
263    type ItemQuery = Read<InstanceBuffer>;
264
265    #[inline]
266    fn render<'w>(
267        item: &P,
268        _view: (),
269        instance_buffer: Option<&'w InstanceBuffer>,
270        (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
271        pass: &mut TrackedRenderPass<'w>,
272    ) -> RenderCommandResult {
273        // A borrow check workaround.
274        let mesh_allocator = mesh_allocator.into_inner();
275
276        let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
277        else {
278            return RenderCommandResult::Skip;
279        };
280        let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
281            return RenderCommandResult::Skip;
282        };
283        let Some(instance_buffer) = instance_buffer else {
284            return RenderCommandResult::Skip;
285        };
286        let Some(vertex_buffer_slice) =
287            mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
288        else {
289            return RenderCommandResult::Skip;
290        };
291
292        pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
293        pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
294
295        match &gpu_mesh.buffer_info {
296            RenderMeshBufferInfo::Indexed {
297                index_format,
298                count,
299            } => {
300                let Some(index_buffer_slice) =
301                    mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
302                else {
303                    return RenderCommandResult::Skip;
304                };
305
306                pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format);
307                pass.draw_indexed(
308                    index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
309                    vertex_buffer_slice.range.start as i32,
310                    0..instance_buffer.length as u32,
311                );
312            }
313            RenderMeshBufferInfo::NonIndexed => {
314                pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
315            }
316        }
317        RenderCommandResult::Success
318    }
319}