numr 0.5.1

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Auto-generated special ternary functions for f32

const WORKGROUP_SIZE: u32 = 256u;
const PI: f32 = 3.14159265358979323846;
const SQRT_PI: f32 = 1.7724538509055159;
const EULER_GAMMA: f32 = 0.5772156649015329;
const LN_SQRT_2PI: f32 = 0.9189385332046727;
const LANCZOS_G: f32 = 7.0;
const MAX_ITER: i32 = 100;
const EPSILON: f32 = 1e-6;
const TINY: f32 = 1e-30;

struct SpecialTernaryParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> special_a: array<f32>;
@group(0) @binding(1) var<storage, read_write> special_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> special_x: array<f32>;
@group(0) @binding(3) var<storage, read_write> special_out: array<f32>;
@group(0) @binding(4) var<uniform> special_params: SpecialTernaryParams;

// ============================================================================
// Helper Functions (shared lgamma)
// ============================================================================

// Lanczos computation for positive x only (no recursion)
fn lgamma_positive(x: f32) -> f32 {
    // Lanczos coefficients (g=7, n=9)
    let c0 = 0.99999999999980993;
    let c1 = 676.5203681218851;
    let c2 = -1259.1392167224028;
    let c3 = 771.32342877765313;
    let c4 = -176.61502916214059;
    let c5 = 12.507343278686905;
    let c6 = -0.13857109526572012;
    let c7 = 9.9843695780195716e-6;
    let c8 = 1.5056327351493116e-7;

    let z = x - 1.0;
    var ag = c0;
    ag = ag + c1 / (z + 1.0);
    ag = ag + c2 / (z + 2.0);
    ag = ag + c3 / (z + 3.0);
    ag = ag + c4 / (z + 4.0);
    ag = ag + c5 / (z + 5.0);
    ag = ag + c6 / (z + 6.0);
    ag = ag + c7 / (z + 7.0);
    ag = ag + c8 / (z + 8.0);

    let t = z + LANCZOS_G + 0.5;
    return LN_SQRT_2PI + (z + 0.5) * log(t) - t + log(ag);
}

// Log-gamma using Lanczos approximation (non-recursive)
fn lgamma_impl(x: f32) -> f32 {
    if (x <= 0.0) {
        // Use reflection formula for negative values
        if (x == floor(x)) {
            return 1e30; // Pole at non-positive integers
        }
        // lgamma(x) = log(pi / sin(pi*x)) - lgamma(1-x)
        // Since 1-x > 0 for x <= 0, we call lgamma_positive directly
        let sinpix = sin(PI * x);
        if (sinpix == 0.0) {
            return 1e30;
        }
        return log(PI / abs(sinpix)) - lgamma_positive(1.0 - x);
    }

    return lgamma_positive(x);
}

// Regularized incomplete beta using continued fraction
fn betainc_cf(a: f32, b: f32, x: f32) -> f32 {
    let qab = a + b;
    let qap = a + 1.0;
    let qam = a - 1.0;

    var c = 1.0;
    var d = 1.0 - qab * x / qap;
    if (abs(d) < TINY) {
        d = TINY;
    }
    d = 1.0 / d;
    var h = d;

    for (var m = 1; m < MAX_ITER; m = m + 1) {
        let m2 = 2 * m;

        var aa = f32(m) * (b - f32(m)) * x / ((qam + f32(m2)) * (a + f32(m2)));
        d = 1.0 + aa * d;
        if (abs(d) < TINY) {
            d = TINY;
        }
        c = 1.0 + aa / c;
        if (abs(c) < TINY) {
            c = TINY;
        }
        d = 1.0 / d;
        h = h * d * c;

        aa = -(a + f32(m)) * (qab + f32(m)) * x / ((a + f32(m2)) * (qap + f32(m2)));
        d = 1.0 + aa * d;
        if (abs(d) < TINY) {
            d = TINY;
        }
        c = 1.0 + aa / c;
        if (abs(c) < TINY) {
            c = TINY;
        }
        d = 1.0 / d;
        let delta = d * c;
        h = h * delta;

        if (abs(delta - 1.0) < EPSILON) {
            break;
        }
    }

    let lnbeta = lgamma_impl(a) + lgamma_impl(b) - lgamma_impl(a + b);
    return exp(a * log(x) + b * log(1.0 - x) - lnbeta) * h / a;
}

fn betainc_impl(a: f32, b: f32, x: f32) -> f32 {
    if (x <= 0.0) {
        return 0.0;
    }
    if (x >= 1.0) {
        return 1.0;
    }

    // Use symmetry for better convergence (non-recursive version)
    if (x > (a + 1.0) / (a + b + 2.0)) {
        // Compute directly without recursion using symmetry
        return 1.0 - betainc_cf(b, a, 1.0 - x);
    }

    return betainc_cf(a, b, x);
}

// ============================================================================
// Compute Kernels
// ============================================================================

@compute @workgroup_size(256)
fn betainc_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < special_params.numel) {
        special_out[idx] = betainc_impl(special_a[idx], special_b[idx], special_x[idx]);
    }
}