rullama 0.2.0

Browser-resident Gemma 4 inference: pure Rust → WebAssembly + WebGPU. Loads Ollama's on-disk GGUF blobs and runs the forward pass on the local GPU via hand-written WGSL.
Documentation
// 2D NeoX RoPE for the Gemma 4 vision tower.
//
// Per Ollama (model_vision.go lines 191–210), each head's first half of head_dim
// is rotated by NeoX RoPE using the patch's X position; the second half is
// rotated by NeoX RoPE using the patch's Y position. Single dispatch handles
// both halves and all heads in parallel.
//
// Layout: x is [head_dim, n_heads, n_patches] flat, fastest-varying axis = head_dim.
// We expect head_dim % 2 == 0 (true for Gemma 4 vision: head_dim = 768/12 = 64).
//
// The per-head element at (h, p, d) is at index `(p * n_heads + h) * head_dim + d`.
//
// rope_theta = 100.0 (NOT the text model's 10000).

struct Params {
    head_dim: u32,
    n_heads:  u32,
    n_patches: u32,
    base:     f32,
}

@group(0) @binding(0) var<uniform>             params:  Params;
@group(0) @binding(1) var<storage, read_write> x:       array<f32>;
@group(0) @binding(2) var<storage, read>       pos_x:   array<u32>;
@group(0) @binding(3) var<storage, read>       pos_y:   array<u32>;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    // One thread per (patch, head, half-pair). The 2D RoPE rotates pairs (i, i+halfDim/2)
    // within each half of head_dim. NeoX rotation: pair (low, high) where low = first
    // half of the chosen half-axis and high = second half. So in a 64-d head split into
    // two 32-d halves, the X half is dims [0..32) and rotates pairs (0,16),(1,17),...,(15,31);
    // the Y half is dims [32..64) and rotates pairs (32,48),(33,49),...,(47,63).
    let half_dim: u32 = params.head_dim / 2u;          // 32 for ViT
    let quarter:  u32 = half_dim / 2u;                 // 16 for ViT (each NeoX block-half)
    let total: u32 = params.n_patches * params.n_heads * half_dim;
    let i: u32 = gid.x;
    if (i >= total) { return; }

    // i decomposes as: (patch * n_heads + head) * half_dim + slot
    let slot: u32 = i % half_dim;
    let ph: u32 = i / half_dim;
    let head: u32 = ph % params.n_heads;
    let pidx: u32 = ph / params.n_heads;

    let row_base: u32 = (pidx * params.n_heads + head) * params.head_dim;

    // Are we in the X half (first 32 dims) or the Y half (last 32 dims)?
    // We dispatch one thread per slot in 0..half_dim, but we need to do two rotations
    // (one per half) per dispatch position. Easier: each thread handles both halves
    // for its slot.
    if (slot >= quarter) { return; }

    let pos_xv: f32 = f32(pos_x[pidx]);
    let pos_yv: f32 = f32(pos_y[pidx]);

    // theta_i = base ^ (-2i / half_dim), i in 0..quarter
    let exponent: f32 = -2.0 * f32(slot) / f32(half_dim);
    let theta: f32 = pow(params.base, exponent);
    let cos_x: f32 = cos(pos_xv * theta);
    let sin_x: f32 = sin(pos_xv * theta);
    let cos_y: f32 = cos(pos_yv * theta);
    let sin_y: f32 = sin(pos_yv * theta);

    // X half: dims [0..half_dim). NeoX pairs (slot, slot + quarter).
    let lo_x: u32 = row_base + slot;
    let hi_x: u32 = row_base + slot + quarter;
    let a_x: f32 = x[lo_x];
    let b_x: f32 = x[hi_x];
    x[lo_x] = a_x * cos_x - b_x * sin_x;
    x[hi_x] = a_x * sin_x + b_x * cos_x;

    // Y half: dims [half_dim..head_dim). NeoX pairs (half_dim + slot, half_dim + slot + quarter).
    let lo_y: u32 = row_base + half_dim + slot;
    let hi_y: u32 = row_base + half_dim + slot + quarter;
    let a_y: f32 = x[lo_y];
    let b_y: f32 = x[hi_y];
    x[lo_y] = a_y * cos_y - b_y * sin_y;
    x[hi_y] = a_y * sin_y + b_y * cos_y;
}