rullama 0.5.0

Browser-resident Gemma 4 inference: pure Rust → WebAssembly + WebGPU. Loads Ollama's on-disk GGUF blobs and runs the forward pass on the local GPU via hand-written WGSL.
Documentation
// Batched Gemma 4 MoE router: one workgroup per token POSITION, each routing
// that position independently (vs the single-position moe_router.wgsl). For the
// DiffusionGemma 256-token canvas this is one dispatch instead of 256.
//
// Per position p: unweighted RMSNorm(x[p]) → ×1/√d → ×optional scale → expert
// scores (routerᵀ·x) → softmax over all experts → top-k → renorm. Writes
// out_ids[p*top_k + s] (u32) + out_w[p*top_k + s] (f32). Identical math to
// moe_router.wgsl (parity-tested there); this just adds the position index +
// per-position input/output offsets.

struct Params {
    n_pos:     u32,
    d_model:   u32,
    n_experts: u32,
    top_k:     u32,
    eps:       f32,
    has_scale: u32,
    _pad0:     u32,
    _pad1:     u32,
}

@group(0) @binding(0) var<uniform>             params:  Params;
@group(0) @binding(1) var<storage, read>       x:       array<f32>; // [n_pos, d_model]
@group(0) @binding(2) var<storage, read>       scale:   array<f32>; // [d_model] (shared)
@group(0) @binding(3) var<storage, read>       router:  array<f32>; // [n_experts, d_model]
@group(0) @binding(4) var<storage, read_write> out_ids: array<u32>; // [n_pos, top_k]
@group(0) @binding(5) var<storage, read_write> out_w:   array<f32>; // [n_pos, top_k]

const WG: u32 = 64u;
const MAX_EXPERTS: u32 = 256u;
const MAX_TOPK: u32 = 16u;

var<workgroup> tile:   array<f32, WG>;
var<workgroup> scores: array<f32, MAX_EXPERTS>;
var<workgroup> inv_rms_shared: f32;

@compute @workgroup_size(64)
fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_index) tid: u32) {
    let pos = wid.x;
    if (pos >= params.n_pos) { return; }
    let n = params.d_model;
    let n_exp = params.n_experts;
    let x_off = pos * n;

    // 1. sum-of-squares for the unweighted RMSNorm.
    var local_sumsq: f32 = 0.0;
    var i: u32 = tid;
    loop {
        if (i >= n) { break; }
        let val = x[x_off + i];
        local_sumsq = local_sumsq + val * val;
        i = i + WG;
    }
    tile[tid] = local_sumsq;
    workgroupBarrier();
    var stride: u32 = WG / 2u;
    loop {
        if (stride == 0u) { break; }
        if (tid < stride) { tile[tid] = tile[tid] + tile[tid + stride]; }
        workgroupBarrier();
        stride = stride / 2u;
    }
    if (tid == 0u) {
        inv_rms_shared = 1.0 / sqrt(tile[0] / f32(n) + params.eps);
    }
    workgroupBarrier();

    // 2+3. per-expert dot of the normed/scaled hidden state.
    let inv_rms = inv_rms_shared;
    let inv_sqrt_d = 1.0 / sqrt(f32(n));
    var e: u32 = tid;
    loop {
        if (e >= n_exp) { break; }
        var acc: f32 = 0.0;
        let row = e * n;
        for (var j: u32 = 0u; j < n; j = j + 1u) {
            var val = x[x_off + j] * inv_rms * inv_sqrt_d;
            if (params.has_scale != 0u) { val = val * scale[j]; }
            acc = acc + val * router[row + j];
        }
        scores[e] = acc;
        e = e + WG;
    }
    workgroupBarrier();

    // 4. softmax → top-k → renorm (thread 0; n_exp ≤ 256, k ≤ 16).
    if (tid == 0u) {
        var m: f32 = scores[0];
        for (var vv: u32 = 1u; vv < n_exp; vv = vv + 1u) { m = max(m, scores[vv]); }
        var z: f32 = 0.0;
        for (var vv: u32 = 0u; vv < n_exp; vv = vv + 1u) {
            let pr = exp(scores[vv] - m);
            scores[vv] = pr;
            z = z + pr;
        }
        for (var vv: u32 = 0u; vv < n_exp; vv = vv + 1u) { scores[vv] = scores[vv] / z; }
        let k = min(params.top_k, min(n_exp, MAX_TOPK));
        let o_off = pos * params.top_k;
        var sum_sel: f32 = 0.0;
        for (var s: u32 = 0u; s < k; s = s + 1u) {
            var best: u32 = 0u;
            var best_p: f32 = -1.0;
            for (var vv: u32 = 0u; vv < n_exp; vv = vv + 1u) {
                if (scores[vv] > best_p) { best_p = scores[vv]; best = vv; }
            }
            out_ids[o_off + s] = best;
            out_w[o_off + s] = best_p;
            sum_sel = sum_sel + best_p;
            scores[best] = -2.0;
        }
        for (var s: u32 = 0u; s < k; s = s + 1u) {
            out_w[o_off + s] = out_w[o_off + s] / sum_sel;
        }
    }
}