bevy_spark 0.2.0

Gaussian splat renderer for Bevy with SPZ support
Documentation
// GPU radix sort for splat indices.
//
// 4 passes × 3 dispatches (block histogram, prefix, stable scatter) over an
// 8-bit digit of a 32-bit distance key. Pass 0 reads keys_a and writes keys_b;
// the scatter dispatch is then re-issued with swapped buffers, so after 4
// passes the sorted result lands in keys_a/indices_a.

struct Params {
    num_splats: u32,
    pass_index: u32,
    num_blocks: u32,
    _pad: u32,
    view_pos: vec4<f32>, // .xyz = splat-local view position
};

// Mirrors the 32-byte layout in splats.rs::GpuSplat. Only `.center` is used.
struct Splat {
    center: vec3<f32>,
    color_alpha: u32,
    scales01: u32,
    scales23: u32,
    rotation: u32,
    _pad: u32,
};

@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> splats: array<Splat>;
@group(0) @binding(2) var<storage, read_write> keys_a: array<u32>;
@group(0) @binding(3) var<storage, read_write> indices_a: array<u32>;
@group(0) @binding(4) var<storage, read_write> keys_b: array<u32>;
@group(0) @binding(5) var<storage, read_write> indices_b: array<u32>;
@group(0) @binding(6) var<storage, read_write> block_offsets: array<u32>;
@group(0) @binding(7) var<storage, read_write> digit_offsets: array<u32, 256>;

const WG: u32 = 256u;

var<workgroup> local_histogram: array<atomic<u32>, 256>;
var<workgroup> local_digits: array<u32, 256>;

fn digit_of(key: u32, pass_idx: u32) -> u32 {
    return (key >> (pass_idx * 8u)) & 0xffu;
}

@compute @workgroup_size(WG)
fn compute_keys(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if (i >= params.num_splats) { return; }
    let c = splats[i].center;
    let d = c - params.view_pos.xyz;
    let d2 = dot(d, d);
    // Invert so largest distance maps to smallest key (back-to-front order).
    let key = 0xffffffffu - bitcast<u32>(d2);
    keys_a[i] = key;
    indices_a[i] = i;
}

@compute @workgroup_size(WG)
fn histogram_pass(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_index) lid: u32,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    atomicStore(&local_histogram[lid], 0u);
    workgroupBarrier();

    let i = gid.x;
    if (i < params.num_splats) {
        let key = keys_a[i];
        atomicAdd(&local_histogram[digit_of(key, params.pass_index)], 1u);
    }
    workgroupBarrier();

    block_offsets[wid.x * 256u + lid] = atomicLoad(&local_histogram[lid]);
}

// Single-thread prefix over digit buckets and workgroup-local histograms.
// This is intentionally simple and stable; performance is tracked separately.
@compute @workgroup_size(1)
fn prefix_sum(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (gid.x != 0u) { return; }
    var digit_base: u32 = 0u;
    for (var d: u32 = 0u; d < 256u; d = d + 1u) {
        digit_offsets[d] = digit_base;
        var block_base: u32 = 0u;
        for (var b: u32 = 0u; b < params.num_blocks; b = b + 1u) {
            let idx = b * 256u + d;
            let count = block_offsets[idx];
            block_offsets[idx] = block_base;
            block_base = block_base + count;
        }
        digit_base = digit_base + block_base;
    }
}

@compute @workgroup_size(WG)
fn scatter(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_index) lid: u32,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    let i = gid.x;
    var key: u32 = 0u;
    var idx: u32 = 0u;
    var d: u32 = 0xffffffffu;
    let is_active = i < params.num_splats;
    if (is_active) {
        key = keys_a[i];
        idx = indices_a[i];
        d = digit_of(key, params.pass_index);
    }
    local_digits[lid] = d;
    workgroupBarrier();
    if (!is_active) { return; }

    var local_rank: u32 = 0u;
    for (var j: u32 = 0u; j < lid; j = j + 1u) {
        if (local_digits[j] == d) {
            local_rank = local_rank + 1u;
        }
    }
    let dst = digit_offsets[d] + block_offsets[wid.x * 256u + d] + local_rank;
    keys_b[dst] = key;
    indices_b[dst] = idx;
}

// "B" variants for ping-pong: histogram from B, scatter B → A.
@compute @workgroup_size(WG)
fn histogram_pass_b(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_index) lid: u32,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    atomicStore(&local_histogram[lid], 0u);
    workgroupBarrier();

    let i = gid.x;
    if (i < params.num_splats) {
        let key = keys_b[i];
        atomicAdd(&local_histogram[digit_of(key, params.pass_index)], 1u);
    }
    workgroupBarrier();

    block_offsets[wid.x * 256u + lid] = atomicLoad(&local_histogram[lid]);
}

@compute @workgroup_size(WG)
fn scatter_b(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_index) lid: u32,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    let i = gid.x;
    var key: u32 = 0u;
    var idx: u32 = 0u;
    var d: u32 = 0xffffffffu;
    let is_active = i < params.num_splats;
    if (is_active) {
        key = keys_b[i];
        idx = indices_b[i];
        d = digit_of(key, params.pass_index);
    }
    local_digits[lid] = d;
    workgroupBarrier();
    if (!is_active) { return; }

    var local_rank: u32 = 0u;
    for (var j: u32 = 0u; j < lid; j = j + 1u) {
        if (local_digits[j] == d) {
            local_rank = local_rank + 1u;
        }
    }
    let dst = digit_offsets[d] + block_offsets[wid.x * 256u + d] + local_rank;
    keys_a[dst] = key;
    indices_a[dst] = idx;
}