rullama 0.5.0

Browser-resident Gemma 4 inference: pure Rust → WebAssembly + WebGPU. Loads Ollama's on-disk GGUF blobs and runs the forward pass on the local GPU via hand-written WGSL.
Documentation
// Backward of `y = matmul_q4_0(W, x)` with respect to x.
//
// Forward (q4_0_dequant_matmul.wgsl): y[j] = Σ_i x[i] * dequant(W)[j, i]
// Backward:                            dx[i] = Σ_j dy[j] * dequant(W)[j, i]
//
// W is row-major [n, k] with each row packed into k/32 Q4_0 blocks of 18 bytes
// (32 elements each). One workgroup per block-row of k (32 output elements);
// each thread owns one i within that block-row, loops over j ∈ [j_start, j_end),
// reads its own nibble from the (j, block_row) Q4_0 block, and accumulates.
//
// W is frozen (LoRA convention) — there is no weight gradient. The j_start/j_end/
// accumulate params support the vocab-axis tiling used by the output-proj
// backward; non-tiled callers pass j_start=0, j_end=n, accumulate=0.

struct Params {
    k: u32,
    n: u32,
    j_start: u32,
    j_end: u32,
    accumulate: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

@group(0) @binding(0) var<uniform>             params: Params;
@group(0) @binding(1) var<storage, read>       weight: array<u32>;
@group(0) @binding(2) var<storage, read>       dy:     array<f32>;
@group(0) @binding(3) var<storage, read_write> dx:     array<f32>;

const BLOCK_ELEMS: u32 = 32u;
const BLOCK_BYTES: u32 = 18u;

fn read_byte(byte_off: u32) -> u32 {
    let u32_idx = byte_off >> 2u;
    let shift   = (byte_off & 3u) << 3u;
    return (weight[u32_idx] >> shift) & 0xFFu;
}

fn read_f16_as_f32(byte_off: u32) -> f32 {
    let lo = read_byte(byte_off);
    let hi = read_byte(byte_off + 1u);
    let packed: u32 = lo | (hi << 8u);
    return unpack2x16float(packed).x;
}

@compute @workgroup_size(32)
fn main(
    @builtin(workgroup_id) wg: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let block_row: u32 = wg.x;
    let tid:       u32 = lid.x;
    let i:         u32 = block_row * BLOCK_ELEMS + tid;
    if (i >= params.k) { return; }

    let n_blocks:  u32 = params.k / BLOCK_ELEMS;
    let row_bytes: u32 = n_blocks * BLOCK_BYTES;

    // Q4_0 packs the first 16 elements in the low nibbles of qs[0..16] and the
    // next 16 in the high nibbles. tid → which byte + which nibble (constant
    // for the life of this thread).
    let nibble_hi: bool = tid >= 16u;
    let qs_idx: u32 = select(tid, tid - 16u, nibble_hi);
    let qs_local_off: u32 = 2u + qs_idx;

    var acc: f32 = 0.0;
    for (var j: u32 = params.j_start; j < params.j_end; j = j + 1u) {
        let block_off: u32 = j * row_bytes + block_row * BLOCK_BYTES;
        let d: f32 = read_f16_as_f32(block_off + 0u);
        let q = read_byte(block_off + qs_local_off);
        let nibble: f32 = select(f32(q & 0xFu), f32(q >> 4u), nibble_hi);
        let w_ij: f32 = (nibble - 8.0) * d;
        acc = acc + w_ij * dy[j];
    }

    if (params.accumulate == 0u) {
        dx[i] = acc;
    } else {
        dx[i] = dx[i] + acc;
    }
}