Expand description
Cumulative sum (inclusive prefix sum) along the last axis.
Computes out[r, i] = sum(x[r, 0..=i]) for every row r.
Used by the chunked Gated DeltaNet path to produce the decay-mask base (ADR-013 Decision 4). Spec derived from the definition of an inclusive prefix scan.
§Algorithm
Hillis-Steele scan in threadgroup shared memory. Each thread owns a contiguous chunk of the row; it scans locally first, exchanges chunk totals via a shared-memory scan, then adds the exclusive prefix of preceding threads to its chunk.
§Threadgroup shape
One threadgroup per row. tg_size = min(256, next_power_of_two(dim))
and each thread handles ceil_div(dim, tg_size) elements. The shader
caps the per-thread chunk at CUMSUM_MAX_CHUNK = 32, so for tg_size = 256 the maximum dim handled in a single dispatch is 8192.
Reduction and prefix arithmetic are performed in f32 regardless of input dtype for numerical stability.
Statics§
Functions§
- dispatch_
cumsum - Dispatch an inclusive prefix sum along the last axis of
[rows, dim]. - register
- Register cumsum shader sources with the given kernel registry.