custom_shader_instancing/
custom_shader_instancing.rs1use bevy::pbr::SetMeshViewBindingArrayBindGroup;
11use bevy::{
12 camera::visibility::NoFrustumCulling,
13 core_pipeline::core_3d::Transparent3d,
14 ecs::{
15 query::QueryItem,
16 system::{lifetimeless::*, SystemParamItem},
17 },
18 mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout},
19 pbr::{
20 MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
21 },
22 prelude::*,
23 render::{
24 extract_component::{ExtractComponent, ExtractComponentPlugin},
25 mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
26 render_asset::RenderAssets,
27 render_phase::{
28 AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
29 RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
30 },
31 render_resource::*,
32 renderer::RenderDevice,
33 sync_world::MainEntity,
34 view::{ExtractedView, NoIndirectDrawing},
35 Render, RenderApp, RenderStartup, RenderSystems,
36 },
37};
38use bytemuck::{Pod, Zeroable};
39
40const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
42
43fn main() {
44 App::new()
45 .add_plugins((DefaultPlugins, CustomMaterialPlugin))
46 .add_systems(Startup, setup)
47 .run();
48}
49
50fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
51 commands.spawn((
52 Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
53 InstanceMaterialData(
54 (1..=10)
55 .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
56 .map(|(x, y)| InstanceData {
57 position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
58 scale: 1.0,
59 color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
60 })
61 .collect(),
62 ),
63 NoFrustumCulling,
71 ));
72
73 commands.spawn((
75 Camera3d::default(),
76 Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
77 NoIndirectDrawing,
81 ));
82}
83
84#[derive(Component, Deref)]
85struct InstanceMaterialData(Vec<InstanceData>);
86
87impl ExtractComponent for InstanceMaterialData {
88 type QueryData = &'static InstanceMaterialData;
89 type QueryFilter = ();
90 type Out = Self;
91
92 fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
93 Some(InstanceMaterialData(item.0.clone()))
94 }
95}
96
97struct CustomMaterialPlugin;
98
99impl Plugin for CustomMaterialPlugin {
100 fn build(&self, app: &mut App) {
101 app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
102 app.sub_app_mut(RenderApp)
103 .add_render_command::<Transparent3d, DrawCustom>()
104 .init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
105 .add_systems(RenderStartup, init_custom_pipeline)
106 .add_systems(
107 Render,
108 (
109 queue_custom.in_set(RenderSystems::QueueMeshes),
110 prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
111 ),
112 );
113 }
114}
115
116#[derive(Clone, Copy, Pod, Zeroable)]
117#[repr(C)]
118struct InstanceData {
119 position: Vec3,
120 scale: f32,
121 color: [f32; 4],
122}
123
124fn queue_custom(
125 transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
126 custom_pipeline: Res<CustomPipeline>,
127 mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
128 pipeline_cache: Res<PipelineCache>,
129 meshes: Res<RenderAssets<RenderMesh>>,
130 render_mesh_instances: Res<RenderMeshInstances>,
131 material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
132 mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
133 views: Query<(&ExtractedView, &Msaa)>,
134) {
135 let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
136
137 for (view, msaa) in &views {
138 let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
139 else {
140 continue;
141 };
142
143 let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
144
145 let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
146 let rangefinder = view.rangefinder3d();
147 for (entity, main_entity) in &material_meshes {
148 let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
149 else {
150 continue;
151 };
152 let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
153 continue;
154 };
155 let key =
156 view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
157 let pipeline = pipelines
158 .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
159 .unwrap();
160 transparent_phase.add(Transparent3d {
161 entity: (entity, *main_entity),
162 pipeline,
163 draw_function: draw_custom,
164 distance: rangefinder.distance_translation(&mesh_instance.translation),
165 batch_range: 0..1,
166 extra_index: PhaseItemExtraIndex::None,
167 indexed: true,
168 });
169 }
170 }
171}
172
173#[derive(Component)]
174struct InstanceBuffer {
175 buffer: Buffer,
176 length: usize,
177}
178
179fn prepare_instance_buffers(
180 mut commands: Commands,
181 query: Query<(Entity, &InstanceMaterialData)>,
182 render_device: Res<RenderDevice>,
183) {
184 for (entity, instance_data) in &query {
185 let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
186 label: Some("instance data buffer"),
187 contents: bytemuck::cast_slice(instance_data.as_slice()),
188 usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
189 });
190 commands.entity(entity).insert(InstanceBuffer {
191 buffer,
192 length: instance_data.len(),
193 });
194 }
195}
196
197#[derive(Resource)]
198struct CustomPipeline {
199 shader: Handle<Shader>,
200 mesh_pipeline: MeshPipeline,
201}
202
203fn init_custom_pipeline(
204 mut commands: Commands,
205 asset_server: Res<AssetServer>,
206 mesh_pipeline: Res<MeshPipeline>,
207) {
208 commands.insert_resource(CustomPipeline {
209 shader: asset_server.load(SHADER_ASSET_PATH),
210 mesh_pipeline: mesh_pipeline.clone(),
211 });
212}
213
214impl SpecializedMeshPipeline for CustomPipeline {
215 type Key = MeshPipelineKey;
216
217 fn specialize(
218 &self,
219 key: Self::Key,
220 layout: &MeshVertexBufferLayoutRef,
221 ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
222 let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
223
224 descriptor.vertex.shader = self.shader.clone();
225 descriptor.vertex.buffers.push(VertexBufferLayout {
226 array_stride: size_of::<InstanceData>() as u64,
227 step_mode: VertexStepMode::Instance,
228 attributes: vec![
229 VertexAttribute {
230 format: VertexFormat::Float32x4,
231 offset: 0,
232 shader_location: 3, },
234 VertexAttribute {
235 format: VertexFormat::Float32x4,
236 offset: VertexFormat::Float32x4.size(),
237 shader_location: 4,
238 },
239 ],
240 });
241 descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
242 Ok(descriptor)
243 }
244}
245
246type DrawCustom = (
247 SetItemPipeline,
248 SetMeshViewBindGroup<0>,
249 SetMeshViewBindingArrayBindGroup<1>,
250 SetMeshBindGroup<2>,
251 DrawMeshInstanced,
252);
253
254struct DrawMeshInstanced;
255
256impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
257 type Param = (
258 SRes<RenderAssets<RenderMesh>>,
259 SRes<RenderMeshInstances>,
260 SRes<MeshAllocator>,
261 );
262 type ViewQuery = ();
263 type ItemQuery = Read<InstanceBuffer>;
264
265 #[inline]
266 fn render<'w>(
267 item: &P,
268 _view: (),
269 instance_buffer: Option<&'w InstanceBuffer>,
270 (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
271 pass: &mut TrackedRenderPass<'w>,
272 ) -> RenderCommandResult {
273 let mesh_allocator = mesh_allocator.into_inner();
275
276 let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
277 else {
278 return RenderCommandResult::Skip;
279 };
280 let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
281 return RenderCommandResult::Skip;
282 };
283 let Some(instance_buffer) = instance_buffer else {
284 return RenderCommandResult::Skip;
285 };
286 let Some(vertex_buffer_slice) =
287 mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
288 else {
289 return RenderCommandResult::Skip;
290 };
291
292 pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
293 pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
294
295 match &gpu_mesh.buffer_info {
296 RenderMeshBufferInfo::Indexed {
297 index_format,
298 count,
299 } => {
300 let Some(index_buffer_slice) =
301 mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
302 else {
303 return RenderCommandResult::Skip;
304 };
305
306 pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format);
307 pass.draw_indexed(
308 index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
309 vertex_buffer_slice.range.start as i32,
310 0..instance_buffer.length as u32,
311 );
312 }
313 RenderMeshBufferInfo::NonIndexed => {
314 pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
315 }
316 }
317 RenderCommandResult::Success
318 }
319}