boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! Reshape and cache for paged KV storage
//! Maps tokens to physical blocks via slot_mapping.

struct PagedCacheParams {
    num_tokens: u32,
    num_heads: u32,
    head_dim: u32,
    block_size: u32,
}

// Read-only first (bindings 0-2), then read-write (bindings 3-4)
@group(0) @binding(0) var<storage, read> key: array<f32>;
@group(0) @binding(1) var<storage, read> value: array<f32>;
@group(0) @binding(2) var<storage, read> slot_mapping: array<i32>;
@group(0) @binding(3) var<storage, read_write> key_cache: array<f32>;
@group(0) @binding(4) var<storage, read_write> value_cache: array<f32>;
@group(0) @binding(5) var<uniform> paged_params: PagedCacheParams;

@compute @workgroup_size(256)
fn reshape_and_cache_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let elem_idx = gid.x;
    let total_elems = paged_params.num_tokens * paged_params.num_heads * paged_params.head_dim;

    if elem_idx >= total_elems {
        return;
    }

    let d = elem_idx % paged_params.head_dim;
    let remainder = elem_idx / paged_params.head_dim;
    let h = remainder % paged_params.num_heads;
    let t = remainder / paged_params.num_heads;

    let slot = u32(slot_mapping[t]);
    let block_idx = slot / paged_params.block_size;
    let offset_in_block = slot % paged_params.block_size;

    let key_idx = (t * paged_params.num_heads + h) * paged_params.head_dim + d;
    let value_idx = (t * paged_params.num_heads + h) * paged_params.head_dim + d;

    let cache_idx = ((block_idx * paged_params.block_size + offset_in_block) * paged_params.num_heads + h) * paged_params.head_dim + d;

    key_cache[cache_idx] = key[key_idx];
    value_cache[cache_idx] = value[value_idx];
}