Skip to main content

Module cumsum

Module cumsum 

Source
Expand description

Cumulative sum (inclusive prefix sum) along the last axis.

Computes out[r, i] = sum(x[r, 0..=i]) for every row r.

Used by the chunked Gated DeltaNet path to produce the decay-mask base (ADR-013 Decision 4). Spec derived from the definition of an inclusive prefix scan.

§Algorithm

Hillis-Steele scan in threadgroup shared memory. Each thread owns a contiguous chunk of the row; it scans locally first, exchanges chunk totals via a shared-memory scan, then adds the exclusive prefix of preceding threads to its chunk.

§Threadgroup shape

One threadgroup per row. tg_size = min(256, next_power_of_two(dim)) and each thread handles ceil_div(dim, tg_size) elements. The shader caps the per-thread chunk at CUMSUM_MAX_CHUNK = 32, so for tg_size = 256 the maximum dim handled in a single dispatch is 8192.

Reduction and prefix arithmetic are performed in f32 regardless of input dtype for numerical stability.

Statics§

CUMSUM_SHADER_SOURCE

Functions§

dispatch_cumsum
Dispatch an inclusive prefix sum along the last axis of [rows, dim].
register
Register cumsum shader sources with the given kernel registry.