custom_shader_instancing/
custom_shader_instancing.rs1use bevy::{
11 core_pipeline::core_3d::Transparent3d,
12 ecs::{
13 query::QueryItem,
14 system::{lifetimeless::*, SystemParamItem},
15 },
16 pbr::{
17 MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
18 },
19 prelude::*,
20 render::{
21 extract_component::{ExtractComponent, ExtractComponentPlugin},
22 mesh::{
23 allocator::MeshAllocator, MeshVertexBufferLayoutRef, RenderMesh, RenderMeshBufferInfo,
24 },
25 render_asset::RenderAssets,
26 render_phase::{
27 AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
28 RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
29 },
30 render_resource::*,
31 renderer::RenderDevice,
32 sync_world::MainEntity,
33 view::{ExtractedView, NoFrustumCulling, NoIndirectDrawing},
34 Render, RenderApp, RenderSet,
35 },
36};
37use bytemuck::{Pod, Zeroable};
38
39const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
41
42fn main() {
43 App::new()
44 .add_plugins((DefaultPlugins, CustomMaterialPlugin))
45 .add_systems(Startup, setup)
46 .run();
47}
48
49fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
50 commands.spawn((
51 Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
52 InstanceMaterialData(
53 (1..=10)
54 .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
55 .map(|(x, y)| InstanceData {
56 position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
57 scale: 1.0,
58 color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
59 })
60 .collect(),
61 ),
62 NoFrustumCulling,
70 ));
71
72 commands.spawn((
74 Camera3d::default(),
75 Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
76 NoIndirectDrawing,
80 ));
81}
82
83#[derive(Component, Deref)]
84struct InstanceMaterialData(Vec<InstanceData>);
85
86impl ExtractComponent for InstanceMaterialData {
87 type QueryData = &'static InstanceMaterialData;
88 type QueryFilter = ();
89 type Out = Self;
90
91 fn extract_component(item: QueryItem<'_, Self::QueryData>) -> Option<Self> {
92 Some(InstanceMaterialData(item.0.clone()))
93 }
94}
95
96struct CustomMaterialPlugin;
97
98impl Plugin for CustomMaterialPlugin {
99 fn build(&self, app: &mut App) {
100 app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
101 app.sub_app_mut(RenderApp)
102 .add_render_command::<Transparent3d, DrawCustom>()
103 .init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
104 .add_systems(
105 Render,
106 (
107 queue_custom.in_set(RenderSet::QueueMeshes),
108 prepare_instance_buffers.in_set(RenderSet::PrepareResources),
109 ),
110 );
111 }
112
113 fn finish(&self, app: &mut App) {
114 app.sub_app_mut(RenderApp).init_resource::<CustomPipeline>();
115 }
116}
117
118#[derive(Clone, Copy, Pod, Zeroable)]
119#[repr(C)]
120struct InstanceData {
121 position: Vec3,
122 scale: f32,
123 color: [f32; 4],
124}
125
126fn queue_custom(
127 transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
128 custom_pipeline: Res<CustomPipeline>,
129 mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
130 pipeline_cache: Res<PipelineCache>,
131 meshes: Res<RenderAssets<RenderMesh>>,
132 render_mesh_instances: Res<RenderMeshInstances>,
133 material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
134 mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
135 views: Query<(&ExtractedView, &Msaa)>,
136) {
137 let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
138
139 for (view, msaa) in &views {
140 let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
141 else {
142 continue;
143 };
144
145 let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
146
147 let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
148 let rangefinder = view.rangefinder3d();
149 for (entity, main_entity) in &material_meshes {
150 let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
151 else {
152 continue;
153 };
154 let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
155 continue;
156 };
157 let key =
158 view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
159 let pipeline = pipelines
160 .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
161 .unwrap();
162 transparent_phase.add(Transparent3d {
163 entity: (entity, *main_entity),
164 pipeline,
165 draw_function: draw_custom,
166 distance: rangefinder.distance_translation(&mesh_instance.translation),
167 batch_range: 0..1,
168 extra_index: PhaseItemExtraIndex::None,
169 indexed: true,
170 });
171 }
172 }
173}
174
175#[derive(Component)]
176struct InstanceBuffer {
177 buffer: Buffer,
178 length: usize,
179}
180
181fn prepare_instance_buffers(
182 mut commands: Commands,
183 query: Query<(Entity, &InstanceMaterialData)>,
184 render_device: Res<RenderDevice>,
185) {
186 for (entity, instance_data) in &query {
187 let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
188 label: Some("instance data buffer"),
189 contents: bytemuck::cast_slice(instance_data.as_slice()),
190 usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
191 });
192 commands.entity(entity).insert(InstanceBuffer {
193 buffer,
194 length: instance_data.len(),
195 });
196 }
197}
198
199#[derive(Resource)]
200struct CustomPipeline {
201 shader: Handle<Shader>,
202 mesh_pipeline: MeshPipeline,
203}
204
205impl FromWorld for CustomPipeline {
206 fn from_world(world: &mut World) -> Self {
207 let mesh_pipeline = world.resource::<MeshPipeline>();
208
209 CustomPipeline {
210 shader: world.load_asset(SHADER_ASSET_PATH),
211 mesh_pipeline: mesh_pipeline.clone(),
212 }
213 }
214}
215
216impl SpecializedMeshPipeline for CustomPipeline {
217 type Key = MeshPipelineKey;
218
219 fn specialize(
220 &self,
221 key: Self::Key,
222 layout: &MeshVertexBufferLayoutRef,
223 ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
224 let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
225
226 descriptor.vertex.shader = self.shader.clone();
227 descriptor.vertex.buffers.push(VertexBufferLayout {
228 array_stride: size_of::<InstanceData>() as u64,
229 step_mode: VertexStepMode::Instance,
230 attributes: vec![
231 VertexAttribute {
232 format: VertexFormat::Float32x4,
233 offset: 0,
234 shader_location: 3, },
236 VertexAttribute {
237 format: VertexFormat::Float32x4,
238 offset: VertexFormat::Float32x4.size(),
239 shader_location: 4,
240 },
241 ],
242 });
243 descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
244 Ok(descriptor)
245 }
246}
247
248type DrawCustom = (
249 SetItemPipeline,
250 SetMeshViewBindGroup<0>,
251 SetMeshBindGroup<1>,
252 DrawMeshInstanced,
253);
254
255struct DrawMeshInstanced;
256
257impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
258 type Param = (
259 SRes<RenderAssets<RenderMesh>>,
260 SRes<RenderMeshInstances>,
261 SRes<MeshAllocator>,
262 );
263 type ViewQuery = ();
264 type ItemQuery = Read<InstanceBuffer>;
265
266 #[inline]
267 fn render<'w>(
268 item: &P,
269 _view: (),
270 instance_buffer: Option<&'w InstanceBuffer>,
271 (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
272 pass: &mut TrackedRenderPass<'w>,
273 ) -> RenderCommandResult {
274 let mesh_allocator = mesh_allocator.into_inner();
276
277 let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
278 else {
279 return RenderCommandResult::Skip;
280 };
281 let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
282 return RenderCommandResult::Skip;
283 };
284 let Some(instance_buffer) = instance_buffer else {
285 return RenderCommandResult::Skip;
286 };
287 let Some(vertex_buffer_slice) =
288 mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
289 else {
290 return RenderCommandResult::Skip;
291 };
292
293 pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
294 pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
295
296 match &gpu_mesh.buffer_info {
297 RenderMeshBufferInfo::Indexed {
298 index_format,
299 count,
300 } => {
301 let Some(index_buffer_slice) =
302 mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
303 else {
304 return RenderCommandResult::Skip;
305 };
306
307 pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format);
308 pass.draw_indexed(
309 index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
310 vertex_buffer_slice.range.start as i32,
311 0..instance_buffer.length as u32,
312 );
313 }
314 RenderMeshBufferInfo::NonIndexed => {
315 pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
316 }
317 }
318 RenderCommandResult::Success
319 }
320}