boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! Scaled Dot-Product Attention (SDPA) for MLA
//! Supports different K and V last dimensions
//! Q [B, H, S_q, D_k], K [B, H, S_k, D_k], V [B, H, S_k, D_v]
//! Output [B, H, S_q, D_v]
//!
//! O(N²) implementation using simple loop structure.
//! Each workgroup handles one (batch, head, query_position) combination.

struct SdpaParams {
    batch_size: u32,
    num_heads: u32,
    seq_len_q: u32,
    seq_len_k: u32,
    head_dim_k: u32,
    head_dim_v: u32,
    scale: f32,
    causal: u32,
}

@group(0) @binding(0) var<storage, read> q: array<f32>;
@group(0) @binding(1) var<storage, read> k: array<f32>;
@group(0) @binding(2) var<storage, read> v: array<f32>;
@group(0) @binding(3) var<storage, read_write> out: array<f32>;
@group(0) @binding(4) var<uniform> params: SdpaParams;

@compute @workgroup_size(256)
fn sdpa_forward_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let query_idx = gid.x;
    let total_queries = params.batch_size * params.num_heads * params.seq_len_q;

    if query_idx >= total_queries {
        return;
    }

    // Decode query position
    let i = query_idx % params.seq_len_q;
    let remainder = query_idx / params.seq_len_q;
    let h = remainder % params.num_heads;
    let b = remainder / params.num_heads;

    // Read query vector Q[b, h, i, :] (head_dim_k elements)
    let q_base = ((b * params.num_heads + h) * params.seq_len_q + i) * params.head_dim_k;

    // Compute attention scores for all keys and aggregate values
    var accum: array<f32, 512>;  // Max head_dim_v = 512
    var max_score = -1e30f;
    var sum_exp = 0.0f;

    // First pass: find max score for numerical stability
    for (var j = 0u; j < params.seq_len_k; j = j + 1u) {
        if params.causal != 0u && i < j {
            // Causal mask: ignore future positions
            continue;
        }

        // Q @ K^T
        let k_base = ((b * params.num_heads + h) * params.seq_len_k + j) * params.head_dim_k;
        var score = 0.0f;
        for (var d = 0u; d < params.head_dim_k; d = d + 1u) {
            score += q[q_base + d] * k[k_base + d];
        }
        score *= params.scale;
        max_score = max(max_score, score);
    }

    // Second pass: compute softmax and weighted sum
    for (var j = 0u; j < params.seq_len_k; j = j + 1u) {
        if params.causal != 0u && i < j {
            continue;
        }

        let k_base = ((b * params.num_heads + h) * params.seq_len_k + j) * params.head_dim_k;
        var score = 0.0f;
        for (var d = 0u; d < params.head_dim_k; d = d + 1u) {
            score += q[q_base + d] * k[k_base + d];
        }
        score *= params.scale;

        let weight = exp(score - max_score);
        sum_exp += weight;

        // Accumulate weighted V
        let v_base = ((b * params.num_heads + h) * params.seq_len_k + j) * params.head_dim_v;
        for (var d = 0u; d < params.head_dim_v; d = d + 1u) {
            accum[d] += weight * v[v_base + d];
        }
    }

    // Normalize and write output
    let out_base = ((b * params.num_heads + h) * params.seq_len_q + i) * params.head_dim_v;
    for (var d = 0u; d < params.head_dim_v; d = d + 1u) {
        out[out_base + d] = accum[d] / sum_exp;
    }
}