hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// GPT-NeoX style rotary embedding (hanzo-ml "rotary-emb"). src is [b,h,t,d] contiguous;
// cos/sin are [t,d/2] (unbatched) or [b,t,d/2]. One invocation per (bh, t, d/2) triple.
//   dst[i1] = src[i1]*cos - src[i2]*sin ;  dst[i2] = src[i1]*sin + src[i2]*cos
// with i1 = i_t*d + i_d, i2 = i1 + d/2, cs index = i_t*(d/2)+i_d (+ b_i*t*d/2 if unbatched).
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly  buffer S { float src[]; };
layout(set = 0, binding = 1) readonly  buffer C { float cosb[]; };
layout(set = 0, binding = 2) readonly  buffer N { float sinb[]; };
layout(set = 0, binding = 3) writeonly buffer O { float dst[]; };
layout(push_constant) uniform Pc { uint b; uint h; uint t; uint d; uint unbatched; };
void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint hd = d / 2u;
    uint per_bh = t * hd;            // (i_t, i_d) pairs per (b,h) slice
    uint total = b * h * per_bh;
    if (gid >= total) { return; }
    uint bh_i = gid / per_bh;        // which [b,h] slice
    uint rem  = gid % per_bh;
    uint i_t  = rem / hd;
    uint i_d  = rem % hd;
    uint sbase = bh_i * t * d;       // start of this [b,h] slice in src/dst
    uint i1 = sbase + i_t * d + i_d;
    uint i2 = i1 + hd;
    uint i_cs = i_t * hd + i_d;
    if (unbatched != 0u) { i_cs += (bh_i / h) * per_bh; }
    float c = cosb[i_cs];
    float s = sinb[i_cs];
    float x1 = src[i1];
    float x2 = src[i2];
    dst[i1] = x1 * c - x2 * s;
    dst[i2] = x1 * s + x2 * c;
}