use crate::render::ShaderModelAsset;
use crate::{
draw::{Draw, DrawCommand, drawing::Drawing, primitive::Primitive},
render::{PreparedShaderModel, ShaderModel, queue_shader_model},
};
use bevy::pbr::{MATERIAL_BIND_GROUP_INDEX, SetMeshViewBindingArrayBindGroup};
use bevy::{
core_pipeline::core_3d::Transparent3d,
ecs::system::{SystemParamItem, lifetimeless::*},
pbr::{RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup},
prelude::*,
render::{
Render, RenderApp, RenderSystems,
extract_component::ExtractComponent,
extract_instances::ExtractedInstances,
mesh::{RenderMesh, RenderMeshBufferInfo, allocator::MeshAllocator},
render_asset::{RenderAssets, prepare_assets},
render_phase::{
AddRenderCommand, PhaseItem, RenderCommand, RenderCommandResult, SetItemPipeline,
TrackedRenderPass,
},
},
};
use std::{hash::Hash, marker::PhantomData, ops::Range};
pub struct Instanced<'a> {
draw: &'a Draw,
primitive_index: Option<usize>,
range: Option<Range<u32>>,
}
impl<'a> Drop for Instanced<'a> {
fn drop(&mut self) {
if let Some((index, data)) = self.primitive_index.take().zip(self.range.take()) {
self.insert_instanced_draw_command(index, data);
}
}
}
pub fn new(draw: &Draw) -> Instanced<'_> {
Instanced {
draw,
primitive_index: None,
range: None,
}
}
impl<'a> Instanced<'a> {
pub fn primitive<T>(mut self, drawing: Drawing<T>) -> Instanced<'a>
where
T: Into<Primitive>,
{
self.draw
.state
.write()
.unwrap()
.ignored_drawings
.insert(drawing.index);
self.primitive_index = Some(drawing.index);
self
}
pub fn range(mut self, range: Range<u32>) -> Instanced<'a> {
self.range = Some(range);
self
}
fn insert_instanced_draw_command(&self, index: usize, range: Range<u32>) {
let mut state = self.draw.state.write().unwrap();
let primitive = state.drawing.remove(&index).unwrap();
state
.draw_commands
.push(Some(DrawCommand::Instanced(primitive, range)));
}
}
#[derive(Component, ExtractComponent, Clone)]
pub struct InstancedMesh;
#[derive(Component, ExtractComponent, Clone)]
pub struct InstanceRange(pub Range<u32>);
pub struct InstancedShaderModelPlugin<SM>(PhantomData<SM>);
impl<SM> Default for InstancedShaderModelPlugin<SM>
where
SM: Default,
{
fn default() -> Self {
InstancedShaderModelPlugin(PhantomData)
}
}
impl<SM> Plugin for InstancedShaderModelPlugin<SM>
where
SM: ShaderModel,
SM::Data: PartialEq + Eq + Hash + Clone,
{
fn build(&self, app: &mut App) {
app.sub_app_mut(RenderApp)
.add_render_command::<Transparent3d, DrawInstancedShaderModel<SM>>()
.add_systems(
Render,
queue_shader_model::<SM, With<InstancedMesh>, DrawInstancedShaderModel<SM>>
.after(prepare_assets::<PreparedShaderModel<SM>>)
.in_set(RenderSystems::QueueMeshes),
);
}
}
type DrawInstancedShaderModel<SM> = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMeshViewBindingArrayBindGroup<1>,
SetMeshBindGroup<2>,
SetShaderModelBindGroup<SM, MATERIAL_BIND_GROUP_INDEX>,
DrawMeshInstanced,
);
struct SetShaderModelBindGroup<SM: ShaderModel, const I: usize>(PhantomData<SM>);
impl<P: PhaseItem, SM: ShaderModel, const I: usize> RenderCommand<P>
for SetShaderModelBindGroup<SM, I>
{
type Param = (
SRes<RenderAssets<PreparedShaderModel<SM>>>,
SRes<ExtractedInstances<ShaderModelAsset<SM>>>,
);
type ViewQuery = ();
type ItemQuery = ();
#[inline]
fn render<'w>(
item: &P,
_view: (),
_item_query: Option<()>,
(models, instances): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let models = models.into_inner();
let instances = instances.into_inner();
let Some(model_asset) = instances.get(&item.main_entity()) else {
return RenderCommandResult::Skip;
};
let Some(shader_model) = models.get(model_asset.0) else {
return RenderCommandResult::Skip;
};
pass.set_bind_group(I, &shader_model.bind_group, &[]);
RenderCommandResult::Success
}
}
struct DrawMeshInstanced;
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
type Param = (
SRes<RenderAssets<RenderMesh>>,
SRes<RenderMeshInstances>,
SRes<MeshAllocator>,
);
type ViewQuery = ();
type ItemQuery = Read<InstanceRange>;
#[inline]
fn render<'w>(
item: &P,
_view: (),
instance_range: Option<&'w InstanceRange>,
(meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let mesh_allocator = mesh_allocator.into_inner();
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
else {
return RenderCommandResult::Skip;
};
let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id()) else {
return RenderCommandResult::Skip;
};
let Some(instance_range) = instance_range else {
return RenderCommandResult::Skip;
};
let Some(vertex_buffer_slice) =
mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id())
else {
return RenderCommandResult::Skip;
};
pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
match &gpu_mesh.buffer_info {
RenderMeshBufferInfo::Indexed { index_format, .. } => {
let Some(index_buffer_slice) =
mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id())
else {
return RenderCommandResult::Skip;
};
pass.set_index_buffer(index_buffer_slice.buffer.slice(..), *index_format);
pass.draw_indexed(
index_buffer_slice.range.clone(),
0,
instance_range.0.clone(),
);
}
RenderMeshBufferInfo::NonIndexed => {
pass.draw(vertex_buffer_slice.range.clone(), instance_range.0.clone());
}
}
RenderCommandResult::Success
}
}