custom_shader_instancing/
custom_shader_instancing.rs

1//! A shader that renders a mesh multiple times in one draw call.
2//!
3//! Bevy will automatically batch and instance your meshes assuming you use the same
4//! `Handle<Material>` and `Handle<Mesh>` for all of your instances.
5//!
6//! This example is intended for advanced users and shows how to make a custom instancing
7//! implementation using bevy's low level rendering api.
8//! It's generally recommended to try the built-in instancing before going with this approach.
9
10use bevy::{
11    core_pipeline::core_3d::Transparent3d,
12    ecs::{
13        query::QueryItem,
14        system::{lifetimeless::*, SystemParamItem},
15    },
16    pbr::{
17        MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
18    },
19    prelude::*,
20    render::{
21        extract_component::{ExtractComponent, ExtractComponentPlugin},
22        mesh::{
23            allocator::MeshAllocator, MeshVertexBufferLayoutRef, RenderMesh, RenderMeshBufferInfo,
24        },
25        render_asset::RenderAssets,
26        render_phase::{
27            AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
28            RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
29        },
30        render_resource::*,
31        renderer::RenderDevice,
32        sync_world::MainEntity,
33        view::{ExtractedView, NoFrustumCulling, NoIndirectDrawing},
34        Render, RenderApp, RenderSet,
35    },
36};
37use bytemuck::{Pod, Zeroable};
38
39/// This example uses a shader source file from the assets subdirectory
40const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
41
42fn main() {
43    App::new()
44        .add_plugins((DefaultPlugins, CustomMaterialPlugin))
45        .add_systems(Startup, setup)
46        .run();
47}
48
49fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
50    commands.spawn((
51        Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
52        InstanceMaterialData(
53            (1..=10)
54                .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
55                .map(|(x, y)| InstanceData {
56                    position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
57                    scale: 1.0,
58                    color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
59                })
60                .collect(),
61        ),
62        // NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
63        // As the cube is at the origin, if its Aabb moves outside the view frustum, all the
64        // instanced cubes will be culled.
65        // The InstanceMaterialData contains the 'GlobalTransform' information for this custom
66        // instancing, and that is not taken into account with the built-in frustum culling.
67        // We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
68        // component to avoid incorrect culling.
69        NoFrustumCulling,
70    ));
71
72    // camera
73    commands.spawn((
74        Camera3d::default(),
75        Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
76        // We need this component because we use `draw_indexed` and `draw`
77        // instead of `draw_indirect_indexed` and `draw_indirect` in
78        // `DrawMeshInstanced::render`.
79        NoIndirectDrawing,
80    ));
81}
82
83#[derive(Component, Deref)]
84struct InstanceMaterialData(Vec<InstanceData>);
85
86impl ExtractComponent for InstanceMaterialData {
87    type QueryData = &'static InstanceMaterialData;
88    type QueryFilter = ();
89    type Out = Self;
90
91    fn extract_component(item: QueryItem<'_, Self::QueryData>) -> Option<Self> {
92        Some(InstanceMaterialData(item.0.clone()))
93    }
94}
95
96struct CustomMaterialPlugin;
97
98impl Plugin for CustomMaterialPlugin {
99    fn build(&self, app: &mut App) {
100        app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
101        app.sub_app_mut(RenderApp)
102            .add_render_command::<Transparent3d, DrawCustom>()
103            .init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
104            .add_systems(
105                Render,
106                (
107                    queue_custom.in_set(RenderSet::QueueMeshes),
108                    prepare_instance_buffers.in_set(RenderSet::PrepareResources),
109                ),
110            );
111    }
112
113    fn finish(&self, app: &mut App) {
114        app.sub_app_mut(RenderApp).init_resource::<CustomPipeline>();
115    }
116}
117
118#[derive(Clone, Copy, Pod, Zeroable)]
119#[repr(C)]
120struct InstanceData {
121    position: Vec3,
122    scale: f32,
123    color: [f32; 4],
124}
125
126fn queue_custom(
127    transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
128    custom_pipeline: Res<CustomPipeline>,
129    mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
130    pipeline_cache: Res<PipelineCache>,
131    meshes: Res<RenderAssets<RenderMesh>>,
132    render_mesh_instances: Res<RenderMeshInstances>,
133    material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
134    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
135    views: Query<(&ExtractedView, &Msaa)>,
136) {
137    let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
138
139    for (view, msaa) in &views {
140        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
141        else {
142            continue;
143        };
144
145        let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
146
147        let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
148        let rangefinder = view.rangefinder3d();
149        for (entity, main_entity) in &material_meshes {
150            let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
151            else {
152                continue;
153            };
154            let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
155                continue;
156            };
157            let key =
158                view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
159            let pipeline = pipelines
160                .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
161                .unwrap();
162            transparent_phase.add(Transparent3d {
163                entity: (entity, *main_entity),
164                pipeline,
165                draw_function: draw_custom,
166                distance: rangefinder.distance_translation(&mesh_instance.translation),
167                batch_range: 0..1,
168                extra_index: PhaseItemExtraIndex::None,
169                indexed: true,
170            });
171        }
172    }
173}
174
175#[derive(Component)]
176struct InstanceBuffer {
177    buffer: Buffer,
178    length: usize,
179}
180
181fn prepare_instance_buffers(
182    mut commands: Commands,
183    query: Query<(Entity, &InstanceMaterialData)>,
184    render_device: Res<RenderDevice>,
185) {
186    for (entity, instance_data) in &query {
187        let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
188            label: Some("instance data buffer"),
189            contents: bytemuck::cast_slice(instance_data.as_slice()),
190            usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
191        });
192        commands.entity(entity).insert(InstanceBuffer {
193            buffer,
194            length: instance_data.len(),
195        });
196    }
197}
198
199#[derive(Resource)]
200struct CustomPipeline {
201    shader: Handle<Shader>,
202    mesh_pipeline: MeshPipeline,
203}
204
205impl FromWorld for CustomPipeline {
206    fn from_world(world: &mut World) -> Self {
207        let mesh_pipeline = world.resource::<MeshPipeline>();
208
209        CustomPipeline {
210            shader: world.load_asset(SHADER_ASSET_PATH),
211            mesh_pipeline: mesh_pipeline.clone(),
212        }
213    }
214}
215
216impl SpecializedMeshPipeline for CustomPipeline {
217    type Key = MeshPipelineKey;
218
219    fn specialize(
220        &self,
221        key: Self::Key,
222        layout: &MeshVertexBufferLayoutRef,
223    ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
224        let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
225
226        descriptor.vertex.shader = self.shader.clone();
227        descriptor.vertex.buffers.push(VertexBufferLayout {
228            array_stride: size_of::<InstanceData>() as u64,
229            step_mode: VertexStepMode::Instance,
230            attributes: vec![
231                VertexAttribute {
232                    format: VertexFormat::Float32x4,
233                    offset: 0,
234                    shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
235                },
236                VertexAttribute {
237                    format: VertexFormat::Float32x4,
238                    offset: VertexFormat::Float32x4.size(),
239                    shader_location: 4,
240                },
241            ],
242        });
243        descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
244        Ok(descriptor)
245    }
246}
247
248type DrawCustom = (
249    SetItemPipeline,
250    SetMeshViewBindGroup<0>,
251    SetMeshBindGroup<1>,
252    DrawMeshInstanced,
253);
254
255struct DrawMeshInstanced;
256
257impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
258    type Param = (
259        SRes<RenderAssets<RenderMesh>>,
260        SRes<RenderMeshInstances>,
261        SRes<MeshAllocator>,
262    );
263    type ViewQuery = ();
264    type ItemQuery = Read<InstanceBuffer>;
265
266    #[inline]
267    fn render<'w>(
268        item: &P,
269        _view: (),
270        instance_buffer: Option<&'w InstanceBuffer>,
271        (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
272        pass: &mut TrackedRenderPass<'w>,
273    ) -> RenderCommandResult {
274        // A borrow check workaround.
275        let mesh_allocator = mesh_allocator.into_inner();
276
277        let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
278        else {
279            return RenderCommandResult::Skip;
280        };
281        let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
282            return RenderCommandResult::Skip;
283        };
284        let Some(instance_buffer) = instance_buffer else {
285            return RenderCommandResult::Skip;
286        };
287        let Some(vertex_buffer_slice) =
288            mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
289        else {
290            return RenderCommandResult::Skip;
291        };
292
293        pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
294        pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
295
296        match &gpu_mesh.buffer_info {
297            RenderMeshBufferInfo::Indexed {
298                index_format,
299                count,
300            } => {
301                let Some(index_buffer_slice) =
302                    mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
303                else {
304                    return RenderCommandResult::Skip;
305                };
306
307                pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format);
308                pass.draw_indexed(
309                    index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
310                    vertex_buffer_slice.range.start as i32,
311                    0..instance_buffer.length as u32,
312                );
313            }
314            RenderMeshBufferInfo::NonIndexed => {
315                pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
316            }
317        }
318        RenderCommandResult::Success
319    }
320}