Skip to main content

Module ssm_conv

Module ssm_conv 

Source
Expand description

SSM depthwise causal 1D conv + SiLU GPU dispatch.

Used by Qwen3.5 Gated DeltaNet linear-attention layers to apply a 4-kernel-wide causal conv1d across the QKV projection’s output (ADR-013 Decision 7).

§Operation

ssm_conv(x, kernel_w, state) -> (y, new_state)
  x:        [channels, n_tokens, n_seqs]
  kernel_w: [K, channels]              (K = 4 for Qwen3.5)
  state:    [K-1, channels, n_seqs]    (previous (K-1) conv inputs per seq)

extended(c, t_ext, s) = state(t_ext, c, s)            if t_ext < K - 1
                        x(c, t_ext - (K-1), s)        otherwise
y(c, t, s) = silu( sum_{k=0..K} kernel_w(k, c) * extended(c, t + k, s) )
new_state(i, c, s) = extended(c, n_tokens + i, s)  for i in 0..K-1

§Memory layout (column-major, innermost-first)

  • x[c, t, s] at offset s * n_tokens * channels + t * channels + c
  • y[c, t, s] same shape and layout as x
  • state[i, c, s] at offset s * (K-1) * channels + c * (K-1) + i
  • kernel_w[k, c] at offset c * K + k

The per-(c, s) state row of K-1 values is contiguous in memory, matching the expected ring-buffer slice that callers view as state[:, c, s].

§Two-pass design

The forward and state-update kernels are separate dispatches because:

  1. When n_tokens + i < K - 1 the state-update reads from the old state; this would alias the output if written in place.
  2. The state update is a small O(K × channels × n_seqs) pass whose arithmetic is different from the main conv; fusing them would waste threads.

Callers must provide separate old_state and new_state buffers. The dispatch_ssm_conv helper below accepts both in a single call and encodes both kernels back-to-back.

Structs§

SsmConvParams
Shape parameters for an ssm_conv dispatch.

Statics§

SSM_CONV_SHADER_SOURCE

Functions§

dispatch_ssm_conv
Dispatch a fused depthwise causal 1D conv + SiLU plus state update.
register
Register SSM conv shader sources with the given kernel registry.