boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! YaRN Rotary Position Embedding shader
//! Input: x [B, H, S, D], cos_cache [S, D/2], sin_cache [S, D/2]
//! Output: x_rotated [B, H, S, D]
//! Same split-half pairing as standard RoPE with attention scaling.
//! Each thread handles one dimension pair (2 elements)

struct YaRNParams {
    batch_size: u32,
    num_heads: u32,
    seq_len: u32,
    head_dim: u32,
    attn_scale: f32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> cos_cache: array<f32>;
@group(0) @binding(2) var<storage, read> sin_cache: array<f32>;
@group(0) @binding(3) var<storage, read_write> out: array<f32>;
@group(0) @binding(4) var<uniform> params: YaRNParams;

@compute @workgroup_size(256)
fn rope_yarn_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let pair_idx = gid.x;
    let half_d = params.head_dim / 2u;
    let total_pairs = params.batch_size * params.num_heads * params.seq_len * half_d;

    if pair_idx >= total_pairs {
        return;
    }

    // Decode linear index to (b, h, s, d_pair)
    let d_pair = pair_idx % half_d;
    let remainder = pair_idx / half_d;

    let s = remainder % params.seq_len;
    let remainder2 = remainder / params.seq_len;

    let h = remainder2 % params.num_heads;
    let b = remainder2 / params.num_heads;

    // Read cos/sin from cache [s, d_pair]
    let cache_idx = s * half_d + d_pair;
    let cos_val = cos_cache[cache_idx];
    let sin_val = sin_cache[cache_idx];

    // Split-half pairing: first half pairs with second half
    let x_base = ((b * params.num_heads + h) * params.seq_len + s) * params.head_dim;
    let idx_first = x_base + d_pair;
    let idx_second = x_base + half_d + d_pair;

    let x_first = x[idx_first];
    let x_second = x[idx_second];

    let scale = params.attn_scale;

    // Apply rotation + scaling
    out[idx_first]  = (x_first * cos_val - x_second * sin_val) * scale;
    out[idx_second] = (x_first * sin_val + x_second * cos_val) * scale;
}