numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// F32 unary operations

const WORKGROUP_SIZE: u32 = 256u;

struct UnaryParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> unary_a: array<f32>;
@group(0) @binding(1) var<storage, read_write> unary_out: array<f32>;
@group(0) @binding(2) var<uniform> unary_params: UnaryParams;

@compute @workgroup_size(256)
fn neg_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = -unary_a[idx];
    }
}

@compute @workgroup_size(256)
fn abs_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = abs(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn sqrt_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = sqrt(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn exp_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = exp(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn log_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = log(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn sin_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = sin(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn cos_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = cos(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn tan_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = tan(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn atan_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = atan(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn tanh_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = tanh(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn recip_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = 1.0 / unary_a[idx];
    }
}

@compute @workgroup_size(256)
fn floor_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = floor(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn ceil_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = ceil(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn round_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        unary_out[idx] = select(ceil(x - 0.5), floor(x + 0.5), x >= 0.0);
    }
}

@compute @workgroup_size(256)
fn trunc_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = trunc(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn rsqrt_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = inverseSqrt(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn cbrt_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        unary_out[idx] = sign(x) * pow(abs(x), 1.0 / 3.0);
    }
}

@compute @workgroup_size(256)
fn exp2_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = exp2(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn expm1_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = exp(unary_a[idx]) - 1.0;
    }
}

@compute @workgroup_size(256)
fn log2_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = log2(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn log10_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = log(unary_a[idx]) * 0.4342944819032518;
    }
}

@compute @workgroup_size(256)
fn log1p_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = log(1.0 + unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn asin_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        let y = sqrt(max(0.0, 1.0 - x * x));
        unary_out[idx] = atan2(x, y);
    }
}

@compute @workgroup_size(256)
fn acos_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        let y = sqrt(max(0.0, 1.0 - x * x));
        unary_out[idx] = atan2(y, x);
    }
}

@compute @workgroup_size(256)
fn sinh_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = sinh(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn cosh_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = cosh(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn asinh_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = asinh(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn acosh_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = acosh(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn atanh_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = atanh(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn square_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        unary_out[idx] = x * x;
    }
}

@compute @workgroup_size(256)
fn sign_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = sign(unary_a[idx]);
    }
}

@compute @workgroup_size(256)
fn relu_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = max(unary_a[idx], 0.0);
    }
}

@compute @workgroup_size(256)
fn sigmoid_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        unary_out[idx] = 1.0 / (1.0 + exp(-unary_a[idx]));
    }
}

@compute @workgroup_size(256)
fn silu_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        unary_out[idx] = x / (1.0 + exp(-x));
    }
}

@compute @workgroup_size(256)
fn gelu_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        let c = 0.7978845608028654;
        unary_out[idx] = 0.5 * x * (1.0 + tanh(c * (x + 0.044715 * x * x * x)));
    }
}

@compute @workgroup_size(256)
fn isnan_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        let bits = bitcast<u32>(f32(x));
        let exp = bits & 0x7f800000u;
        let mant = bits & 0x007fffffu;
        let is_nan = (exp == 0x7f800000u) && (mant != 0u);
        unary_out[idx] = select(0.0, 1.0, is_nan);
    }
}

@compute @workgroup_size(256)
fn isinf_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < unary_params.numel) {
        let x = unary_a[idx];
        let bits = bitcast<u32>(f32(x));
        let exp = bits & 0x7f800000u;
        let mant = bits & 0x007fffffu;
        let is_inf = (exp == 0x7f800000u) && (mant == 0u);
        unary_out[idx] = select(0.0, 1.0, is_inf);
    }
}