#version 450
// Cumulative sum along the last axis (cols). One invocation per row.
// exclusive=1 ⇒ out[j] = sum_{i<j} in[i]; else inclusive.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint rows;
uint cols;
uint in_off;
uint out_off;
uint exclusive;
} pc;
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= pc.rows) { return; }
uint base = row * pc.cols;
float acc = 0.0;
for (uint j = 0u; j < pc.cols; j++) {
if (pc.exclusive != 0u) {
data[pc.out_off + base + j] = acc;
acc += data[pc.in_off + base + j];
} else {
acc += data[pc.in_off + base + j];
data[pc.out_off + base + j] = acc;
}
}
}