#version 450
// Reverse (flip) along the axes whose flag is set. Output is contiguous; the
// input index is the same coordinate with flagged axes mirrored. rank ≤ 6.
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint n;
uint rank;
uint in_off;
uint out_off;
uint s0; uint s1; uint s2; uint s3; uint s4; uint s5; // shape
uint f0; uint f1; uint f2; uint f3; uint f4; uint f5; // flip flag (1 = reverse axis)
} pc;
void main() {
uint i = gl_GlobalInvocationID.x;
if (i >= pc.n) { return; }
uint shape[6] = uint[6](pc.s0, pc.s1, pc.s2, pc.s3, pc.s4, pc.s5);
uint flip[6] = uint[6](pc.f0, pc.f1, pc.f2, pc.f3, pc.f4, pc.f5);
uint coords[6];
uint rem = i;
for (uint a = 0u; a < pc.rank; a++) {
uint ax = pc.rank - 1u - a;
coords[ax] = rem % shape[ax];
rem = rem / shape[ax];
}
uint in_idx = 0u;
uint stride = 1u;
for (uint a = 0u; a < pc.rank; a++) {
uint ax = pc.rank - 1u - a;
uint c = coords[ax];
if (flip[ax] != 0u) { c = shape[ax] - 1u - c; }
in_idx += c * stride;
stride *= shape[ax];
}
data[pc.out_off + i] = data[pc.in_off + in_idx];
}