rlx-wgpu 0.2.6

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// RLX native Gaussian splat tile raster — matches rlx-metal/src/splat.msl.

const GAUSSIAN_SUPPORT_SIGMA_RADIUS: f32 = 3.0;
const OUTPUT_GAMMA: f32 = 2.2;

struct SplatRasterParams {
    width: u32,
    height: u32,
    tile_size: u32,
    tile_width: u32,
    alpha_cutoff: f32,
    transmittance_threshold: f32,
    bg_r: f32,
    bg_g: f32,
    bg_b: f32,
    dst_base: u32,
}

@group(0) @binding(0) var<storage, read_write> dst: array<f32>;
@group(0) @binding(1) var<storage, read> color_alpha: array<f32>;
@group(0) @binding(2) var<storage, read> valid: array<u32>;
@group(0) @binding(3) var<storage, read> pos_local: array<f32>;
@group(0) @binding(4) var<storage, read> inv_scale: array<f32>;
@group(0) @binding(5) var<storage, read> quat: array<f32>;
@group(0) @binding(6) var<storage, read> sorted_values: array<u32>;
@group(0) @binding(7) var<storage, read> tile_ranges: array<u32>;
@group(0) @binding(8) var<storage, read> rays: array<f32>;
@group(0) @binding(9) var<uniform> params: SplatRasterParams;

fn quat_rotate(v: vec3f, q_wxyz: vec4f) -> vec3f {
    let qv = q_wxyz.yzw;
    let w = q_wxyz.x;
    let t1 = cross(v, qv);
    let mid = t1 + w * v;
    let t2 = cross(mid, qv);
    return v + 2.0 * t2;
}

fn ray_splat_intersection_alpha(
    splat_id: u32,
    ray_direction: vec3f,
    alpha_cutoff: f32,
) -> f32 {
    let ca = splat_id * 4u;
    let opacity = clamp(color_alpha[ca + 3u], 0.0, 1.0);
    if (opacity < alpha_cutoff) {
        return 0.0;
    }
    let support_sigma_radius = sqrt(max(0.0, -2.0 * log(alpha_cutoff / max(opacity, alpha_cutoff))));
    if (support_sigma_radius <= 1e-10) {
        return 0.0;
    }
    let support_scale = GAUSSIAN_SUPPORT_SIGMA_RADIUS / support_sigma_radius;
    let pl = splat_id * 3u;
    let ro_local = vec3f(pos_local[pl], pos_local[pl + 1u], pos_local[pl + 2u]) * support_scale;
    let qb = splat_id * 4u;
    let q_wxyz = vec4f(quat[qb], quat[qb + 1u], quat[qb + 2u], quat[qb + 3u]);
    let is = splat_id * 3u;
    let inv = vec3f(inv_scale[is], inv_scale[is + 1u], inv_scale[is + 2u]);
    let rotated = quat_rotate(ray_direction, q_wxyz);
    let ray_local = vec3f(rotated.x * inv.x, rotated.y * inv.y, rotated.z * inv.z) * support_scale;
    let denom = dot(ray_local, ray_local);
    if (denom <= 1e-10) {
        return 0.0;
    }
    let t_closest = -dot(ray_local, ro_local) / denom;
    if (t_closest <= 0.0) {
        return 0.0;
    }
    let closest = ro_local + ray_local * t_closest;
    let rho2 = max(0.0, dot(closest, closest));
    return opacity * exp(-0.5 * support_sigma_radius * support_sigma_radius * rho2);
}

@compute @workgroup_size(8, 8, 1)
fn gaussian_splat_rasterize(@builtin(global_invocation_id) gid: vec3<u32>) {
    let px = gid.x;
    let py = gid.y;
    if (px >= params.width || py >= params.height) {
        return;
    }
    let out_base = params.dst_base + (py * params.width + px) * 4u;
    let tile_y = py / params.tile_size;
    let tile = tile_y * params.tile_width + px / params.tile_size;
    let range_base = tile * 2u;
    let start = tile_ranges[range_base];
    let end = tile_ranges[range_base + 1u];
    if (start == 0xFFFFFFFFu || end <= start) {
        dst[out_base] = params.bg_r;
        dst[out_base + 1u] = params.bg_g;
        dst[out_base + 2u] = params.bg_b;
        dst[out_base + 3u] = 1.0;
        return;
    }
    let ray_base = (py * params.width + px) * 3u;
    let ray = vec3f(rays[ray_base], rays[ray_base + 1u], rays[ray_base + 2u]);
    var accum = vec3f(0.0);
    var trans = 1.0;
    for (var i = start; i < end; i++) {
        let splat_id = sorted_values[i];
        if (valid[splat_id] == 0u) {
            continue;
        }
        let alpha = ray_splat_intersection_alpha(splat_id, ray, params.alpha_cutoff);
        if (alpha < params.alpha_cutoff) {
            continue;
        }
        let rgb_base = splat_id * 4u;
        accum += trans * alpha * vec3f(
            color_alpha[rgb_base],
            color_alpha[rgb_base + 1u],
            color_alpha[rgb_base + 2u],
        );
        trans *= 1.0 - alpha;
        if (trans < params.transmittance_threshold) {
            break;
        }
    }
    let bg = vec3f(params.bg_r, params.bg_g, params.bg_b);
    let composed = max(vec3f(0.0), accum + trans * bg);
    dst[out_base] = pow(composed.x, OUTPUT_GAMMA);
    dst[out_base + 1u] = pow(composed.y, OUTPUT_GAMMA);
    dst[out_base + 2u] = pow(composed.z, OUTPUT_GAMMA);
    dst[out_base + 3u] = 1.0 - trans;
}