viewport-lib 0.14.0

3D viewport rendering library
Documentation
// gaussian_splat_sort.wgsl
//
// Two shader modules in one file:
//   - Depth compute: compute view-space depth keys for each splat.
//   - Radix sort: 4-pass 8-bit GPU radix sort (back-to-front order).

// -----------------------------------------------------------------------
// Depth compute
// -----------------------------------------------------------------------

struct DepthUniform {
    model: mat4x4<f32>,
    eye: vec3<f32>,
    count: u32,
}

@group(0) @binding(0) var<uniform> depth_u: DepthUniform;
@group(0) @binding(1) var<storage, read> positions: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read_write> keys_ping_out: array<u32>;

// Write flipped depth bits so ascending sort = back-to-front order.
@compute @workgroup_size(256)
fn compute_depths(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i >= depth_u.count { return; }
    let world_pos = (depth_u.model * positions[i]).xyz;
    let dist = length(world_pos - depth_u.eye);
    // Flip all bits: large distance -> small key -> sorted first (back-to-front).
    keys_ping_out[i] = ~bitcast<u32>(dist);
}

// -----------------------------------------------------------------------
// Radix sort
// -----------------------------------------------------------------------

struct SortUniform {
    shift: u32,    // bit shift for current pass: 0, 8, 16, 24
    count: u32,
    pass_num: u32, // 0-3, selects ping/pong
    _pad: u32,
}

@group(0) @binding(0) var<uniform> sort_u: SortUniform;
@group(0) @binding(1) var<storage, read_write> s_keys_ping: array<u32>;
@group(0) @binding(2) var<storage, read_write> s_keys_pong: array<u32>;
@group(0) @binding(3) var<storage, read_write> s_vals_ping: array<u32>;
@group(0) @binding(4) var<storage, read_write> s_vals_pong: array<u32>;
@group(0) @binding(5) var<storage, read_write> histogram: array<atomic<u32>>;

// Initialize vals_ping[i] = i (identity permutation before first sort pass).
@compute @workgroup_size(256)
fn init_indices(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i >= sort_u.count { return; }
    s_vals_ping[i] = i;
}

// Clear the 256-entry global histogram to zero.
@compute @workgroup_size(256)
fn clear_histogram(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i < 256u {
        atomicStore(&histogram[i], 0u);
    }
}

var<workgroup> local_hist: array<atomic<u32>, 256>;

// Count occurrences of each 8-bit bucket in the current digit position.
@compute @workgroup_size(256)
fn histogram_pass(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    // Zero local histogram.
    atomicStore(&local_hist[lid.x], 0u);
    workgroupBarrier();

    let i = gid.x;
    if i < sort_u.count {
        var key: u32;
        if sort_u.pass_num % 2u == 0u {
            key = s_keys_ping[i];
        } else {
            key = s_keys_pong[i];
        }
        let bucket = (key >> sort_u.shift) & 0xffu;
        atomicAdd(&local_hist[bucket], 1u);
    }

    workgroupBarrier();

    // Merge local histogram into global (one thread per bucket).
    if lid.x < 256u {
        let local_count = atomicLoad(&local_hist[lid.x]);
        if local_count > 0u {
            atomicAdd(&histogram[lid.x], local_count);
        }
    }
}

// Exclusive prefix sum over 256 entries (single workgroup, sequential).
@compute @workgroup_size(1)
fn prefix_sum_pass(@builtin(global_invocation_id) gid: vec3<u32>) {
    var sum: u32 = 0u;
    for (var i = 0u; i < 256u; i++) {
        let cnt = atomicLoad(&histogram[i]);
        atomicStore(&histogram[i], sum);
        sum += cnt;
    }
}

// Scatter elements into sorted order using the prefix-summed histogram.
@compute @workgroup_size(256)
fn scatter_pass(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i >= sort_u.count { return; }

    var key: u32;
    var val: u32;
    if sort_u.pass_num % 2u == 0u {
        key = s_keys_ping[i];
        val = s_vals_ping[i];
    } else {
        key = s_keys_pong[i];
        val = s_vals_pong[i];
    }

    let bucket = (key >> sort_u.shift) & 0xffu;
    let pos = atomicAdd(&histogram[bucket], 1u);

    if sort_u.pass_num % 2u == 0u {
        s_keys_pong[pos] = key;
        s_vals_pong[pos] = val;
    } else {
        s_keys_ping[pos] = key;
        s_vals_ping[pos] = val;
    }
}