use bevy::asset::{AssetServer, Handle, embedded_asset};
use bevy::ecs::system::Commands;
use bevy::prelude::*;
use bevy::render::render_resource::*;
use bevy::render::renderer::{RenderDevice, RenderQueue};
use bevy::render::view::ExtractedView;
use bevy::shader::Shader;
use bytemuck::{Pod, Zeroable};
use crate::render::{
ExtractedSplatCloudSettings, ExtractedSplatTransform, SparkSettings, SplatCloudGpu,
SplatCloudGpuStorage, SplatIndexStorage, SplatPerViewResources, effective_sort_settings,
experimental_gpu_sort_enabled, gpu_compute_sort_supported, storage_buffer_fits_device,
};
const SORT_SHADER_PATH: &str = "embedded://bevy_spark/render/sort.wgsl";
const SORT_PARAM_SLOTS: u32 = 5;
const SORT_PARAM_SIZE: u64 = core::mem::size_of::<SortParams>() as u64;
fn sort_params_stride(render_device: &RenderDevice) -> u64 {
let align = render_device
.limits()
.min_uniform_buffer_offset_alignment
.max(1) as u64;
SORT_PARAM_SIZE.div_ceil(align) * align
}
fn sort_param_offset(slot: u32, stride: u64) -> u64 {
u64::from(slot) * stride
}
#[derive(Resource)]
pub struct SortPipeline {
pub bind_layout: BindGroupLayout,
pub compute_keys: CachedComputePipelineId,
pub histogram_a: CachedComputePipelineId,
pub histogram_b: CachedComputePipelineId,
pub prefix_sum: CachedComputePipelineId,
pub scatter_a_to_b: CachedComputePipelineId,
pub scatter_b_to_a: CachedComputePipelineId,
}
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable, Debug)]
pub struct SortParams {
pub num_splats: u32,
pub pass_index: u32,
pub num_blocks: u32,
pub _pad: u32,
pub view_pos: [f32; 4], }
#[derive(Component)]
pub struct SortBuffers {
pub upload_id: u64,
pub keys_a: Buffer,
pub keys_b: Buffer,
pub indices_b: Buffer,
pub block_offsets: Buffer,
pub offsets: Buffer,
pub params: Buffer,
}
fn sort_layout_entries() -> Vec<BindGroupLayoutEntry> {
let mut v = Vec::with_capacity(8);
v.push(BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
min_binding_size: None,
},
count: None,
});
v.push(BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
for b in 2..=7u32 {
v.push(BindGroupLayoutEntry {
binding: b,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
}
v
}
pub fn init_sort_pipeline(
mut commands: Commands,
asset_server: Res<AssetServer>,
render_device: Res<RenderDevice>,
pipeline_cache: Res<PipelineCache>,
) {
let entries = sort_layout_entries();
let bind_layout = render_device.create_bind_group_layout("spark.sort_layout", &entries);
let layout_desc = BindGroupLayoutDescriptor::new("spark.sort_layout", &entries);
let shader: Handle<Shader> = asset_server.load(SORT_SHADER_PATH);
let make_pipeline = |label: &'static str, entry: &'static str| {
pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some(label.into()),
layout: vec![layout_desc.clone()],
push_constant_ranges: vec![],
shader: shader.clone(),
shader_defs: vec![],
entry_point: Some(entry.into()),
zero_initialize_workgroup_memory: false,
})
};
commands.insert_resource(SortPipeline {
bind_layout,
compute_keys: make_pipeline("spark.sort.compute_keys", "compute_keys"),
histogram_a: make_pipeline("spark.sort.hist_a", "histogram_pass"),
histogram_b: make_pipeline("spark.sort.hist_b", "histogram_pass_b"),
prefix_sum: make_pipeline("spark.sort.prefix", "prefix_sum"),
scatter_a_to_b: make_pipeline("spark.sort.scatter_ab", "scatter"),
scatter_b_to_a: make_pipeline("spark.sort.scatter_ba", "scatter_b"),
});
}
pub fn ensure_sort_buffers(
mut commands: Commands,
render_device: Res<RenderDevice>,
settings: Res<SparkSettings>,
clouds: Query<(
Entity,
&SplatCloudGpu,
Option<&SortBuffers>,
Option<&ExtractedSplatCloudSettings>,
)>,
) {
for (entity, gpu, maybe_sort, cloud_settings) in &clouds {
let sort_settings = effective_sort_settings(&settings, cloud_settings);
if !experimental_gpu_sort_enabled(sort_settings)
|| !gpu_compute_sort_supported(&render_device)
{
continue;
}
if maybe_sort.is_some_and(|sort| sort.upload_id == gpu.upload_id) {
continue;
}
let n = gpu.num_splats as u64;
let blocks = n.div_ceil(256);
let params_stride = sort_params_stride(&render_device);
let keys_size = n * 4;
let block_offsets_size = blocks * 256 * 4;
if !storage_buffer_fits_device(&render_device, "GPU sort key buffer", keys_size)
|| !storage_buffer_fits_device(&render_device, "GPU sort index buffer", keys_size)
|| !storage_buffer_fits_device(
&render_device,
"GPU sort block-offset buffer",
block_offsets_size,
)
{
continue;
}
let keys_a = render_device.create_buffer(&BufferDescriptor {
label: Some("spark.sort.keys_a"),
size: keys_size,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let keys_b = render_device.create_buffer(&BufferDescriptor {
label: Some("spark.sort.keys_b"),
size: keys_size,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let indices_b = render_device.create_buffer(&BufferDescriptor {
label: Some("spark.sort.indices_b"),
size: keys_size,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let block_offsets = render_device.create_buffer(&BufferDescriptor {
label: Some("spark.sort.block_offsets"),
size: block_offsets_size,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let offsets = render_device.create_buffer(&BufferDescriptor {
label: Some("spark.sort.offsets"),
size: 256 * 4,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let params = render_device.create_buffer(&BufferDescriptor {
label: Some("spark.sort.params"),
size: params_stride * u64::from(SORT_PARAM_SLOTS),
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
commands.entity(entity).insert(SortBuffers {
upload_id: gpu.upload_id,
keys_a,
keys_b,
indices_b,
block_offsets,
offsets,
params,
});
}
}
pub fn run_gpu_sort(
pipeline: Res<SortPipeline>,
pipeline_cache: Res<PipelineCache>,
render_device: Res<RenderDevice>,
render_queue: Res<RenderQueue>,
settings: Res<SparkSettings>,
views: Query<(Entity, &ExtractedView)>,
mut clouds: Query<(
&SplatCloudGpu,
&SortBuffers,
&ExtractedSplatTransform,
Option<&ExtractedSplatCloudSettings>,
&mut SplatPerViewResources,
)>,
) {
let pipes_ready = [
pipeline.compute_keys,
pipeline.histogram_a,
pipeline.histogram_b,
pipeline.prefix_sum,
pipeline.scatter_a_to_b,
pipeline.scatter_b_to_a,
]
.into_iter()
.all(|p| {
matches!(
pipeline_cache.get_compute_pipeline_state(p),
CachedPipelineState::Ok(_)
)
});
if !pipes_ready {
return;
}
let compute_keys_pl = pipeline_cache
.get_compute_pipeline(pipeline.compute_keys)
.unwrap();
let prefix_sum_pl = pipeline_cache
.get_compute_pipeline(pipeline.prefix_sum)
.unwrap();
let hist_a_pl = pipeline_cache
.get_compute_pipeline(pipeline.histogram_a)
.unwrap();
let hist_b_pl = pipeline_cache
.get_compute_pipeline(pipeline.histogram_b)
.unwrap();
let scatter_ab_pl = pipeline_cache
.get_compute_pipeline(pipeline.scatter_a_to_b)
.unwrap();
let scatter_ba_pl = pipeline_cache
.get_compute_pipeline(pipeline.scatter_b_to_a)
.unwrap();
let params_stride = sort_params_stride(&render_device);
for (gpu, sort, xf, cloud_settings, mut resources) in &mut clouds {
let SplatCloudGpuStorage::StorageBuffers { splats_buffer, .. } = &gpu.storage else {
continue;
};
let sort_settings = effective_sort_settings(&settings, cloud_settings);
if !experimental_gpu_sort_enabled(sort_settings)
|| !gpu_compute_sort_supported(&render_device)
{
continue;
}
let n = gpu.num_splats;
if n == 0 {
continue;
}
let inv = xf.0.inverse();
let workgroups = n.div_ceil(256);
for (view_entity, view) in &views {
let Some(entry) = resources.entries.get_mut(&view_entity) else {
continue;
};
let SplatIndexStorage::Buffer(indices_buffer) = &entry.indices else {
continue;
};
entry.visible_count = n;
entry.lod_selected_count = n;
entry.lod_total_count = n;
let view_world = view.world_from_view.translation();
let view_local = inv.transform_point3(view_world);
let params = SortParams {
num_splats: n,
pass_index: 0,
num_blocks: workgroups,
_pad: 0,
view_pos: [view_local.x, view_local.y, view_local.z, 0.0],
};
for slot in 0..SORT_PARAM_SLOTS {
let mut slot_params = params;
slot_params.pass_index = slot.saturating_sub(1);
render_queue.write_buffer(
&sort.params,
sort_param_offset(slot, params_stride),
bytemuck::cast_slice(&[slot_params]),
);
}
let bg = render_device.create_bind_group(
"spark.sort_bg",
&pipeline.bind_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sort.params,
offset: 0,
size: BufferSize::new(SORT_PARAM_SIZE),
}),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::Buffer(BufferBinding {
buffer: splats_buffer,
offset: 0,
size: None,
}),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sort.keys_a,
offset: 0,
size: None,
}),
},
BindGroupEntry {
binding: 3,
resource: BindingResource::Buffer(BufferBinding {
buffer: indices_buffer,
offset: 0,
size: None,
}),
},
BindGroupEntry {
binding: 4,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sort.keys_b,
offset: 0,
size: None,
}),
},
BindGroupEntry {
binding: 5,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sort.indices_b,
offset: 0,
size: None,
}),
},
BindGroupEntry {
binding: 6,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sort.block_offsets,
offset: 0,
size: None,
}),
},
BindGroupEntry {
binding: 7,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sort.offsets,
offset: 0,
size: None,
}),
},
],
);
let mut encoder = render_device.create_command_encoder(&CommandEncoderDescriptor {
label: Some("spark.sort"),
});
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.set_bind_group(0, &bg, &[sort_param_offset(0, params_stride) as u32]);
pass.set_pipeline(compute_keys_pl);
pass.dispatch_workgroups(workgroups, 1, 1);
}
for pass_idx in 0..4u32 {
let (hist, scatter) = if pass_idx % 2 == 0 {
(hist_a_pl, scatter_ab_pl)
} else {
(hist_b_pl, scatter_ba_pl)
};
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.set_bind_group(
0,
&bg,
&[sort_param_offset(pass_idx + 1, params_stride) as u32],
);
pass.set_pipeline(hist);
pass.dispatch_workgroups(workgroups, 1, 1);
pass.set_pipeline(prefix_sum_pl);
pass.dispatch_workgroups(1, 1, 1);
pass.set_pipeline(scatter);
pass.dispatch_workgroups(workgroups, 1, 1);
}
}
render_queue.submit([encoder.finish()]);
}
}
}
pub fn register_sort_assets(app: &mut App) {
embedded_asset!(app, "sort.wgsl");
}