#version 450
// Rotary position embedding, honoring the pairing flavor (RopeStyle):
// style 0 = NeoX (HF rotate-half): dim i pairs with i + n_rot/2.
// style 1 = GptJ (llama.cpp NORM / interleaved): pairs (2i, 2i+1).
// Both index cos/sin row = token*tab_half at freq i in 0..n_rot/2, and copy
// dims [n_rot, head_dim) through unchanged (partial rotary). One invocation
// per (token, head). x rows walk with src_row_stride; output is dense `hidden`.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint batch;
uint seq;
uint hidden;
uint head_dim;
uint n_rot;
uint nh; // heads = hidden / head_dim
uint tab_half; // head_dim / 2
uint src_row_stride;
uint per_token; // 1 ⇒ cos/sin row indexed by global token, else by seq pos
uint style; // 0 = NeoX, 1 = GptJ
uint x_off;
uint cos_off;
uint sin_off;
uint out_off;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.batch * pc.seq * pc.nh;
if (gid >= total) { return; }
uint idx = gid / pc.nh; // token index = bi*seq + si
uint hi = gid % pc.nh;
uint bi = idx / pc.seq;
uint si = idx % pc.seq;
uint tab_off = (pc.per_token != 0u ? idx : si) * pc.tab_half;
uint src_base = bi * pc.seq * pc.src_row_stride + si * pc.src_row_stride + hi * pc.head_dim;
uint dst_base = bi * pc.seq * pc.hidden + si * pc.hidden + hi * pc.head_dim;
uint rot_half = pc.n_rot / 2u;
for (uint i = 0u; i < rot_half; i++) {
float cv = data[pc.cos_off + tab_off + i];
float sv = data[pc.sin_off + tab_off + i];
uint a;
uint b;
if (pc.style == 0u) { // NeoX rotate-half
a = i;
b = i + rot_half;
} else { // GptJ interleaved
a = 2u * i;
b = 2u * i + 1u;
}
float x1 = data[pc.x_off + src_base + a];
float x2 = data[pc.x_off + src_base + b];
data[pc.out_off + dst_base + a] = x1 * cv - x2 * sv;
data[pc.out_off + dst_base + b] = x2 * cv + x1 * sv;
}
for (uint j = pc.n_rot; j < pc.head_dim; j++) {
data[pc.out_off + dst_base + j] = data[pc.x_off + src_base + j];
}
}