// 3D Gaussian Splatting renderer. Mirrors the sparkjs splatVertex.glsl
// covariance projection and splatFragment.glsl falloff function.
#import bevy_render::view::View
// Packed 32-byte layout. See splats.rs::GpuSplat.
struct Splat {
center: vec3<f32>, // offset 0..12 (vec3 takes 16-byte slot)
color_alpha: u32, // offset 12..16 (unpack4x8unorm)
scales01: u32, // offset 16..20 (unpack2x16float -> sx, sy)
scales23: u32, // offset 20..24 (unpack2x16float -> sz, pad)
rotation: u32, // offset 24..28 (unpack4x8snorm; normalize before use)
_pad: u32, // offset 28..32
};
struct SplatSh {
words: array<u32, 12>,
};
@group(0) @binding(0) var<uniform> view: View;
// Per-cloud bind group
#ifdef SPARK_TEXTURE_BACKEND
@group(1) @binding(0) var splats_texture: texture_2d<u32>;
@group(1) @binding(1) var sorted_indices_texture: texture_2d<u32>;
#else
@group(1) @binding(0) var<storage, read> splats_buffer: array<Splat>;
@group(1) @binding(1) var<storage, read> sorted_indices_buffer: array<u32>;
#endif
@group(1) @binding(2) var<uniform> cloud: CloudUniforms;
#ifdef SPARK_TEXTURE_BACKEND
@group(1) @binding(3) var splat_sh_texture: texture_2d<u32>;
#else
@group(1) @binding(3) var<storage, read> splat_sh_buffer: array<SplatSh>;
#endif
struct CloudUniforms {
// 4x4 column-major model -> world.
model: mat4x4<f32>,
// Total number of splats (also serves as a guard for out-of-range indices).
num_splats: u32,
// 1 if the source SPZ was anti-aliased (different alpha falloff).
anti_aliased: u32,
// SPZ spherical harmonics degree, 0..3.
sh_degree: u32,
_pad0: u32,
// Render target size in pixels (x, y).
render_size: vec2<f32>,
// Tunable: max stddev for the splat extent (sparkjs default = sqrt(8)).
max_stddev: f32,
// Min alpha cut to discard the splat in the vertex stage.
min_alpha: f32,
// Min/max pixel radius clamp.
min_pixel_radius: f32,
max_pixel_radius: f32,
// 0 = SparkJS reference falloff, 1 = legacy edge-normalized falloff.
falloff_profile: u32,
// 0 = SparkJS reference high-alpha stretch, 1 = legacy bounded expansion.
high_alpha_profile: u32,
// Width used to tile WebGL2 data textures. Unused by storage-buffer path.
texture_width: u32,
_pad1: u32,
_pad2: u32,
_pad3: u32,
};
struct VsOut {
@builtin(position) clip_pos: vec4<f32>,
@location(0) splat_uv: vec2<f32>,
@location(1) v_rgba: vec4<f32>,
@location(2) adjusted_stddev: f32,
};
// Quad corners shared by every splat. Order matches indexed mesh.
const QUAD_CORNERS = array<vec2<f32>, 4>(
vec2<f32>(-1.0, -1.0),
vec2<f32>( 1.0, -1.0),
vec2<f32>(-1.0, 1.0),
vec2<f32>( 1.0, 1.0),
);
const FALLOFF_PROFILE_EDGE_NORMALIZED: u32 = 1u;
const HIGH_ALPHA_PROFILE_BOUNDED: u32 = 1u;
fn quat_to_mat3(q: vec4<f32>) -> mat3x3<f32> {
let xx = q.x * q.x;
let yy = q.y * q.y;
let zz = q.z * q.z;
let xy = q.x * q.y;
let xz = q.x * q.z;
let yz = q.y * q.z;
let wx = q.w * q.x;
let wy = q.w * q.y;
let wz = q.w * q.z;
return mat3x3<f32>(
vec3<f32>(1.0 - 2.0 * (yy + zz), 2.0 * (xy + wz), 2.0 * (xz - wy)),
vec3<f32>( 2.0 * (xy - wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz + wx)),
vec3<f32>( 2.0 * (xz + wy), 2.0 * (yz - wx), 1.0 - 2.0 * (xx + yy)),
);
}
fn normalize_or_zero(v: vec3<f32>) -> vec3<f32> {
let len2 = dot(v, v);
if (len2 <= 1e-12) {
return vec3<f32>(0.0, 0.0, 1.0);
}
return v * inverseSqrt(len2);
}
fn srgb_channel_to_linear(value: f32) -> f32 {
if (value <= 0.0) {
return value;
}
if (value <= 0.04045) {
return value / 12.92;
}
return pow((value + 0.055) / 1.055, 2.4);
}
fn linear_channel_to_srgb(value: f32) -> f32 {
if (value <= 0.0) {
return value;
}
if (value <= 0.0031308) {
return value * 12.92;
}
return 1.055 * pow(value, 1.0 / 2.4) - 0.055;
}
fn srgb_to_linear(rgb: vec3<f32>) -> vec3<f32> {
return vec3<f32>(
srgb_channel_to_linear(rgb.x),
srgb_channel_to_linear(rgb.y),
srgb_channel_to_linear(rgb.z),
);
}
fn linear_to_srgb(rgb: vec3<f32>) -> vec3<f32> {
return vec3<f32>(
linear_channel_to_srgb(rgb.x),
linear_channel_to_srgb(rgb.y),
linear_channel_to_srgb(rgb.z),
);
}
fn unpack4x8unorm_bits(word: u32) -> vec4<f32> {
return vec4<f32>(
f32(word & 0xffu),
f32((word >> 8u) & 0xffu),
f32((word >> 16u) & 0xffu),
f32((word >> 24u) & 0xffu),
) / 255.0;
}
fn unpack_snorm8(byte: u32) -> f32 {
var signed_value = i32(byte & 0xffu);
if (byte >= 128u) {
signed_value = signed_value - 256;
}
return max(f32(signed_value) / 127.0, -1.0);
}
fn unpack4x8snorm_bits(word: u32) -> vec4<f32> {
return vec4<f32>(
unpack_snorm8(word & 0xffu),
unpack_snorm8((word >> 8u) & 0xffu),
unpack_snorm8((word >> 16u) & 0xffu),
unpack_snorm8((word >> 24u) & 0xffu),
);
}
fn unpack_f16_bits(bits: u32) -> f32 {
let sign = select(1.0, -1.0, (bits & 0x8000u) != 0u);
let exponent = (bits >> 10u) & 0x1fu;
let mantissa = bits & 0x03ffu;
if (exponent == 0u) {
if (mantissa == 0u) {
return sign * 0.0;
}
return sign * exp2(-14.0) * (f32(mantissa) / 1024.0);
}
if (exponent == 31u) {
return sign * 3.4028234663852886e38;
}
return sign * exp2(f32(exponent) - 15.0) * (1.0 + f32(mantissa) / 1024.0);
}
fn unpack2x16float_bits(word: u32) -> vec2<f32> {
return vec2<f32>(
unpack_f16_bits(word & 0xffffu),
unpack_f16_bits((word >> 16u) & 0xffffu),
);
}
#ifdef SPARK_TEXTURE_BACKEND
fn data_texture_coords(texel_index: u32) -> vec2<u32> {
return vec2<u32>(texel_index % cloud.texture_width, texel_index / cloud.texture_width);
}
fn load_splat(index: u32) -> Splat {
let base = index * 2u;
let a = textureLoad(splats_texture, data_texture_coords(base), 0);
let b = textureLoad(splats_texture, data_texture_coords(base + 1u), 0);
var s: Splat;
s.center = vec3<f32>(
bitcast<f32>(a.x),
bitcast<f32>(a.y),
bitcast<f32>(a.z),
);
s.color_alpha = a.w;
s.scales01 = b.x;
s.scales23 = b.y;
s.rotation = b.z;
s._pad = b.w;
return s;
}
fn load_sorted_index(index: u32) -> u32 {
return textureLoad(sorted_indices_texture, data_texture_coords(index), 0).x;
}
fn load_splat_sh(index: u32) -> SplatSh {
let base = index * 3u;
let a = textureLoad(splat_sh_texture, data_texture_coords(base), 0);
let b = textureLoad(splat_sh_texture, data_texture_coords(base + 1u), 0);
let c = textureLoad(splat_sh_texture, data_texture_coords(base + 2u), 0);
return SplatSh(array<u32, 12>(
a.x, a.y, a.z, a.w,
b.x, b.y, b.z, b.w,
c.x, c.y, c.z, c.w,
));
}
#else
fn load_splat(index: u32) -> Splat {
return splats_buffer[index];
}
fn load_sorted_index(index: u32) -> u32 {
return sorted_indices_buffer[index];
}
fn load_splat_sh(index: u32) -> SplatSh {
return splat_sh_buffer[index];
}
#endif
fn sh_coeff(sh: SplatSh, index: u32) -> f32 {
let word = sh.words[index / 4u];
let byte = (word >> ((index & 3u) * 8u)) & 0xffu;
var signed_value = i32(byte);
if (byte >= 128u) {
signed_value = signed_value - 256;
}
return f32(signed_value) / 128.0;
}
fn evaluate_spz_sh(sh: SplatSh, view_dir: vec3<f32>, degree: u32) -> vec3<f32> {
var rgb = vec3<f32>(0.0);
if (degree >= 1u) {
let sh1_0 = vec3<f32>(sh_coeff(sh, 0u), sh_coeff(sh, 1u), sh_coeff(sh, 2u));
let sh1_1 = vec3<f32>(sh_coeff(sh, 3u), sh_coeff(sh, 4u), sh_coeff(sh, 5u));
let sh1_2 = vec3<f32>(sh_coeff(sh, 6u), sh_coeff(sh, 7u), sh_coeff(sh, 8u));
rgb = rgb
+ sh1_0 * (-0.4886025 * view_dir.y)
+ sh1_1 * ( 0.4886025 * view_dir.z)
+ sh1_2 * (-0.4886025 * view_dir.x);
}
if (degree >= 2u) {
let base = 9u;
let sh2_0 = vec3<f32>(sh_coeff(sh, base + 0u), sh_coeff(sh, base + 1u), sh_coeff(sh, base + 2u));
let sh2_1 = vec3<f32>(sh_coeff(sh, base + 3u), sh_coeff(sh, base + 4u), sh_coeff(sh, base + 5u));
let sh2_2 = vec3<f32>(sh_coeff(sh, base + 6u), sh_coeff(sh, base + 7u), sh_coeff(sh, base + 8u));
let sh2_3 = vec3<f32>(sh_coeff(sh, base + 9u), sh_coeff(sh, base + 10u), sh_coeff(sh, base + 11u));
let sh2_4 = vec3<f32>(sh_coeff(sh, base + 12u), sh_coeff(sh, base + 13u), sh_coeff(sh, base + 14u));
rgb = rgb
+ sh2_0 * ( 1.0925484 * view_dir.x * view_dir.y)
+ sh2_1 * (-1.0925484 * view_dir.y * view_dir.z)
+ sh2_2 * ( 0.3153915 * (2.0 * view_dir.z * view_dir.z - view_dir.x * view_dir.x - view_dir.y * view_dir.y))
+ sh2_3 * (-1.0925484 * view_dir.x * view_dir.z)
+ sh2_4 * ( 0.5462742 * (view_dir.x * view_dir.x - view_dir.y * view_dir.y));
}
if (degree >= 3u) {
let base = 24u;
let sh3_0 = vec3<f32>(sh_coeff(sh, base + 0u), sh_coeff(sh, base + 1u), sh_coeff(sh, base + 2u));
let sh3_1 = vec3<f32>(sh_coeff(sh, base + 3u), sh_coeff(sh, base + 4u), sh_coeff(sh, base + 5u));
let sh3_2 = vec3<f32>(sh_coeff(sh, base + 6u), sh_coeff(sh, base + 7u), sh_coeff(sh, base + 8u));
let sh3_3 = vec3<f32>(sh_coeff(sh, base + 9u), sh_coeff(sh, base + 10u), sh_coeff(sh, base + 11u));
let sh3_4 = vec3<f32>(sh_coeff(sh, base + 12u), sh_coeff(sh, base + 13u), sh_coeff(sh, base + 14u));
let sh3_5 = vec3<f32>(sh_coeff(sh, base + 15u), sh_coeff(sh, base + 16u), sh_coeff(sh, base + 17u));
let sh3_6 = vec3<f32>(sh_coeff(sh, base + 18u), sh_coeff(sh, base + 19u), sh_coeff(sh, base + 20u));
let xx = view_dir.x * view_dir.x;
let yy = view_dir.y * view_dir.y;
let zz = view_dir.z * view_dir.z;
let xy = view_dir.x * view_dir.y;
rgb = rgb
+ sh3_0 * (-0.5900436 * view_dir.y * (3.0 * xx - yy))
+ sh3_1 * ( 2.8906114 * xy * view_dir.z)
+ sh3_2 * (-0.4570458 * view_dir.y * (4.0 * zz - xx - yy))
+ sh3_3 * ( 0.3731763 * view_dir.z * (2.0 * zz - 3.0 * xx - 3.0 * yy))
+ sh3_4 * (-0.4570458 * view_dir.x * (4.0 * zz - xx - yy))
+ sh3_5 * ( 1.4453057 * view_dir.z * (xx - yy))
+ sh3_6 * (-0.5900436 * view_dir.x * (xx - 3.0 * yy));
}
return rgb;
}
@vertex
fn vs_main(@builtin(vertex_index) vid: u32, @builtin(instance_index) iid: u32) -> VsOut {
var out: VsOut;
out.clip_pos = vec4<f32>(0.0, 0.0, 2.0, 1.0); // off-screen by default
out.v_rgba = vec4<f32>(0.0);
out.splat_uv = vec2<f32>(0.0);
out.adjusted_stddev = 0.0;
if (iid >= cloud.num_splats) {
return out;
}
let splat_index = load_sorted_index(iid);
let s = load_splat(splat_index);
var rgba = unpack4x8unorm_bits(s.color_alpha);
// sparkjs splatVertex.glsl: stored alpha is half of visual.
rgba.a = rgba.a * 2.0;
if (rgba.a == 0.0 || rgba.a < cloud.min_alpha) {
return out;
}
let s01 = unpack2x16float_bits(s.scales01);
let s23 = unpack2x16float_bits(s.scales23);
let scales = vec3<f32>(s01.x, s01.y, s23.x);
let rotation = normalize(unpack4x8snorm_bits(s.rotation));
var adjusted = cloud.max_stddev;
if (rgba.a > 1.0) {
if (cloud.high_alpha_profile == HIGH_ALPHA_PROFILE_BOUNDED) {
adjusted = cloud.max_stddev + 0.4 * (rgba.a - 1.0);
} else {
// SparkJS stretches 1..2 alpha to 1..5 before expanding the quad.
rgba.a = min(rgba.a * 4.0 - 3.0, 5.0);
adjusted = cloud.max_stddev + 0.7 * (rgba.a - 1.0);
}
}
// World-space center.
let center_world = (cloud.model * vec4<f32>(s.center, 1.0)).xyz;
let display_rgb = rgba.rgb;
if (cloud.sh_degree > 0u) {
let view_dir = normalize_or_zero(center_world - view.world_position);
let sh_rgb = evaluate_spz_sh(load_splat_sh(splat_index), view_dir, cloud.sh_degree);
rgba = vec4<f32>(srgb_to_linear(clamp(display_rgb + sh_rgb, vec3<f32>(0.0), vec3<f32>(1.0))), rgba.a);
} else {
rgba = vec4<f32>(srgb_to_linear(display_rgb), rgba.a);
}
// View-space center.
let view_center4 = view.view_from_world * vec4<f32>(center_world, 1.0);
let view_center = view_center4.xyz;
// Behind near?
if (view_center.z >= 0.0) {
return out;
}
// Clip-space center.
let clip_center = view.clip_from_view * view_center4;
if (abs(clip_center.z) >= clip_center.w) { return out; }
let clip_xy_lim = 1.4 * clip_center.w;
if (abs(clip_center.x) > clip_xy_lim || abs(clip_center.y) > clip_xy_lim) { return out; }
// 3D covariance in view space:
// cov3D = (view * model_linear * local_rotation * local_scale)
// * transpose(view * model_linear * local_rotation * local_scale).
// `model3` intentionally carries uniform and non-uniform entity scale.
let model3 = mat3x3<f32>(cloud.model[0].xyz, cloud.model[1].xyz, cloud.model[2].xyz);
let view3 = mat3x3<f32>(view.view_from_world[0].xyz, view.view_from_world[1].xyz, view.view_from_world[2].xyz);
let local_rot = quat_to_mat3(rotation);
let R = view3 * model3 * local_rot;
let RS = mat3x3<f32>(R[0] * scales.x, R[1] * scales.y, R[2] * scales.z);
let cov3d = RS * transpose(RS);
// Project the 3D covariance to screen-space using the projection Jacobian.
let proj = view.clip_from_view;
let focal = 0.5 * cloud.render_size * vec2<f32>(proj[0][0], proj[1][1]);
var J: mat3x3<f32>;
if (proj[3][3] == 1.0) {
J = mat3x3<f32>(
vec3<f32>(focal.x, 0.0, 0.0),
vec3<f32>(0.0, focal.y, 0.0),
vec3<f32>(0.0, 0.0, 0.0),
);
} else {
let inv_z = 1.0 / view_center.z;
let j1 = focal * inv_z;
let j2 = -(j1 * view_center.xy) * inv_z;
J = mat3x3<f32>(
vec3<f32>(j1.x, 0.0, 0.0),
vec3<f32>(0.0, j1.y, 0.0),
vec3<f32>(j2.x, j2.y, 0.0),
);
}
let cov2d_full = J * cov3d * transpose(J);
var a_ = cov2d_full[0][0];
var d_ = cov2d_full[1][1];
let b_ = cov2d_full[0][1];
// SparkJS convention:
// - anti-aliased SPZ: blur with opacity compensation
// - non-anti-aliased SPZ: pre-blur only, no opacity compensation
let pre_blur = select(0.3, 0.0, cloud.anti_aliased != 0u);
let blur = select(0.0, 0.3, cloud.anti_aliased != 0u);
a_ = a_ + pre_blur;
d_ = d_ + pre_blur;
let det_orig = max(a_ * d_ - b_ * b_, 0.0);
a_ = a_ + blur;
d_ = d_ + blur;
let det = a_ * d_ - b_ * b_;
if (det <= 0.0) { return out; }
let blur_adjust = sqrt(max(0.0, det_orig / det));
rgba.a = rgba.a * blur_adjust;
if (rgba.a < cloud.min_alpha) { return out; }
let eigen_avg = 0.5 * (a_ + d_);
let eigen_delta = sqrt(max(0.0, eigen_avg * eigen_avg - det));
let eigen1 = eigen_avg + eigen_delta;
let eigen2 = eigen_avg - eigen_delta;
var ev1: vec2<f32>;
if (abs(b_) > 0.001) {
ev1 = normalize(vec2<f32>(b_, eigen1 - a_));
} else if (a_ >= d_) {
ev1 = vec2<f32>(1.0, 0.0);
} else {
ev1 = vec2<f32>(0.0, 1.0);
}
let ev2 = vec2<f32>(ev1.y, -ev1.x);
let scale1 = min(cloud.max_pixel_radius, adjusted * sqrt(max(eigen1, 0.0)));
let scale2 = min(cloud.max_pixel_radius, adjusted * sqrt(max(eigen2, 0.0)));
if (scale1 < cloud.min_pixel_radius && scale2 < cloud.min_pixel_radius) { return out; }
let corner = QUAD_CORNERS[vid];
let pixel_offset = corner.x * ev1 * scale1 + corner.y * ev2 * scale2;
let ndc_offset = (2.0 / cloud.render_size) * pixel_offset;
let ndc_center = clip_center.xyz / clip_center.w;
let ndc = vec3<f32>(ndc_center.xy + ndc_offset, ndc_center.z);
// Keep every quad corner at the splat center depth so Bevy's reverse-Z
// depth test and transparent phase ordering use one stable depth per splat.
out.clip_pos = vec4<f32>(ndc.xy * clip_center.w, clip_center.zw);
out.splat_uv = corner * adjusted;
out.v_rgba = vec4<f32>(rgba.rgb, rgba.a);
out.adjusted_stddev = adjusted;
return out;
}
@fragment
fn fs_main(in: VsOut) -> @location(0) vec4<f32> {
let z2 = dot(in.splat_uv, in.splat_uv);
let stddev2 = in.adjusted_stddev * in.adjusted_stddev;
if (z2 > stddev2) {
discard;
}
let edge_z = exp(-0.5 * z2);
var alpha: f32;
if (cloud.falloff_profile == FALLOFF_PROFILE_EDGE_NORMALIZED) {
let edge_raw = exp(-0.5 * stddev2);
let denom = max(1.0 - edge_raw, 1e-8);
if (in.v_rgba.a <= 1.0) {
let attenuation = (edge_z - edge_raw) / denom;
alpha = in.v_rgba.a * attenuation;
} else {
let a = exp((in.v_rgba.a * in.v_rgba.a - 1.0) / 2.718281828459045);
let ratio = (1.0 - edge_z) / denom;
alpha = 1.0 - pow(ratio, a);
}
} else {
if (in.v_rgba.a <= 1.0) {
alpha = in.v_rgba.a * edge_z;
} else {
let a = exp((in.v_rgba.a * in.v_rgba.a - 1.0) / 2.718281828459045);
alpha = 1.0 - pow(1.0 - edge_z, a);
}
}
if (alpha < cloud.min_alpha) {
discard;
}
let rgb = clamp(in.v_rgba.rgb, vec3<f32>(0.0), vec3<f32>(1.0));
return vec4<f32>(rgb * alpha, alpha);
}