use std::collections::VecDeque;
use bevy::math::{Mat4, Vec3, Vec3A};
use bevy::prelude::Component;
use bevy::reflect::TypePath;
use bytemuck::{Pod, Zeroable};
use serde::{Deserialize, Serialize};
pub const SH_COEFFS_DEGREE_1: usize = 9;
pub const SH_COEFFS_DEGREE_2: usize = 15;
pub const SH_COEFFS_DEGREE_3: usize = 21;
pub const SH_COEFFS_MAX: usize = SH_COEFFS_DEGREE_1 + SH_COEFFS_DEGREE_2 + SH_COEFFS_DEGREE_3;
pub const SH_PACKED_WORDS: usize = SH_COEFFS_MAX.div_ceil(4);
#[derive(Default, Clone, Copy, Debug)]
pub struct RawSplat {
pub center: [f32; 3],
pub alpha: f32,
pub color: [f32; 3], pub scale: [f32; 3], pub quat: [f32; 4], }
#[derive(Component, Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
pub enum SplatCoordinateConvention {
#[default]
BevyYUp,
YDown,
}
impl SplatCoordinateConvention {
pub(crate) fn local_from_splat(self) -> Mat4 {
match self {
SplatCoordinateConvention::BevyYUp => Mat4::IDENTITY,
SplatCoordinateConvention::YDown => Mat4::from_rotation_x(std::f32::consts::PI),
}
}
}
#[derive(bevy::asset::Asset, TypePath, Clone)]
pub struct Splats {
pub splats: Vec<RawSplat>,
pub anti_aliased: bool,
pub sh_degree: u32,
pub sh_coefficients: Vec<i8>,
pub header_version: u32,
pub lod: bool,
pub lod_child_counts: Vec<u16>,
pub lod_child_starts: Vec<u32>,
pub coordinate_convention: SplatCoordinateConvention,
}
impl Splats {
pub fn len(&self) -> usize {
self.splats.len()
}
pub fn is_empty(&self) -> bool {
self.splats.is_empty()
}
pub fn aabb(&self) -> (Vec3, Vec3) {
let mut mn = Vec3::splat(f32::INFINITY);
let mut mx = Vec3::splat(f32::NEG_INFINITY);
for s in &self.splats {
let p = Vec3::from_array(s.center);
mn = mn.min(p);
mx = mx.max(p);
}
(mn, mx)
}
pub fn to_gpu(&self) -> Vec<GpuSplat> {
use half::f16;
let to_srgb_byte = |v: f32| -> u32 {
(v.clamp(0.0, 1.0) * 255.0).round() as u32
};
self.splats
.iter()
.map(|s| {
let r = to_srgb_byte(s.color[0]);
let g = to_srgb_byte(s.color[1]);
let b = to_srgb_byte(s.color[2]);
let a = (s.alpha.clamp(0.0, 1.0) * 255.0).round() as u32;
let color_alpha = r | (g << 8) | (b << 16) | (a << 24);
let s0 = f16::from_f32(s.scale[0]).to_bits() as u32;
let s1 = f16::from_f32(s.scale[1]).to_bits() as u32;
let s2 = f16::from_f32(s.scale[2]).to_bits() as u32;
let scales01 = s0 | (s1 << 16);
let scales23 = s2;
let qi = |v: f32| -> u32 {
let c = (v.clamp(-1.0, 1.0) * 127.0).round() as i32;
(c as i8) as u8 as u32
};
let rotation = qi(s.quat[0])
| (qi(s.quat[1]) << 8)
| (qi(s.quat[2]) << 16)
| (qi(s.quat[3]) << 24);
GpuSplat {
center: s.center,
color_alpha,
scales01,
scales23,
rotation,
_pad: 0,
}
})
.collect()
}
pub fn sh_coefficients_per_splat(&self) -> usize {
sh_coefficients_per_splat(self.sh_degree)
}
pub fn sh_for_splat(&self, index: usize) -> &[i8] {
let stride = self.sh_coefficients_per_splat();
let start = index * stride;
let end = start + stride;
&self.sh_coefficients[start..end]
}
pub fn to_gpu_sh(&self) -> Vec<GpuSplatSh> {
let stride = self.sh_coefficients_per_splat();
if stride == 0 {
return vec![GpuSplatSh::default()];
}
(0..self.splats.len())
.map(|i| {
let start = i * stride;
let end = start + stride;
self.sh_coefficients
.get(start..end)
.map(GpuSplatSh::from_coefficients)
.unwrap_or_default()
})
.collect()
}
pub fn generate_quick_lod(&mut self, leaf_size: usize, branch_factor: usize) -> bool {
if self.lod || self.splats.len() <= 1 {
return false;
}
let leaf_size = leaf_size.max(1);
let branch_factor = branch_factor.clamp(2, u16::MAX as usize);
if self.splats.len() <= leaf_size {
return false;
}
let original_splats = self.splats.clone();
let stride = self.sh_coefficients_per_splat();
let original_sh = self.sh_coefficients.clone();
let original_indices = (0..original_splats.len()).collect::<Vec<_>>();
let mut nodes =
Vec::with_capacity(original_splats.len() + original_splats.len() / leaf_size);
let root = build_quick_lod_node(
&original_splats,
&original_sh,
stride,
&original_indices,
leaf_size,
branch_factor,
&mut nodes,
);
let mut order = Vec::with_capacity(nodes.len());
let mut node_to_output = vec![usize::MAX; nodes.len()];
let mut queue = VecDeque::from([root]);
while let Some(node_index) = queue.pop_front() {
if node_to_output[node_index] != usize::MAX {
continue;
}
node_to_output[node_index] = order.len();
order.push(node_index);
for &child in &nodes[node_index].children {
queue.push_back(child);
}
}
let mut new_splats = Vec::with_capacity(order.len());
let mut new_sh = if stride == 0 {
Vec::new()
} else {
Vec::with_capacity(order.len() * stride)
};
let mut child_counts = vec![0u16; order.len()];
let mut child_starts = vec![0u32; order.len()];
for (out_index, &node_index) in order.iter().enumerate() {
let node = &nodes[node_index];
new_splats.push(node.splat);
if stride != 0 {
new_sh.extend_from_slice(&node.sh);
}
if let Some(&first_child) = node.children.first() {
child_counts[out_index] = node.children.len() as u16;
child_starts[out_index] = node_to_output[first_child] as u32;
}
}
self.splats = new_splats;
self.sh_coefficients = new_sh;
self.lod = true;
self.lod_child_counts = child_counts;
self.lod_child_starts = child_starts;
true
}
}
struct QuickLodNode {
splat: RawSplat,
sh: Vec<i8>,
children: Vec<usize>,
}
fn build_quick_lod_node(
splats: &[RawSplat],
sh_coefficients: &[i8],
sh_stride: usize,
indices: &[usize],
leaf_size: usize,
branch_factor: usize,
nodes: &mut Vec<QuickLodNode>,
) -> usize {
if indices.len() == 1 {
let index = indices[0];
let sh = sh_for_original(sh_coefficients, sh_stride, index).to_vec();
let node_index = nodes.len();
nodes.push(QuickLodNode {
splat: splats[index],
sh,
children: Vec::new(),
});
return node_index;
}
let child_groups = if indices.len() <= leaf_size {
indices.iter().map(|&index| vec![index]).collect::<Vec<_>>()
} else {
split_quick_lod_groups(splats, indices, leaf_size, branch_factor)
};
let children = child_groups
.iter()
.map(|group| {
build_quick_lod_node(
splats,
sh_coefficients,
sh_stride,
group,
leaf_size,
branch_factor,
nodes,
)
})
.collect::<Vec<_>>();
let node_index = nodes.len();
nodes.push(QuickLodNode {
splat: aggregate_quick_lod_splat(splats, indices),
sh: aggregate_quick_lod_sh(sh_coefficients, sh_stride, indices),
children,
});
node_index
}
fn split_quick_lod_groups(
splats: &[RawSplat],
indices: &[usize],
leaf_size: usize,
branch_factor: usize,
) -> Vec<Vec<usize>> {
let mut min = Vec3A::splat(f32::INFINITY);
let mut max = Vec3A::splat(f32::NEG_INFINITY);
for &index in indices {
let p = Vec3A::from_array(splats[index].center);
min = min.min(p);
max = max.max(p);
}
let extent = max - min;
let axis = if extent.x >= extent.y && extent.x >= extent.z {
0
} else if extent.y >= extent.z {
1
} else {
2
};
let mut sorted = indices.to_vec();
sorted.sort_by(|&a, &b| splats[a].center[axis].total_cmp(&splats[b].center[axis]));
let group_count = branch_factor
.min(indices.len().div_ceil(leaf_size))
.max(2)
.min(indices.len());
let mut groups = Vec::with_capacity(group_count);
for group_index in 0..group_count {
let start = group_index * sorted.len() / group_count;
let end = (group_index + 1) * sorted.len() / group_count;
if start < end {
groups.push(sorted[start..end].to_vec());
}
}
groups
}
fn aggregate_quick_lod_splat(splats: &[RawSplat], indices: &[usize]) -> RawSplat {
let inv_count = 1.0 / indices.len() as f32;
let mut center_sum = Vec3A::ZERO;
let mut color_sum = Vec3A::ZERO;
let mut alpha_sum = 0.0;
for &index in indices {
let splat = splats[index];
center_sum += Vec3A::from_array(splat.center);
color_sum += Vec3A::from_array(splat.color);
alpha_sum += splat.alpha;
}
let center = center_sum * inv_count;
let mut radius = 0.0_f32;
for &index in indices {
let splat = splats[index];
let p = Vec3A::from_array(splat.center);
let splat_radius = splat.scale[0].max(splat.scale[1]).max(splat.scale[2]);
radius = radius.max((p - center).length() + splat_radius);
}
let radius = radius.max(1e-4);
RawSplat {
center: center.to_array(),
alpha: (alpha_sum * inv_count).clamp(0.0, 1.0),
color: (color_sum * inv_count).to_array(),
scale: [radius, radius, radius],
quat: [0.0, 0.0, 0.0, 1.0],
}
}
fn aggregate_quick_lod_sh(sh_coefficients: &[i8], stride: usize, indices: &[usize]) -> Vec<i8> {
if stride == 0 {
return Vec::new();
}
let mut out = vec![0i8; stride];
let inv_count = 1.0 / indices.len() as f32;
for (coeff_index, coeff) in out.iter_mut().enumerate() {
let sum: i32 = indices
.iter()
.map(|&index| sh_for_original(sh_coefficients, stride, index)[coeff_index] as i32)
.sum();
*coeff = ((sum as f32 * inv_count).round() as i32).clamp(-128, 127) as i8;
}
out
}
fn sh_for_original(sh_coefficients: &[i8], stride: usize, index: usize) -> &[i8] {
if stride == 0 {
&[]
} else {
let start = index * stride;
let end = start + stride;
sh_coefficients.get(start..end).unwrap_or(&[])
}
}
pub fn sh_coefficients_per_splat(degree: u32) -> usize {
match degree {
0 => 0,
1 => SH_COEFFS_DEGREE_1,
2 => SH_COEFFS_DEGREE_1 + SH_COEFFS_DEGREE_2,
3 => SH_COEFFS_MAX,
_ => 0,
}
}
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable, Debug)]
pub struct GpuSplat {
pub center: [f32; 3], pub color_alpha: u32, pub scales01: u32, pub scales23: u32, pub rotation: u32, pub _pad: u32, }
const _: [(); 32] = [(); core::mem::size_of::<GpuSplat>()];
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable, Debug, Default)]
pub struct GpuSplatSh {
pub words: [u32; SH_PACKED_WORDS],
}
impl GpuSplatSh {
pub fn from_coefficients(coefficients: &[i8]) -> Self {
let mut out = Self::default();
for (i, coefficient) in coefficients.iter().take(SH_COEFFS_MAX).enumerate() {
let byte = *coefficient as u8 as u32;
out.words[i / 4] |= byte << ((i % 4) * 8);
}
out
}
}
const _: [(); 48] = [(); core::mem::size_of::<GpuSplatSh>()];
#[cfg(test)]
mod tests {
use super::*;
fn test_splat(x: f32) -> RawSplat {
RawSplat {
center: [x, 0.0, 0.0],
alpha: 1.0,
color: [0.5, 0.5, 0.5],
scale: [0.1, 0.1, 0.1],
quat: [0.0, 0.0, 0.0, 1.0],
}
}
#[test]
fn packs_signed_sh_coefficients_as_bytes() {
let coeffs = [-128, -1, 0, 1, 127];
let packed = GpuSplatSh::from_coefficients(&coeffs);
assert_eq!(packed.words[0], 0x0100_ff80);
assert_eq!(packed.words[1] & 0xff, 0x7f);
}
#[test]
fn packs_display_srgb_color_without_linear_quantization() {
let mut splat = test_splat(0.0);
splat.color = [0.5, 0.25, 0.0];
splat.alpha = 0.5;
let splats = Splats {
splats: vec![splat],
anti_aliased: false,
sh_degree: 0,
sh_coefficients: Vec::new(),
header_version: 3,
lod: false,
lod_child_counts: Vec::new(),
lod_child_starts: Vec::new(),
coordinate_convention: SplatCoordinateConvention::BevyYUp,
};
let color_alpha = splats.to_gpu()[0].color_alpha;
assert_eq!(color_alpha & 0xff, 128);
assert_eq!((color_alpha >> 8) & 0xff, 64);
assert_eq!((color_alpha >> 16) & 0xff, 0);
assert_eq!((color_alpha >> 24) & 0xff, 128);
}
#[test]
fn generates_reusable_quick_lod_tree() {
let mut splats = Splats {
splats: (0..4).map(|i| test_splat(i as f32)).collect(),
anti_aliased: false,
sh_degree: 0,
sh_coefficients: Vec::new(),
header_version: 3,
lod: false,
lod_child_counts: Vec::new(),
lod_child_starts: Vec::new(),
coordinate_convention: SplatCoordinateConvention::BevyYUp,
};
assert!(splats.generate_quick_lod(2, 2));
assert!(splats.lod);
assert_eq!(splats.splats.len(), 7);
assert_eq!(splats.lod_child_counts, vec![2, 2, 2, 0, 0, 0, 0]);
assert_eq!(splats.lod_child_starts, vec![1, 3, 5, 0, 0, 0, 0]);
}
#[test]
fn generated_quick_lod_remaps_sh_coefficients() {
let stride = sh_coefficients_per_splat(1);
let mut splats = Splats {
splats: vec![test_splat(0.0), test_splat(1.0)],
anti_aliased: false,
sh_degree: 1,
sh_coefficients: [vec![10; stride], vec![14; stride]].concat(),
header_version: 3,
lod: false,
lod_child_counts: Vec::new(),
lod_child_starts: Vec::new(),
coordinate_convention: SplatCoordinateConvention::BevyYUp,
};
assert!(splats.generate_quick_lod(1, 2));
assert_eq!(splats.splats.len(), 3);
assert_eq!(splats.sh_coefficients.len(), splats.splats.len() * stride);
assert_eq!(&splats.sh_coefficients[0..stride], vec![12; stride]);
assert_eq!(
&splats.sh_coefficients[stride..stride * 2],
vec![10; stride]
);
assert_eq!(
&splats.sh_coefficients[stride * 2..stride * 3],
vec![14; stride]
);
}
}