boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! Paged Attention - vLLM style block-table indirection
//! Q [B, num_heads, S_q, head_dim]
//! K_blocks [num_blocks, block_size, head_dim]
//! V_blocks [num_blocks, block_size, head_dim]
//! block_table [B, max_num_blocks] (i32)
//!
//! Maps logical token position to physical block via block_table.
//! Token t in sequence is in block t/block_size, offset t%block_size.

struct PagedParams {
    batch_size: u32,
    num_heads: u32,
    num_kv_heads: u32,
    seq_len_q: u32,
    seq_len_k: u32,
    head_dim: u32,
    block_size: u32,
    max_num_blocks: u32,
    scale: f32,
    causal: u32,
}

@group(0) @binding(0) var<storage, read> q: array<f32>;
@group(0) @binding(1) var<storage, read> k_blocks: array<f32>;
@group(0) @binding(2) var<storage, read> v_blocks: array<f32>;
@group(0) @binding(3) var<storage, read> block_table: array<i32>;
@group(0) @binding(4) var<storage, read_write> out: array<f32>;
@group(0) @binding(5) var<storage, read_write> lse: array<f32>;
@group(0) @binding(6) var<uniform> params: PagedParams;

@compute @workgroup_size(256)
fn paged_attention_fwd_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let query_idx = gid.x;
    let total_queries = params.batch_size * params.num_heads * params.seq_len_q;

    if query_idx >= total_queries {
        return;
    }

    // Decode query position
    let i = query_idx % params.seq_len_q;
    let remainder = query_idx / params.seq_len_q;
    let h = remainder % params.num_heads;
    let b = remainder / params.num_heads;

    // GQA: map query head to kv head
    let h_kv = (h * params.num_kv_heads) / params.num_heads;

    // Read Q[b, h, i, :]
    let q_base = ((b * params.num_heads + h) * params.seq_len_q + i) * params.head_dim;

    var accum: array<f32, 512>;  // Max head_dim
    var max_score = -1e30f;
    var sum_exp = 0.0f;

    // Compute valid key range
    var end_j = params.seq_len_k;
    if params.causal != 0u {
        end_j = i + 1u;
    }

    // First pass: find max score
    for (var j = 0u; j < end_j; j = j + 1u) {
        // Map logical position j to physical position via block table
        let block_idx_logical = j / params.block_size;
        let offset_in_block = j % params.block_size;

        let bt_idx = b * params.max_num_blocks + block_idx_logical;
        if bt_idx >= arrayLength(&block_table) {
            break;
        }
        let block_idx_physical = u32(block_table[bt_idx]);

        // K[block_idx_physical, offset_in_block, h_kv, :]
        let k_base = ((block_idx_physical * params.block_size + offset_in_block) * params.num_kv_heads + h_kv) * params.head_dim;

        if k_base + params.head_dim > arrayLength(&k_blocks) {
            break;
        }

        var score = 0.0f;
        for (var d = 0u; d < params.head_dim; d = d + 1u) {
            score += q[q_base + d] * k_blocks[k_base + d];
        }
        score *= params.scale;
        max_score = max(max_score, score);
    }

    if max_score == -1e30f {
        max_score = 0.0f;
    }

    // Second pass: softmax and aggregate
    for (var j = 0u; j < end_j; j = j + 1u) {
        let block_idx_logical = j / params.block_size;
        let offset_in_block = j % params.block_size;

        let bt_idx = b * params.max_num_blocks + block_idx_logical;
        if bt_idx >= arrayLength(&block_table) {
            break;
        }
        let block_idx_physical = u32(block_table[bt_idx]);

        let k_base = ((block_idx_physical * params.block_size + offset_in_block) * params.num_kv_heads + h_kv) * params.head_dim;
        if k_base + params.head_dim > arrayLength(&k_blocks) {
            break;
        }

        var score = 0.0f;
        for (var d = 0u; d < params.head_dim; d = d + 1u) {
            score += q[q_base + d] * k_blocks[k_base + d];
        }
        score *= params.scale;

        let weight = exp(score - max_score);
        sum_exp += weight;

        // V[block_idx_physical, offset_in_block, h_kv, :]
        let v_base = ((block_idx_physical * params.block_size + offset_in_block) * params.num_kv_heads + h_kv) * params.head_dim;
        for (var d = 0u; d < params.head_dim; d = d + 1u) {
            accum[d] += weight * v_blocks[v_base + d];
        }
    }

    // Write output
    let out_base = ((b * params.num_heads + h) * params.seq_len_q + i) * params.head_dim;
    for (var d = 0u; d < params.head_dim; d = d + 1u) {
        out[out_base + d] = accum[d] / max(sum_exp, 1e-10f);
    }

    let lse_idx = (b * params.num_heads + h) * params.seq_len_q + i;
    lse[lse_idx] = log(max(sum_exp, 1e-10f)) + max_score;
}

@compute @workgroup_size(256)
fn paged_attention_bwd_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    // Backward pass placeholder
}