#version 450
// Element-wise binary op over the shared f32 arena with per-operand
// trailing-broadcast modulus. op: 0=add 1=sub 2=mul 3=div 4=max 5=min 6=pow.
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint n; // output element count
uint a_off; // f32-element offset of lhs
uint b_off; // f32-element offset of rhs
uint c_off; // f32-element offset of output
uint a_mod; // lhs element count for broadcast (0 = no broadcast)
uint b_mod; // rhs element count for broadcast (0 = no broadcast)
uint op; // selector
} pc;
void main() {
uint i = gl_GlobalInvocationID.x;
if (i >= pc.n) { return; }
uint ai = (pc.a_mod == 0u) ? i : (i % pc.a_mod);
uint bi = (pc.b_mod == 0u) ? i : (i % pc.b_mod);
float a = data[pc.a_off + ai];
float b = data[pc.b_off + bi];
float c = 0.0;
switch (pc.op) {
case 0u: c = a + b; break;
case 1u: c = a - b; break;
case 2u: c = a * b; break;
case 3u: c = a / b; break;
case 4u: c = max(a, b); break;
case 5u: c = min(a, b); break;
case 6u: c = pow(a, b); break;
default: c = 0.0; break;
}
data[pc.c_off + i] = c;
}