#version 450
// Mamba selective-scan (plan #15), matching the CPU reference recurrence:
// state[c,n] = exp(Δ·A[c,n])·state[c,n] + Δ·B[s,n]·x[s,c]
// y[s,c] = Σ_n C[s,n]·state[c,n]
// Inputs: x,Δ [B,S,H], A [H,N], B,C [B,S,N]. One invocation per (batch, channel),
// walking the sequence sequentially while channels/batches run in parallel.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
const uint MAX_N = 256u;
layout(push_constant) uniform PC {
uint bb; // batch
uint ss; // seq
uint hh; // hidden channels
uint nn; // state size
uint x_off;
uint delta_off;
uint a_off;
uint b_off;
uint c_off;
uint out_off;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.bb * pc.hh;
if (gid >= total || pc.nn > MAX_N) { return; }
uint bi = gid / pc.hh;
uint ci = gid % pc.hh;
float state[MAX_N];
for (uint n = 0u; n < pc.nn; n++) { state[n] = 0.0; }
for (uint si = 0u; si < pc.ss; si++) {
uint bsh = (bi * pc.ss + si) * pc.hh + ci;
uint bsn = (bi * pc.ss + si) * pc.nn;
float d = data[pc.delta_off + bsh];
float xv = data[pc.x_off + bsh];
float acc = 0.0;
for (uint n = 0u; n < pc.nn; n++) {
float da = exp(d * data[pc.a_off + ci * pc.nn + n]);
float sv = da * state[n] + d * data[pc.b_off + bsn + n] * xv;
state[n] = sv;
acc += data[pc.c_off + bsn + n] * sv;
}
data[pc.out_off + bsh] = acc;
}
}