Expand description
SSM depthwise causal 1D conv + SiLU GPU dispatch.
Used by Qwen3.5 Gated DeltaNet linear-attention layers to apply a 4-kernel-wide causal conv1d across the QKV projection’s output (ADR-013 Decision 7).
§Operation
ssm_conv(x, kernel_w, state) -> (y, new_state)
x: [channels, n_tokens, n_seqs]
kernel_w: [K, channels] (K = 4 for Qwen3.5)
state: [K-1, channels, n_seqs] (previous (K-1) conv inputs per seq)
extended(c, t_ext, s) = state(t_ext, c, s) if t_ext < K - 1
x(c, t_ext - (K-1), s) otherwise
y(c, t, s) = silu( sum_{k=0..K} kernel_w(k, c) * extended(c, t + k, s) )
new_state(i, c, s) = extended(c, n_tokens + i, s) for i in 0..K-1§Memory layout (column-major, innermost-first)
x[c, t, s]at offsets * n_tokens * channels + t * channels + cy[c, t, s]same shape and layout asxstate[i, c, s]at offsets * (K-1) * channels + c * (K-1) + ikernel_w[k, c]at offsetc * K + k
The per-(c, s) state row of K-1 values is contiguous in memory, matching
the expected ring-buffer slice that callers view as state[:, c, s].
§Two-pass design
The forward and state-update kernels are separate dispatches because:
- When
n_tokens + i < K - 1the state-update reads from the old state; this would alias the output if written in place. - The state update is a small O(K × channels × n_seqs) pass whose arithmetic is different from the main conv; fusing them would waste threads.
Callers must provide separate old_state and new_state buffers. The
dispatch_ssm_conv helper below accepts both in a single call and encodes
both kernels back-to-back.
Structs§
- SsmConv
Params - Shape parameters for an ssm_conv dispatch.
Statics§
Functions§
- dispatch_
ssm_ conv - Dispatch a fused depthwise causal 1D conv + SiLU plus state update.
- register
- Register SSM conv shader sources with the given kernel registry.