#version 450
// GPT-NeoX style rotary embedding (hanzo-ml "rotary-emb"). src is [b,h,t,d] contiguous;
// cos/sin are [t,d/2] (unbatched) or [b,t,d/2]. One invocation per (bh, t, d/2) triple.
// dst[i1] = src[i1]*cos - src[i2]*sin ; dst[i2] = src[i1]*sin + src[i2]*cos
// with i1 = i_t*d + i_d, i2 = i1 + d/2, cs index = i_t*(d/2)+i_d (+ b_i*t*d/2 if unbatched).
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer S { float src[]; };
layout(set = 0, binding = 1) readonly buffer C { float cosb[]; };
layout(set = 0, binding = 2) readonly buffer N { float sinb[]; };
layout(set = 0, binding = 3) writeonly buffer O { float dst[]; };
layout(push_constant) uniform Pc { uint b; uint h; uint t; uint d; uint unbatched; };
void main() {
uint gid = gl_GlobalInvocationID.x;
uint hd = d / 2u;
uint per_bh = t * hd; // (i_t, i_d) pairs per (b,h) slice
uint total = b * h * per_bh;
if (gid >= total) { return; }
uint bh_i = gid / per_bh; // which [b,h] slice
uint rem = gid % per_bh;
uint i_t = rem / hd;
uint i_d = rem % hd;
uint sbase = bh_i * t * d; // start of this [b,h] slice in src/dst
uint i1 = sbase + i_t * d + i_d;
uint i2 = i1 + hd;
uint i_cs = i_t * hd + i_d;
if (unbatched != 0u) { i_cs += (bh_i / h) * per_bh; }
float c = cosb[i_cs];
float s = sinb[i_cs];
float x1 = src[i1];
float x2 = src[i2];
dst[i1] = x1 * c - x2 * s;
dst[i2] = x1 * s + x2 * c;
}