rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Softmax over `axis_len` along a strided axis. The tensor is viewed as
// [outer, axis_len, inner]: element (o, k, j) lives at
// base + o*axis_len*inner + k*inner + j. One invocation per (o, j).
layout(local_size_x = 256) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint outer;
    uint axis_len;
    uint inner;
    uint in_off;
    uint out_off;
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.outer * pc.inner;
    if (gid >= total) { return; }
    uint o = gid / pc.inner;
    uint j = gid % pc.inner;
    uint base = o * pc.axis_len * pc.inner + j;

    float m = -3.402823466e38;
    for (uint k = 0u; k < pc.axis_len; k++) {
        m = max(m, data[pc.in_off + base + k * pc.inner]);
    }
    float sum = 0.0;
    for (uint k = 0u; k < pc.axis_len; k++) {
        sum += exp(data[pc.in_off + base + k * pc.inner] - m);
    }
    float inv = (sum > 0.0) ? (1.0 / sum) : 0.0;
    for (uint k = 0u; k < pc.axis_len; k++) {
        uint idx = base + k * pc.inner;
        data[pc.out_off + idx] = exp(data[pc.in_off + idx] - m) * inv;
    }
}