oxillama-gpu 0.1.3

Optional wgpu GPU compute backend for OxiLLaMa
Documentation
// attention_fused_f32.wgsl — fused QK + online softmax + AV in a single dispatch.
//
// Implements Flash-Attention-style online softmax: one workgroup per Q row,
// iterating over KV tiles with numerically stable running max (m) and
// normalisation (l).
//
// Parameters:
//   seq_len_q  — number of Q rows
//   seq_len_kv — number of K/V rows
//   head_dim   — dimension of each Q/K/V vector (≤ 64)
//   scale      — pre-multiplied scale factor (typically 1/sqrt(head_dim))
//   causal     — 0: full attention, 1: causal mask (only attend to j <= q_row)
//
// Workgroup size: 64 threads (one per KV element within a tile of size 64).
// Each workgroup handles exactly one Q row.
//
// Shared memory usage:
//   K_tile[TILE_K × 64]:  64 × 64 = 4096 × 4 = 16 KiB — exactly at limit.
//   V_tile[TILE_K × 64]:  another 16 KiB — exceeds 16 KiB baseline.
//
// To fit within the 16 KiB WebGPU minimum, we use TILE_K = 32, head_dim ≤ 64:
//   K_tile[32 × 64]  = 2048 × 4 = 8 KiB
//   V_tile[32 × 64]  = 2048 × 4 = 8 KiB
//   dot_scratch[64]  = 64 × 4   = 256 B (dedicated QK reduction scratchpad)
//   Total ≈ 16.25 KiB — within desktop Metal/Vulkan limits (32 KiB).
//
// For head_dim > 64, callers must split into multiple passes (not handled here).

struct AttentionParams {
    seq_len_q:  u32,
    seq_len_kv: u32,
    head_dim:   u32,
    scale:      f32,
    causal:     u32,  // 0 = no mask, 1 = causal
}

@group(0) @binding(0) var<storage, read>       Q:      array<f32>;  // [seq_q  × head_dim]
@group(0) @binding(1) var<storage, read>       K:      array<f32>;  // [seq_kv × head_dim]
@group(0) @binding(2) var<storage, read>       V:      array<f32>;  // [seq_kv × head_dim]
@group(0) @binding(3) var<storage, read_write> Out:    array<f32>;  // [seq_q  × head_dim]
@group(0) @binding(4) var<uniform>             params: AttentionParams;

// TILE_K = 32 KV rows per tile iteration.
const TILE_K: u32 = 32u;

// Shared tiles — sized for TILE_K=32, head_dim≤64.
var<workgroup> K_tile:       array<f32, 2048>;  // TILE_K × 64 = 8 KiB
var<workgroup> V_tile:       array<f32, 2048>;  // TILE_K × 64 = 8 KiB
var<workgroup> dot_scratch:  array<f32, 64>;    // dot-product reduction scratchpad (256 B)
// Total shared memory: 8 KiB + 8 KiB + 256 B = ~16.25 KiB.
// Note: wgpu on Metal allows up to 32 KiB; the 16 KiB WebGPU baseline is for browsers.
// On desktop targets this is within the real limit.

// One workgroup per Q row; 64 threads per workgroup.
@compute @workgroup_size(64)
fn main(
    @builtin(workgroup_id)        wg_id:    vec3<u32>,
    @builtin(local_invocation_id) local_id: vec3<u32>,
) {
    let q_row = wg_id.x;          // which Q row this workgroup handles
    let tid   = local_id.x;       // thread index within workgroup (0..63)
    let hd    = params.head_dim;

    // Out-of-bounds workgroup guard.
    if q_row >= params.seq_len_q { return; }

    // Load Q row into registers: each thread caches Q[q_row, tid] if tid < hd.
    var q_val: f32 = 0.0;
    if tid < hd {
        q_val = Q[q_row * hd + tid];
    }

    // Online softmax state.
    var m: f32 = -1e38;  // running max
    var l: f32 = 0.0;    // running normalisation (sum of exp)

    // Accumulator for output: o[tid] (register storage; max hd = 64 = workgroup size).
    var o: f32 = 0.0;

    let n_kv_tiles = (params.seq_len_kv + TILE_K - 1u) / TILE_K;

    for (var t: u32 = 0u; t < n_kv_tiles; t++) {
        let kv_start = t * TILE_K;

        // ── Cooperative load K_tile and V_tile ────────────────────────────
        // 64 threads load TILE_K × hd elements.
        // Each thread loads ceil(TILE_K × hd / 64) elements.
        // With TILE_K=32, hd≤64: max 32 elements per thread.
        for (var i: u32 = tid; i < TILE_K * hd; i += 64u) {
            let kv_row = kv_start + i / hd;
            let dim    = i % hd;
            if kv_row < params.seq_len_kv {
                K_tile[i] = K[kv_row * hd + dim];
                V_tile[i] = V[kv_row * hd + dim];
            } else {
                K_tile[i] = 0.0;
                V_tile[i] = 0.0;
            }
        }
        workgroupBarrier();

        // ── Process each KV row in this tile ──────────────────────────────
        for (var kv_local: u32 = 0u; kv_local < TILE_K; kv_local++) {
            let kv_abs = kv_start + kv_local;
            if kv_abs >= params.seq_len_kv { break; }

            // Causal mask: skip future tokens.
            if params.causal != 0u && kv_abs > q_row { continue; }

            // dot(Q[q_row], K[kv_abs]) across head_dim — all threads in wg
            // compute a partial sum of exactly one element.
            // Reduce 64 partial sums via dot_scratch (separate from K_tile/V_tile).
            var partial: f32 = 0.0;
            if tid < hd {
                partial = q_val * K_tile[kv_local * hd + tid];
            }

            // Write partial product to dedicated dot-product scratchpad.
            // Using dot_scratch keeps K_tile intact for subsequent kv_local iterations.
            dot_scratch[tid] = partial;
            workgroupBarrier();

            // Parallel reduction of 64 elements → single sum.
            // Assumes workgroup size = 64 (power of two).
            if tid < 32u { dot_scratch[tid] = dot_scratch[tid] + dot_scratch[tid + 32u]; } workgroupBarrier();
            if tid < 16u { dot_scratch[tid] = dot_scratch[tid] + dot_scratch[tid + 16u]; } workgroupBarrier();
            if tid < 8u  { dot_scratch[tid] = dot_scratch[tid] + dot_scratch[tid +  8u]; } workgroupBarrier();
            if tid < 4u  { dot_scratch[tid] = dot_scratch[tid] + dot_scratch[tid +  4u]; } workgroupBarrier();
            if tid < 2u  { dot_scratch[tid] = dot_scratch[tid] + dot_scratch[tid +  2u]; } workgroupBarrier();
            if tid < 1u  { dot_scratch[0]   = dot_scratch[0]   + dot_scratch[1];         } workgroupBarrier();

            let score = dot_scratch[0] * params.scale;

            // Online softmax update (Flash-Attention style):
            //   m_new = max(m, score)
            //   P     = exp(score - m_new)
            //   l_new = exp(m - m_new) * l + P
            //   o_new = exp(m - m_new) * o + P * v[tid]
            let m_new   = max(m, score);
            let exp_old = exp(m - m_new);   // rescale factor for old state
            let p       = exp(score - m_new);
            let l_new   = exp_old * l + p;

            var v_val: f32 = 0.0;
            if tid < hd {
                v_val = V_tile[kv_local * hd + tid];
            }
            o = exp_old * o + p * v_val;

            m = m_new;
            l = l_new;
        }

        workgroupBarrier();
    }

    // ── Write output ───────────────────────────────────────────────────────
    // Normalise: o / l.
    if tid < hd {
        var result: f32;
        if l > 0.0 {
            result = o / l;
        } else {
            result = 0.0;
        }
        Out[q_row * hd + tid] = result;
    }
}