numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Strided log-sum-exp shader for f32

struct LogsumexpStridedParams {
    reduce_size: u32,
    outer_size: u32,
    inner_size: u32,
}

@group(0) @binding(0) var<storage, read_write> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: LogsumexpStridedParams;

@compute @workgroup_size(256)
fn logsumexp_strided_f32(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    let total_inner = params.outer_size * params.inner_size;
    if (idx >= total_inner) {
        return;
    }

    let outer_idx = idx / params.inner_size;
    let inner_idx = idx % params.inner_size;

    // Step 1: Find max value along reduce dimension
    var max_val: f32 = -3.402823e+38;
    for (var r: u32 = 0u; r < params.reduce_size; r = r + 1u) {
        let offset = outer_idx * params.reduce_size * params.inner_size + r * params.inner_size + inner_idx;
        max_val = max(max_val, input[offset]);
    }

    // Step 2: Compute sum(exp(x - max))
    var sum_exp: f32 = 0.0;
    for (var r: u32 = 0u; r < params.reduce_size; r = r + 1u) {
        let offset = outer_idx * params.reduce_size * params.inner_size + r * params.inner_size + inner_idx;
        sum_exp = sum_exp + exp(input[offset] - max_val);
    }

    // Step 3: Write result
    output[outer_idx * params.inner_size + inner_idx] = max_val + log(sum_exp);
}