// GPU radix sort for splat indices.
//
// 4 passes × 3 dispatches (block histogram, prefix, stable scatter) over an
// 8-bit digit of a 32-bit distance key. Pass 0 reads keys_a and writes keys_b;
// the scatter dispatch is then re-issued with swapped buffers, so after 4
// passes the sorted result lands in keys_a/indices_a.
struct Params {
num_splats: u32,
pass_index: u32,
num_blocks: u32,
_pad: u32,
view_pos: vec4<f32>, // .xyz = splat-local view position
};
// Mirrors the 32-byte layout in splats.rs::GpuSplat. Only `.center` is used.
struct Splat {
center: vec3<f32>,
color_alpha: u32,
scales01: u32,
scales23: u32,
rotation: u32,
_pad: u32,
};
@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> splats: array<Splat>;
@group(0) @binding(2) var<storage, read_write> keys_a: array<u32>;
@group(0) @binding(3) var<storage, read_write> indices_a: array<u32>;
@group(0) @binding(4) var<storage, read_write> keys_b: array<u32>;
@group(0) @binding(5) var<storage, read_write> indices_b: array<u32>;
@group(0) @binding(6) var<storage, read_write> block_offsets: array<u32>;
@group(0) @binding(7) var<storage, read_write> digit_offsets: array<u32, 256>;
const WG: u32 = 256u;
var<workgroup> local_histogram: array<atomic<u32>, 256>;
var<workgroup> local_digits: array<u32, 256>;
fn digit_of(key: u32, pass_idx: u32) -> u32 {
return (key >> (pass_idx * 8u)) & 0xffu;
}
@compute @workgroup_size(WG)
fn compute_keys(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= params.num_splats) { return; }
let c = splats[i].center;
let d = c - params.view_pos.xyz;
let d2 = dot(d, d);
// Invert so largest distance maps to smallest key (back-to-front order).
let key = 0xffffffffu - bitcast<u32>(d2);
keys_a[i] = key;
indices_a[i] = i;
}
@compute @workgroup_size(WG)
fn histogram_pass(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_index) lid: u32,
@builtin(workgroup_id) wid: vec3<u32>,
) {
atomicStore(&local_histogram[lid], 0u);
workgroupBarrier();
let i = gid.x;
if (i < params.num_splats) {
let key = keys_a[i];
atomicAdd(&local_histogram[digit_of(key, params.pass_index)], 1u);
}
workgroupBarrier();
block_offsets[wid.x * 256u + lid] = atomicLoad(&local_histogram[lid]);
}
// Single-thread prefix over digit buckets and workgroup-local histograms.
// This is intentionally simple and stable; performance is tracked separately.
@compute @workgroup_size(1)
fn prefix_sum(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x != 0u) { return; }
var digit_base: u32 = 0u;
for (var d: u32 = 0u; d < 256u; d = d + 1u) {
digit_offsets[d] = digit_base;
var block_base: u32 = 0u;
for (var b: u32 = 0u; b < params.num_blocks; b = b + 1u) {
let idx = b * 256u + d;
let count = block_offsets[idx];
block_offsets[idx] = block_base;
block_base = block_base + count;
}
digit_base = digit_base + block_base;
}
}
@compute @workgroup_size(WG)
fn scatter(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_index) lid: u32,
@builtin(workgroup_id) wid: vec3<u32>,
) {
let i = gid.x;
var key: u32 = 0u;
var idx: u32 = 0u;
var d: u32 = 0xffffffffu;
let is_active = i < params.num_splats;
if (is_active) {
key = keys_a[i];
idx = indices_a[i];
d = digit_of(key, params.pass_index);
}
local_digits[lid] = d;
workgroupBarrier();
if (!is_active) { return; }
var local_rank: u32 = 0u;
for (var j: u32 = 0u; j < lid; j = j + 1u) {
if (local_digits[j] == d) {
local_rank = local_rank + 1u;
}
}
let dst = digit_offsets[d] + block_offsets[wid.x * 256u + d] + local_rank;
keys_b[dst] = key;
indices_b[dst] = idx;
}
// "B" variants for ping-pong: histogram from B, scatter B → A.
@compute @workgroup_size(WG)
fn histogram_pass_b(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_index) lid: u32,
@builtin(workgroup_id) wid: vec3<u32>,
) {
atomicStore(&local_histogram[lid], 0u);
workgroupBarrier();
let i = gid.x;
if (i < params.num_splats) {
let key = keys_b[i];
atomicAdd(&local_histogram[digit_of(key, params.pass_index)], 1u);
}
workgroupBarrier();
block_offsets[wid.x * 256u + lid] = atomicLoad(&local_histogram[lid]);
}
@compute @workgroup_size(WG)
fn scatter_b(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_index) lid: u32,
@builtin(workgroup_id) wid: vec3<u32>,
) {
let i = gid.x;
var key: u32 = 0u;
var idx: u32 = 0u;
var d: u32 = 0xffffffffu;
let is_active = i < params.num_splats;
if (is_active) {
key = keys_b[i];
idx = indices_b[i];
d = digit_of(key, params.pass_index);
}
local_digits[lid] = d;
workgroupBarrier();
if (!is_active) { return; }
var local_rank: u32 = 0u;
for (var j: u32 = 0u; j < lid; j = j + 1u) {
if (local_digits[j] == d) {
local_rank = local_rank + 1u;
}
}
let dst = digit_offsets[d] + block_offsets[wid.x * 256u + d] + local_rank;
keys_a[dst] = key;
indices_a[dst] = idx;
}