numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Auto-generated masked_select operations for f32

const WORKGROUP_SIZE: u32 = 256u;

// Phase 1: Count masked elements
struct CountParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> count_mask: array<u32>;
@group(0) @binding(1) var<storage, read_write> count_result: atomic<u32>;
@group(0) @binding(2) var<uniform> count_params: CountParams;

var<workgroup> shared_count: atomic<u32>;

@compute @workgroup_size(256)
fn masked_count(@builtin(global_invocation_id) gid: vec3<u32>,
                @builtin(local_invocation_id) lid: vec3<u32>) {
    if (lid.x == 0u) {
        atomicStore(&shared_count, 0u);
    }
    workgroupBarrier();

    var local_count: u32 = 0u;
    var i = gid.x;
    while (i < count_params.numel) {
        if (count_mask[i] != 0u) {
            local_count = local_count + 1u;
        }
        i = i + 256u * 256u; // Grid stride
    }

    atomicAdd(&shared_count, local_count);
    workgroupBarrier();

    if (lid.x == 0u) {
        atomicAdd(&count_result, atomicLoad(&shared_count));
    }
}

// Phase 2: Compute prefix sum (sequential - for small arrays)
struct PrefixSumParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> prefix_mask: array<u32>;
@group(0) @binding(1) var<storage, read_write> prefix_sum: array<u32>;
@group(0) @binding(2) var<uniform> prefix_params: PrefixSumParams;

@compute @workgroup_size(1)
fn masked_prefix_sum(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (gid.x != 0u) {
        return;
    }

    var sum: u32 = 0u;
    for (var i: u32 = 0u; i < prefix_params.numel; i = i + 1u) {
        prefix_sum[i] = sum;
        if (prefix_mask[i] != 0u) {
            sum = sum + 1u;
        }
    }
}

// Phase 3: Gather selected elements
struct SelectParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> select_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> select_mask: array<u32>;
@group(0) @binding(2) var<storage, read_write> select_prefix: array<u32>;
@group(0) @binding(3) var<storage, read_write> select_output: array<f32>;
@group(0) @binding(4) var<uniform> select_params: SelectParams;

@compute @workgroup_size(256)
fn masked_select_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= select_params.numel) {
        return;
    }

    if (select_mask[idx] != 0u) {
        let out_idx = select_prefix[idx];
        select_output[out_idx] = select_input[idx];
    }
}