use bevy::asset::{load_internal_asset, uuid_handle};
use bevy::mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout};
use bevy::pbr::SetMeshViewBindingArrayBindGroup;
use bevy::render::RenderSystems;
use bevy::{
core_pipeline::core_3d::Transparent3d,
ecs::{
query::QueryItem,
system::{lifetimeless::*, SystemParamItem},
},
pbr::{
MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
},
prelude::*,
render::{
extract_component::{ExtractComponent, ExtractComponentPlugin},
mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
render_asset::RenderAssets,
render_phase::{
AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
},
render_resource::*,
renderer::RenderDevice,
sync_world::MainEntity,
view::ExtractedView,
Render, RenderApp,
},
};
use bytemuck::{Pod, Zeroable};
#[derive(Component)]
pub struct InstanceMaterialData {
pub instances: Vec<InstanceData>,
}
impl ExtractComponent for InstanceMaterialData {
type QueryData = &'static InstanceMaterialData;
type QueryFilter = ();
type Out = Self;
fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
Some(InstanceMaterialData {
instances: item.instances.clone(),
})
}
}
pub struct VoxelMaterialPlugin;
pub const SHADER_HANDLE: Handle<Shader> = uuid_handle!("123e4567-e89b-12d3-a456-426614174000");
impl Plugin for VoxelMaterialPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
app.add_plugins(ExtractComponentPlugin::<CameraPosition>::default());
app.sub_app_mut(RenderApp)
.add_render_command::<Transparent3d, DrawCustom>()
.init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
.add_systems(
Render,
(
queue_custom.in_set(RenderSystems::QueueMeshes),
prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
),
);
load_internal_asset!(
app,
SHADER_HANDLE,
"../assets/shaders/instancing.wgsl",
Shader::from_wgsl
);
}
fn finish(&self, app: &mut App) {
app.sub_app_mut(RenderApp).init_resource::<CustomPipeline>();
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct InstanceData {
pub position: [f32; 3],
pub scale: f32,
pub color: [f32; 4],
}
#[allow(clippy::too_many_arguments)]
fn queue_custom(
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
custom_pipeline: Res<CustomPipeline>,
mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
pipeline_cache: Res<PipelineCache>,
meshes: Res<RenderAssets<RenderMesh>>,
render_mesh_instances: Res<RenderMeshInstances>,
material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
views: Query<(&ExtractedView, &Msaa)>,
) {
let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
for (view, msaa) in &views {
let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
else {
continue;
};
let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
let rangefinder = view.rangefinder3d();
for (entity, main_entity) in &material_meshes {
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
else {
continue;
};
let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
continue;
};
let key =
view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
let pipeline = pipelines
.specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
.unwrap();
transparent_phase.add(Transparent3d {
entity: (entity, *main_entity),
pipeline,
draw_function: draw_custom,
distance: rangefinder.distance(&mesh_instance.center),
batch_range: 0..1,
extra_index: PhaseItemExtraIndex::None,
indexed: false,
});
}
}
}
#[derive(Component)]
struct InstanceBuffer {
buffer: Buffer,
length: usize,
}
#[derive(Component, Clone)]
struct CameraPosition(Vec3);
impl ExtractComponent for CameraPosition {
type QueryData = &'static GlobalTransform;
type QueryFilter = With<Camera3d>;
type Out = Self;
fn extract_component(transform: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
Some(CameraPosition(transform.translation()))
}
}
fn prepare_instance_buffers(
mut commands: Commands,
query: Query<(Entity, &InstanceMaterialData)>,
render_device: Res<RenderDevice>,
camera_query: Query<&CameraPosition>,
) {
let camera_pos = camera_query
.iter()
.next()
.map(|pos| pos.0)
.unwrap_or(Vec3::ZERO);
for (entity, instance_data) in &query {
if instance_data.instances.is_empty() {
commands.entity(entity).remove::<InstanceBuffer>();
continue;
}
let mut sorted_instances = instance_data.instances.clone();
sorted_instances.sort_by(|a, b| {
let dist_a = camera_pos.distance_squared(Vec3::from_slice(&a.position));
let dist_b = camera_pos.distance_squared(Vec3::from_slice(&b.position));
dist_b
.partial_cmp(&dist_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("instance data buffer"),
contents: bytemuck::cast_slice(sorted_instances.as_slice()),
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
});
commands.entity(entity).insert(InstanceBuffer {
buffer,
length: sorted_instances.len(),
});
}
}
#[derive(Resource)]
struct CustomPipeline {
shader: Handle<Shader>,
mesh_pipeline: MeshPipeline,
}
impl FromWorld for CustomPipeline {
fn from_world(world: &mut World) -> Self {
let mesh_pipeline = world.resource::<MeshPipeline>().clone();
CustomPipeline {
shader: SHADER_HANDLE.clone(),
mesh_pipeline,
}
}
}
impl SpecializedMeshPipeline for CustomPipeline {
type Key = MeshPipelineKey;
fn specialize(
&self,
key: Self::Key,
layout: &MeshVertexBufferLayoutRef,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
let color_format = TextureFormat::Rgba8UnormSrgb;
descriptor.depth_stencil = Some(DepthStencilState {
format: TextureFormat::Depth32Float,
depth_compare: CompareFunction::Always,
stencil: StencilState::default(),
depth_write_enabled: false,
bias: DepthBiasState::default(),
});
descriptor.fragment.as_mut().unwrap().targets[0] = Some(ColorTargetState {
format: color_format,
blend: Some(BlendState {
color: BlendComponent {
src_factor: BlendFactor::SrcAlpha,
dst_factor: BlendFactor::OneMinusSrcAlpha,
operation: BlendOperation::Add,
},
alpha: BlendComponent {
src_factor: BlendFactor::SrcAlpha,
dst_factor: BlendFactor::OneMinusSrcAlpha,
operation: BlendOperation::Add,
},
}),
write_mask: ColorWrites::ALL,
});
descriptor.vertex.shader = self.shader.clone();
descriptor.vertex.buffers.push(VertexBufferLayout {
array_stride: size_of::<InstanceData>() as u64,
step_mode: VertexStepMode::Instance,
attributes: vec![
VertexAttribute {
format: VertexFormat::Float32x4,
offset: 0,
shader_location: 3,
},
VertexAttribute {
format: VertexFormat::Float32x4,
offset: VertexFormat::Float32x4.size(),
shader_location: 4,
},
],
});
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
Ok(descriptor)
}
}
type DrawCustom = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMeshViewBindingArrayBindGroup<1>,
SetMeshBindGroup<2>,
DrawMeshInstanced,
);
struct DrawMeshInstanced;
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
type Param = (
SRes<RenderAssets<RenderMesh>>,
SRes<RenderMeshInstances>,
SRes<MeshAllocator>,
);
type ViewQuery = ();
type ItemQuery = Read<InstanceBuffer>;
#[inline]
fn render<'w>(
item: &P,
_view: (),
instance_buffer: Option<&'w InstanceBuffer>,
(meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let mesh_allocator = mesh_allocator.into_inner();
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
else {
return RenderCommandResult::Skip;
};
let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
return RenderCommandResult::Skip;
};
let Some(instance_buffer) = instance_buffer else {
return RenderCommandResult::Skip;
};
let Some(vertex_buffer_slice) =
mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
else {
return RenderCommandResult::Skip;
};
pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
match &gpu_mesh.buffer_info {
RenderMeshBufferInfo::Indexed {
index_format,
count,
} => {
let Some(index_buffer_slice) =
mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
else {
return RenderCommandResult::Skip;
};
pass.set_index_buffer(index_buffer_slice.buffer.slice(..), *index_format);
pass.draw_indexed(
index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
vertex_buffer_slice.range.start as i32,
0..instance_buffer.length as u32,
);
}
RenderMeshBufferInfo::NonIndexed => {
pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
}
}
RenderCommandResult::Success
}
}