boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
// Speculative Decoding Shaders
//
// Element-wise acceptance probability and expected token count computation.
// Token verification (verify_speculative_tokens) is handled by impl_generic
// using numr's philox_uniform for reproducible, backend-consistent RNG.
// F32 only (WebGPU limitation).

// -----------------------------------------------------------------
// Element-wise acceptance and residual probabilities
// -----------------------------------------------------------------

struct AcceptParams {
    total_elements: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

@group(0) @binding(0) var<storage, read> accept_draft: array<f32>;
@group(0) @binding(1) var<storage, read> accept_target: array<f32>;
@group(0) @binding(2) var<storage, read_write> acceptance_out: array<f32>;
@group(0) @binding(3) var<storage, read_write> residual_out: array<f32>;
@group(0) @binding(4) var<uniform> accept_params: AcceptParams;

@compute @workgroup_size(256)
fn compute_acceptance_probs_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= accept_params.total_elements) {
        return;
    }

    let dp = accept_draft[idx];
    let tp = accept_target[idx];

    // Acceptance: min(1, target / draft)
    var accept = 1.0;
    if (dp > 1e-10) {
        accept = min(1.0, tp / dp);
    }
    acceptance_out[idx] = accept;

    // Residual: max(0, target - draft)
    residual_out[idx] = max(0.0, tp - dp);
}

// -----------------------------------------------------------------
// Expected tokens computation
// -----------------------------------------------------------------

struct ExpectedParams {
    batch_size: u32,
    max_spec_tokens: u32,
    _pad0: u32,
    _pad1: u32,
}

@group(0) @binding(0) var<storage, read> exp_rates: array<f32>;
@group(0) @binding(1) var<storage, read_write> expected_out: array<f32>;
@group(0) @binding(2) var<uniform> expected_params: ExpectedParams;

@compute @workgroup_size(256)
fn compute_expected_tokens_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let batch_idx = gid.x;
    if (batch_idx >= expected_params.batch_size) {
        return;
    }

    let K = expected_params.max_spec_tokens;
    var cumulative_prob = 1.0;
    var expected = 0.0;

    for (var i = 0u; i < K; i = i + 1u) {
        cumulative_prob = cumulative_prob * exp_rates[batch_idx * K + i];
        expected = expected + cumulative_prob;
    }

    // +1 for bonus token
    expected_out[batch_idx] = expected + 1.0;
}