Skip to main content

mlx_native/ops/
ssm_conv.rs

1//! SSM depthwise causal 1D conv + SiLU GPU dispatch.
2//!
3//! Used by Qwen3.5 Gated DeltaNet linear-attention layers to apply a
4//! 4-kernel-wide causal conv1d across the QKV projection's output
5//! (ADR-013 Decision 7).
6//!
7//! # Operation
8//!
9//! ```text
10//! ssm_conv(x, kernel_w, state) -> (y, new_state)
11//!   x:        [channels, n_tokens, n_seqs]
12//!   kernel_w: [K, channels]              (K = 4 for Qwen3.5)
13//!   state:    [K-1, channels, n_seqs]    (previous (K-1) conv inputs per seq)
14//!
15//! extended(c, t_ext, s) = state(t_ext, c, s)            if t_ext < K - 1
16//!                         x(c, t_ext - (K-1), s)        otherwise
17//! y(c, t, s) = silu( sum_{k=0..K} kernel_w(k, c) * extended(c, t + k, s) )
18//! new_state(i, c, s) = extended(c, n_tokens + i, s)  for i in 0..K-1
19//! ```
20//!
21//! # Memory layout (column-major, innermost-first)
22//!
23//! * `x[c, t, s]`        at offset `s * n_tokens * channels + t * channels + c`
24//! * `y[c, t, s]`        same shape and layout as `x`
25//! * `state[i, c, s]`    at offset `s * (K-1) * channels + c * (K-1) + i`
26//! * `kernel_w[k, c]`    at offset `c * K + k`
27//!
28//! The per-(c, s) state row of K-1 values is contiguous in memory, matching
29//! the expected ring-buffer slice that callers view as `state[:, c, s]`.
30//!
31//! # Two-pass design
32//!
33//! The forward and state-update kernels are separate dispatches because:
34//! 1. When `n_tokens + i < K - 1` the state-update reads from the old state;
35//!    this would alias the output if written in place.
36//! 2. The state update is a small O(K × channels × n_seqs) pass whose
37//!    arithmetic is different from the main conv; fusing them would waste
38//!    threads.
39//!
40//! Callers must provide separate `old_state` and `new_state` buffers. The
41//! `dispatch_ssm_conv` helper below accepts both in a single call and encodes
42//! both kernels back-to-back.
43use metal::MTLSize;
44
45use crate::buffer::MlxBuffer;
46use crate::dtypes::DType;
47use crate::encoder::CommandEncoder;
48use crate::error::{MlxError, Result};
49use crate::kernel_registry::KernelRegistry;
50
51pub static SSM_CONV_SHADER_SOURCE: &str = include_str!("../shaders/ssm_conv.metal");
52
53/// Register SSM conv shader sources with the given kernel registry.
54pub fn register(registry: &mut KernelRegistry) {
55    registry.register_source("ssm_conv_forward_f32", SSM_CONV_SHADER_SOURCE);
56    registry.register_source("ssm_conv_forward_bf16", SSM_CONV_SHADER_SOURCE);
57    registry.register_source("ssm_conv_state_update_f32", SSM_CONV_SHADER_SOURCE);
58    registry.register_source("ssm_conv_state_update_bf16", SSM_CONV_SHADER_SOURCE);
59}
60
61/// Shape parameters for an ssm_conv dispatch.
62#[derive(Debug, Clone, Copy)]
63pub struct SsmConvParams {
64    pub channels: u32,
65    pub n_tokens: u32,
66    pub n_seqs: u32,
67    pub k_width: u32, // typically 4; ADR-013 forbids K <= 1
68}
69
70fn validate(
71    params: &SsmConvParams,
72    x: &MlxBuffer,
73    kernel_w: &MlxBuffer,
74    old_state: &MlxBuffer,
75    new_state: &MlxBuffer,
76    y: &MlxBuffer,
77) -> Result<()> {
78    if params.channels == 0 || params.n_tokens == 0 || params.n_seqs == 0 {
79        return Err(MlxError::InvalidArgument(
80            "ssm_conv: channels, n_tokens, n_seqs must all be > 0".into(),
81        ));
82    }
83    if params.k_width < 2 {
84        return Err(MlxError::InvalidArgument(
85            "ssm_conv: k_width must be >= 2 (K=1 has empty state)".into(),
86        ));
87    }
88    let x_elems = (params.channels as usize)
89        .checked_mul(params.n_tokens as usize)
90        .and_then(|v| v.checked_mul(params.n_seqs as usize))
91        .ok_or_else(|| MlxError::InvalidArgument("ssm_conv: shape overflow".into()))?;
92    let w_elems = (params.k_width as usize) * (params.channels as usize);
93    let s_elems = ((params.k_width - 1) as usize)
94        * (params.channels as usize)
95        * (params.n_seqs as usize);
96
97    if x.element_count() != x_elems {
98        return Err(MlxError::InvalidArgument(format!(
99            "ssm_conv: x element count {} != channels({}) * n_tokens({}) * n_seqs({})",
100            x.element_count(),
101            params.channels,
102            params.n_tokens,
103            params.n_seqs
104        )));
105    }
106    if y.element_count() != x_elems {
107        return Err(MlxError::InvalidArgument(format!(
108            "ssm_conv: y element count {} != expected {}",
109            y.element_count(),
110            x_elems
111        )));
112    }
113    if kernel_w.element_count() != w_elems {
114        return Err(MlxError::InvalidArgument(format!(
115            "ssm_conv: kernel_w element count {} != K({}) * channels({})",
116            kernel_w.element_count(),
117            params.k_width,
118            params.channels
119        )));
120    }
121    if old_state.element_count() != s_elems || new_state.element_count() != s_elems {
122        return Err(MlxError::InvalidArgument(format!(
123            "ssm_conv: state element count mismatch; old={} new={} expected {}",
124            old_state.element_count(),
125            new_state.element_count(),
126            s_elems
127        )));
128    }
129
130    let dt = x.dtype();
131    for (name, buf) in [
132        ("kernel_w", kernel_w),
133        ("old_state", old_state),
134        ("new_state", new_state),
135        ("y", y),
136    ] {
137        if buf.dtype() != dt {
138            return Err(MlxError::InvalidArgument(format!(
139                "ssm_conv: dtype mismatch — x is {}, {} is {}",
140                dt,
141                name,
142                buf.dtype()
143            )));
144        }
145    }
146    Ok(())
147}
148
149/// Dispatch a fused depthwise causal 1D conv + SiLU plus state update.
150///
151/// Two kernels are encoded back-to-back: the forward conv produces `y`, and a
152/// small state update writes the last K-1 tokens of the extended stream into
153/// `new_state`. Callers may point `old_state` and `new_state` at the same
154/// backing buffer if-and-only-if `n_tokens >= k_width - 1` (the state-update
155/// never reads from `old_state` in that regime, so aliasing is safe). For
156/// decode with `n_tokens < K - 1` a separate buffer is mandatory.
157///
158/// # Arguments
159///
160/// * `params_buf` - buffer of 4 u32 `[channels, n_tokens, n_seqs, k_width]`.
161///
162/// # Errors
163///
164/// See [`validate`] for the full list.
165pub fn dispatch_ssm_conv(
166    encoder: &mut CommandEncoder,
167    registry: &mut KernelRegistry,
168    device: &metal::DeviceRef,
169    x: &MlxBuffer,
170    kernel_w: &MlxBuffer,
171    old_state: &MlxBuffer,
172    new_state: &MlxBuffer,
173    y: &MlxBuffer,
174    params_buf: &MlxBuffer,
175    params: SsmConvParams,
176) -> Result<()> {
177    validate(&params, x, kernel_w, old_state, new_state, y)?;
178
179    let (fwd_name, state_name) = match x.dtype() {
180        DType::F32 => ("ssm_conv_forward_f32", "ssm_conv_state_update_f32"),
181        DType::BF16 => ("ssm_conv_forward_bf16", "ssm_conv_state_update_bf16"),
182        other => {
183            return Err(MlxError::InvalidArgument(format!(
184                "ssm_conv: unsupported dtype {}",
185                other
186            )))
187        }
188    };
189
190    // Forward: one thread per (c, t, s).
191    let fwd_pipeline = registry.get_pipeline(fwd_name, device)?;
192    let fwd_grid = MTLSize::new(
193        params.channels as u64,
194        params.n_tokens as u64,
195        params.n_seqs as u64,
196    );
197    // Threadgroup: keep total <= 256, prefer packing along the channels axis.
198    let tg_c = std::cmp::min(params.channels, 256).max(1);
199    let remain = 256u32 / tg_c;
200    let tg_t = std::cmp::min(params.n_tokens, remain).max(1);
201    let remain2 = (256u32 / (tg_c * tg_t)).max(1);
202    let tg_s = std::cmp::min(params.n_seqs, remain2).max(1);
203    let fwd_tg = MTLSize::new(tg_c as u64, tg_t as u64, tg_s as u64);
204
205    encoder.encode(
206        fwd_pipeline,
207        &[
208            (0, x),
209            (1, kernel_w),
210            (2, old_state),
211            (3, y),
212            (4, params_buf),
213        ],
214        fwd_grid,
215        fwd_tg,
216    );
217
218    // State update: one thread per (i, c, s), i in 0..K-1.
219    let state_pipeline = registry.get_pipeline(state_name, device)?;
220    let state_grid = MTLSize::new(
221        (params.k_width - 1) as u64,
222        params.channels as u64,
223        params.n_seqs as u64,
224    );
225    let su_tg_i = (params.k_width - 1).max(1);
226    let su_remain = (256u32 / su_tg_i).max(1);
227    let su_tg_c_raw = std::cmp::min(params.channels, su_remain).max(1);
228    // ADR-029 iter-175 Step 1r: ensure su_tg_i * su_tg_c is a multiple of
229    // 32 (Apple's threadExecutionWidth).  Pre-fix this site computed
230    // (3, 85, 1) → 255 threads per TG when k_width=4 and channels≥85,
231    // which is NOT a multiple of 32 and causes UB under
232    // HF2Q_PIPELINE_TG_MULT_HINT=1.  Coherence_smoke
233    // apex-q5km/{the-quick-brown-fox, what-is-22, hello-my-name-is}
234    // caught this via the Step 1q safety assertion.
235    //
236    // Strategy: shrink su_tg_c to the largest multiple of (32 / gcd(su_tg_i, 32))
237    // that still respects the channel/256 limit.  gcd(3, 32) = 1 → step = 32,
238    // so su_tg_c rounds DOWN to a multiple of 32 (85 → 64).  gcd(2, 32) = 2 →
239    // step = 16; gcd(4, 32) = 4 → step = 8; gcd(8, 32) = 8 → step = 4; etc.
240    // When su_tg_i is itself a multiple of 32 (rare for SSM conv), step = 1
241    // and any su_tg_c works — leaves the original value.
242    fn gcd_u32(mut a: u32, mut b: u32) -> u32 {
243        while b != 0 { let t = b; b = a % b; a = t; }
244        a
245    }
246    let step = 32u32 / gcd_u32(su_tg_i, 32);
247    let su_tg_c = if step <= 1 {
248        su_tg_c_raw
249    } else if su_tg_c_raw >= step {
250        (su_tg_c_raw / step) * step  // round DOWN to multiple of step
251    } else {
252        // su_tg_c_raw < step: can't satisfy multiple-of-32 with this su_tg_i.
253        // Leave su_tg_c_raw alone; safety check will fire under
254        // HF2Q_PIPELINE_TG_MULT_HINT=1 (correct behavior — this case
255        // means the kernel inherently can't satisfy the constraint).
256        su_tg_c_raw
257    };
258    let su_remain2 = (256u32 / (su_tg_i * su_tg_c)).max(1);
259    let su_tg_s = std::cmp::min(params.n_seqs, su_remain2).max(1);
260    let state_tg = MTLSize::new(su_tg_i as u64, su_tg_c as u64, su_tg_s as u64);
261
262    encoder.encode(
263        state_pipeline,
264        &[
265            (0, x),
266            (1, old_state),
267            (2, new_state),
268            (3, params_buf),
269        ],
270        state_grid,
271        state_tg,
272    );
273
274    Ok(())
275}