rullama 0.2.0

Browser-resident Gemma 4 inference: pure Rust → WebAssembly + WebGPU. Loads Ollama's on-disk GGUF blobs and runs the forward pass on the local GPU via hand-written WGSL.
Documentation
// Attention backward — pass 1: produces `d_scores` (staged for pass 2) and `d_q`.
//
// One workgroup per query head `h`. Each WG, sharing one `d_probs`/`d_scores`
// staging array via workgroup memory:
//
//   1. d_probs[j]   = Σ_d  d_out[h, d] · v_hist[j, kv, d]
//   2. sum_pd       = Σ_j  probs[h, j] · d_probs[j]
//   3. d_scores[j]  = probs[h, j] · (d_probs[j] - sum_pd)
//   4. d_q[h, d]    = Σ_j  d_scores[j] · k_hist[j, kv, d]
//
// `d_scores` is also written to a global `[n_heads, history_len]` buffer
// so pass 2 (`attention_backward_dkv.wgsl`) can consume it without redoing
// the math. Forward attention's score is the raw dot product (Gemma 4
// absorbs the inverse-√d factor into the q RMSNorm), so the backward
// arithmetic is un-scaled as well.

struct Params {
    head_dim:     u32,
    n_heads:      u32,
    n_kv_heads:   u32,
    heads_per_kv: u32,
    history_len:  u32,
    _pad0:        u32,
    _pad1:        u32,
    _pad2:        u32,
}

// Pass 1 doesn't need `q` — d_q depends only on d_scores and k_hist, and
// d_scores depends only on probs, v_hist, and d_out. `q` lives on the
// pass-2 bind group.
@group(0) @binding(0) var<uniform>             params:   Params;
@group(0) @binding(1) var<storage, read>       k_hist:   array<f32>;
@group(0) @binding(2) var<storage, read>       v_hist:   array<f32>;
@group(0) @binding(3) var<storage, read>       probs:    array<f32>;
@group(0) @binding(4) var<storage, read>       d_out:    array<f32>;
@group(0) @binding(5) var<storage, read_write> d_scores: array<f32>;
@group(0) @binding(6) var<storage, read_write> d_q:      array<f32>;

const WG: u32 = 64u;
// Mirrors the forward attention kernel's MAX_HISTORY — keeps workgroup
// memory at 16 KB (4096 × 4 B) and accommodates the longest sequence
// rullama currently supports.
const MAX_HISTORY: u32 = 4096u;

var<workgroup> ds_buf: array<f32, MAX_HISTORY>;
var<workgroup> rbuf:   array<f32, WG>;

fn block_sum_reduce(tid: u32) -> f32 {
    var stride: u32 = WG / 2u;
    loop {
        if (stride == 0u) { break; }
        if (tid < stride) {
            rbuf[tid] = rbuf[tid] + rbuf[tid + stride];
        }
        workgroupBarrier();
        stride = stride / 2u;
    }
    return rbuf[0];
}

@compute @workgroup_size(64)
fn main(
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(local_invocation_index) tid: u32,
) {
    let h = wid.x;
    if (h >= params.n_heads) { return; }
    let kv = h / params.heads_per_kv;
    let head_dim = params.head_dim;
    let history_len = params.history_len;
    let q_off = h * head_dim;
    let p_off = h * history_len;

    // ---- Phase A: d_probs[j] into ds_buf, accumulate partial sum_pd ----
    var local_sum_pd: f32 = 0.0;
    var j_a: u32 = tid;
    loop {
        if (j_a >= history_len) { break; }
        let v_off = (j_a * params.n_kv_heads + kv) * head_dim;
        var dp: f32 = 0.0;
        for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
            dp = dp + d_out[q_off + d] * v_hist[v_off + d];
        }
        ds_buf[j_a] = dp;
        local_sum_pd = local_sum_pd + probs[p_off + j_a] * dp;
        j_a = j_a + WG;
    }
    rbuf[tid] = local_sum_pd;
    workgroupBarrier();
    let sum_pd = block_sum_reduce(tid);
    workgroupBarrier();

    // ---- Phase B: d_scores[j] = probs[h, j] · (d_probs[j] - sum_pd) ----
    var j_b: u32 = tid;
    loop {
        if (j_b >= history_len) { break; }
        let dp = ds_buf[j_b];
        let ds = probs[p_off + j_b] * (dp - sum_pd);
        ds_buf[j_b] = ds;
        d_scores[p_off + j_b] = ds;
        j_b = j_b + WG;
    }
    workgroupBarrier();

    // ---- Phase C: d_q[h, d] = Σ_j d_scores[j] · k_hist[j, kv, d] ----
    var d_c: u32 = tid;
    loop {
        if (d_c >= head_dim) { break; }
        var acc: f32 = 0.0;
        for (var j: u32 = 0u; j < history_len; j = j + 1u) {
            let k_off = (j * params.n_kv_heads + kv) * head_dim;
            acc = acc + ds_buf[j] * k_hist[k_off + d_c];
        }
        d_q[q_off + d_c] = acc;
        d_c = d_c + WG;
    }
}