custom_shader_instancing/
custom_shader_instancing.rs1use bevy::core_pipeline::core_3d::TransparentSortingInfo3d;
11use bevy::pbr::{
12 self, MeshInputUniform, MeshPipelineSystems, MeshUniform, SetMeshViewBindingArrayBindGroup,
13 ViewKeyCache,
14};
15use bevy::{
16 camera::visibility::NoFrustumCulling,
17 core_pipeline::core_3d::Transparent3d,
18 ecs::{
19 query::QueryItem,
20 system::{lifetimeless::*, SystemParamItem},
21 },
22 mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout},
23 pbr::{
24 MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
25 },
26 prelude::*,
27 render::{
28 batching::gpu_preprocessing::BatchedInstanceBuffers,
29 extract_component::{ExtractComponent, ExtractComponentPlugin},
30 mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
31 render_asset::RenderAssets,
32 render_phase::{
33 AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
34 RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
35 },
36 render_resource::*,
37 renderer::RenderDevice,
38 sync_component::SyncComponent,
39 sync_world::MainEntity,
40 view::{ExtractedView, NoIndirectDrawing},
41 Render, RenderApp, RenderStartup, RenderSystems,
42 },
43};
44use bytemuck::{Pod, Zeroable};
45
46const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
48
49fn main() {
50 App::new()
51 .add_plugins((DefaultPlugins, CustomMaterialPlugin))
52 .add_systems(Startup, setup)
53 .run();
54}
55
56fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
57 commands.spawn((
58 Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
59 InstanceMaterialData(
60 (1..=10)
61 .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
62 .map(|(x, y)| InstanceData {
63 position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
64 scale: 1.0,
65 color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
66 })
67 .collect(),
68 ),
69 NoFrustumCulling,
77 ));
78
79 commands.spawn((
81 Camera3d::default(),
82 Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
83 NoIndirectDrawing,
87 ));
88}
89
90#[derive(Component, Deref)]
91struct InstanceMaterialData(Vec<InstanceData>);
92
93impl SyncComponent for InstanceMaterialData {
94 type Target = Self;
95}
96
97impl ExtractComponent for InstanceMaterialData {
98 type QueryData = &'static InstanceMaterialData;
99 type QueryFilter = ();
100 type Out = Self;
101
102 fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
103 Some(InstanceMaterialData(item.0.clone()))
104 }
105}
106
107struct CustomMaterialPlugin;
108
109impl Plugin for CustomMaterialPlugin {
110 fn build(&self, app: &mut App) {
111 app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
112 app.sub_app_mut(RenderApp)
113 .add_render_command::<Transparent3d, DrawCustom>()
114 .init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
115 .add_systems(
116 RenderStartup,
117 init_custom_pipeline.after(MeshPipelineSystems),
118 )
119 .add_systems(
120 Render,
121 (
122 queue_custom.in_set(RenderSystems::QueueMeshes),
123 prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
124 ),
125 );
126 }
127}
128
129#[derive(Clone, Copy, Pod, Zeroable)]
130#[repr(C)]
131struct InstanceData {
132 position: Vec3,
133 scale: f32,
134 color: [f32; 4],
135}
136
137fn queue_custom(
138 transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
139 custom_pipeline: Res<CustomPipeline>,
140 mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
141 pipeline_cache: Res<PipelineCache>,
142 meshes: Res<RenderAssets<RenderMesh>>,
143 render_mesh_instances: Res<RenderMeshInstances>,
144 maybe_batched_instance_buffers: Option<
145 Res<BatchedInstanceBuffers<MeshUniform, MeshInputUniform>>,
146 >,
147 material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
148 mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
149 views: Query<&ExtractedView>,
150 view_key_cache: Res<ViewKeyCache>,
151) {
152 let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
153
154 for view in &views {
155 let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
156 else {
157 continue;
158 };
159
160 let Some(&view_key) = view_key_cache.get(&view.retained_view_entity) else {
161 continue;
162 };
163
164 for (entity, main_entity) in &material_meshes {
165 let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
166 else {
167 continue;
168 };
169 let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id()) else {
170 continue;
171 };
172 let key = view_key
173 | MeshPipelineKey::from_primitive_topology_and_strip_index(
174 mesh.primitive_topology(),
175 mesh.index_format(),
176 );
177 let pipeline = pipelines
178 .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
179 .unwrap();
180 transparent_phase.add_retained(Transparent3d {
181 sorting_info: TransparentSortingInfo3d::Sorted {
182 mesh_center: pbr::get_mesh_instance_world_from_local(
183 *main_entity,
184 mesh_instance.current_uniform_index,
185 &render_mesh_instances,
186 maybe_batched_instance_buffers.as_deref(),
187 )
188 .transform_point3(
189 meshes
190 .get(mesh_instance.mesh_asset_id())
191 .unwrap()
192 .aabb_center,
193 ),
194 depth_bias: 0.0,
195 },
196 entity: (entity, *main_entity),
197 pipeline,
198 draw_function: draw_custom,
199 distance: 0.0,
200 batch_range: 0..1,
201 extra_index: PhaseItemExtraIndex::None,
202 indexed: true,
203 });
204 }
205 }
206}
207
208#[derive(Component)]
209struct InstanceBuffer {
210 buffer: Buffer,
211 length: usize,
212}
213
214fn prepare_instance_buffers(
215 mut commands: Commands,
216 query: Query<(Entity, &InstanceMaterialData)>,
217 render_device: Res<RenderDevice>,
218) {
219 for (entity, instance_data) in &query {
220 let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
221 label: Some("instance data buffer"),
222 contents: bytemuck::cast_slice(instance_data.as_slice()),
223 usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
224 });
225 commands.entity(entity).insert(InstanceBuffer {
226 buffer,
227 length: instance_data.len(),
228 });
229 }
230}
231
232#[derive(Resource)]
233struct CustomPipeline {
234 shader: Handle<Shader>,
235 mesh_pipeline: MeshPipeline,
236}
237
238fn init_custom_pipeline(
239 mut commands: Commands,
240 asset_server: Res<AssetServer>,
241 mesh_pipeline: Res<MeshPipeline>,
242) {
243 commands.insert_resource(CustomPipeline {
244 shader: asset_server.load(SHADER_ASSET_PATH),
245 mesh_pipeline: mesh_pipeline.clone(),
246 });
247}
248
249impl SpecializedMeshPipeline for CustomPipeline {
250 type Key = MeshPipelineKey;
251
252 fn specialize(
253 &self,
254 key: Self::Key,
255 layout: &MeshVertexBufferLayoutRef,
256 ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
257 let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
258
259 descriptor.vertex.shader = self.shader.clone();
260 descriptor.vertex.buffers.push(VertexBufferLayout {
261 array_stride: size_of::<InstanceData>() as u64,
262 step_mode: VertexStepMode::Instance,
263 attributes: vec![
264 VertexAttribute {
265 format: VertexFormat::Float32x4,
266 offset: 0,
267 shader_location: 3, },
269 VertexAttribute {
270 format: VertexFormat::Float32x4,
271 offset: VertexFormat::Float32x4.size(),
272 shader_location: 4,
273 },
274 ],
275 });
276 descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
277 Ok(descriptor)
278 }
279}
280
281type DrawCustom = (
282 SetItemPipeline,
283 SetMeshViewBindGroup<0>,
284 SetMeshViewBindingArrayBindGroup<1>,
285 SetMeshBindGroup<2>,
286 DrawMeshInstanced,
287);
288
289struct DrawMeshInstanced;
290
291impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
292 type Param = (
293 SRes<RenderAssets<RenderMesh>>,
294 SRes<RenderMeshInstances>,
295 SRes<MeshAllocator>,
296 );
297 type ViewQuery = ();
298 type ItemQuery = Read<InstanceBuffer>;
299
300 #[inline]
301 fn render<'w>(
302 item: &P,
303 _view: (),
304 instance_buffer: Option<&'w InstanceBuffer>,
305 (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
306 pass: &mut TrackedRenderPass<'w>,
307 ) -> RenderCommandResult {
308 let mesh_allocator = mesh_allocator.into_inner();
310
311 let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
312 else {
313 return RenderCommandResult::Skip;
314 };
315 let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id()) else {
316 return RenderCommandResult::Skip;
317 };
318 let Some(instance_buffer) = instance_buffer else {
319 return RenderCommandResult::Skip;
320 };
321 let Some(vertex_buffer_slice) =
322 mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id())
323 else {
324 return RenderCommandResult::Skip;
325 };
326
327 pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
328 pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
329
330 match &gpu_mesh.buffer_info {
331 RenderMeshBufferInfo::Indexed {
332 index_format,
333 count,
334 } => {
335 let Some(index_buffer_slice) =
336 mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id())
337 else {
338 return RenderCommandResult::Skip;
339 };
340
341 pass.set_index_buffer(index_buffer_slice.buffer.slice(..), *index_format);
342 pass.draw_indexed(
343 index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
344 vertex_buffer_slice.range.start as i32,
345 0..instance_buffer.length as u32,
346 );
347 }
348 RenderMeshBufferInfo::NonIndexed => {
349 pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
350 }
351 }
352 RenderCommandResult::Success
353 }
354}