1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
//! SSM depthwise causal 1D conv + SiLU GPU dispatch.
//!
//! Used by Qwen3.5 Gated DeltaNet linear-attention layers to apply a
//! 4-kernel-wide causal conv1d across the QKV projection's output
//! (ADR-013 Decision 7).
//!
//! # Operation
//!
//! ```text
//! ssm_conv(x, kernel_w, state) -> (y, new_state)
//! x: [channels, n_tokens, n_seqs]
//! kernel_w: [K, channels] (K = 4 for Qwen3.5)
//! state: [K-1, channels, n_seqs] (previous (K-1) conv inputs per seq)
//!
//! extended(c, t_ext, s) = state(t_ext, c, s) if t_ext < K - 1
//! x(c, t_ext - (K-1), s) otherwise
//! y(c, t, s) = silu( sum_{k=0..K} kernel_w(k, c) * extended(c, t + k, s) )
//! new_state(i, c, s) = extended(c, n_tokens + i, s) for i in 0..K-1
//! ```
//!
//! # Memory layout (column-major, innermost-first)
//!
//! * `x[c, t, s]` at offset `s * n_tokens * channels + t * channels + c`
//! * `y[c, t, s]` same shape and layout as `x`
//! * `state[i, c, s]` at offset `s * (K-1) * channels + c * (K-1) + i`
//! * `kernel_w[k, c]` at offset `c * K + k`
//!
//! The per-(c, s) state row of K-1 values is contiguous in memory, matching
//! the expected ring-buffer slice that callers view as `state[:, c, s]`.
//!
//! # Two-pass design
//!
//! The forward and state-update kernels are separate dispatches because:
//! 1. When `n_tokens + i < K - 1` the state-update reads from the old state;
//! this would alias the output if written in place.
//! 2. The state update is a small O(K × channels × n_seqs) pass whose
//! arithmetic is different from the main conv; fusing them would waste
//! threads.
//!
//! Callers must provide separate `old_state` and `new_state` buffers. The
//! `dispatch_ssm_conv` helper below accepts both in a single call and encodes
//! both kernels back-to-back.
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static SSM_CONV_SHADER_SOURCE: &str = include_str!("../shaders/ssm_conv.metal");
/// Register SSM conv shader sources with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("ssm_conv_forward_f32", SSM_CONV_SHADER_SOURCE);
registry.register_source("ssm_conv_forward_bf16", SSM_CONV_SHADER_SOURCE);
registry.register_source("ssm_conv_state_update_f32", SSM_CONV_SHADER_SOURCE);
registry.register_source("ssm_conv_state_update_bf16", SSM_CONV_SHADER_SOURCE);
}
/// Shape parameters for an ssm_conv dispatch.
#[derive(Debug, Clone, Copy)]
pub struct SsmConvParams {
pub channels: u32,
pub n_tokens: u32,
pub n_seqs: u32,
pub k_width: u32, // typically 4; ADR-013 forbids K <= 1
}
fn validate(
params: &SsmConvParams,
x: &MlxBuffer,
kernel_w: &MlxBuffer,
old_state: &MlxBuffer,
new_state: &MlxBuffer,
y: &MlxBuffer,
) -> Result<()> {
if params.channels == 0 || params.n_tokens == 0 || params.n_seqs == 0 {
return Err(MlxError::InvalidArgument(
"ssm_conv: channels, n_tokens, n_seqs must all be > 0".into(),
));
}
if params.k_width < 2 {
return Err(MlxError::InvalidArgument(
"ssm_conv: k_width must be >= 2 (K=1 has empty state)".into(),
));
}
let x_elems = (params.channels as usize)
.checked_mul(params.n_tokens as usize)
.and_then(|v| v.checked_mul(params.n_seqs as usize))
.ok_or_else(|| MlxError::InvalidArgument("ssm_conv: shape overflow".into()))?;
let w_elems = (params.k_width as usize) * (params.channels as usize);
let s_elems = ((params.k_width - 1) as usize)
* (params.channels as usize)
* (params.n_seqs as usize);
if x.element_count() != x_elems {
return Err(MlxError::InvalidArgument(format!(
"ssm_conv: x element count {} != channels({}) * n_tokens({}) * n_seqs({})",
x.element_count(),
params.channels,
params.n_tokens,
params.n_seqs
)));
}
if y.element_count() != x_elems {
return Err(MlxError::InvalidArgument(format!(
"ssm_conv: y element count {} != expected {}",
y.element_count(),
x_elems
)));
}
if kernel_w.element_count() != w_elems {
return Err(MlxError::InvalidArgument(format!(
"ssm_conv: kernel_w element count {} != K({}) * channels({})",
kernel_w.element_count(),
params.k_width,
params.channels
)));
}
if old_state.element_count() != s_elems || new_state.element_count() != s_elems {
return Err(MlxError::InvalidArgument(format!(
"ssm_conv: state element count mismatch; old={} new={} expected {}",
old_state.element_count(),
new_state.element_count(),
s_elems
)));
}
let dt = x.dtype();
for (name, buf) in [
("kernel_w", kernel_w),
("old_state", old_state),
("new_state", new_state),
("y", y),
] {
if buf.dtype() != dt {
return Err(MlxError::InvalidArgument(format!(
"ssm_conv: dtype mismatch — x is {}, {} is {}",
dt,
name,
buf.dtype()
)));
}
}
Ok(())
}
/// Dispatch a fused depthwise causal 1D conv + SiLU plus state update.
///
/// Two kernels are encoded back-to-back: the forward conv produces `y`, and a
/// small state update writes the last K-1 tokens of the extended stream into
/// `new_state`. Callers may point `old_state` and `new_state` at the same
/// backing buffer if-and-only-if `n_tokens >= k_width - 1` (the state-update
/// never reads from `old_state` in that regime, so aliasing is safe). For
/// decode with `n_tokens < K - 1` a separate buffer is mandatory.
///
/// # Arguments
///
/// * `params_buf` - buffer of 4 u32 `[channels, n_tokens, n_seqs, k_width]`.
///
/// # Errors
///
/// See [`validate`] for the full list.
pub fn dispatch_ssm_conv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
kernel_w: &MlxBuffer,
old_state: &MlxBuffer,
new_state: &MlxBuffer,
y: &MlxBuffer,
params_buf: &MlxBuffer,
params: SsmConvParams,
) -> Result<()> {
validate(¶ms, x, kernel_w, old_state, new_state, y)?;
let (fwd_name, state_name) = match x.dtype() {
DType::F32 => ("ssm_conv_forward_f32", "ssm_conv_state_update_f32"),
DType::BF16 => ("ssm_conv_forward_bf16", "ssm_conv_state_update_bf16"),
other => {
return Err(MlxError::InvalidArgument(format!(
"ssm_conv: unsupported dtype {}",
other
)))
}
};
// Forward: one thread per (c, t, s).
let fwd_pipeline = registry.get_pipeline(fwd_name, device)?;
let fwd_grid = MTLSize::new(
params.channels as u64,
params.n_tokens as u64,
params.n_seqs as u64,
);
// Threadgroup: keep total <= 256, prefer packing along the channels axis.
let tg_c = std::cmp::min(params.channels, 256).max(1);
let remain = 256u32 / tg_c;
let tg_t = std::cmp::min(params.n_tokens, remain).max(1);
let remain2 = (256u32 / (tg_c * tg_t)).max(1);
let tg_s = std::cmp::min(params.n_seqs, remain2).max(1);
let fwd_tg = MTLSize::new(tg_c as u64, tg_t as u64, tg_s as u64);
encoder.encode(
fwd_pipeline,
&[
(0, x),
(1, kernel_w),
(2, old_state),
(3, y),
(4, params_buf),
],
fwd_grid,
fwd_tg,
);
// State update: one thread per (i, c, s), i in 0..K-1.
let state_pipeline = registry.get_pipeline(state_name, device)?;
let state_grid = MTLSize::new(
(params.k_width - 1) as u64,
params.channels as u64,
params.n_seqs as u64,
);
let su_tg_i = (params.k_width - 1).max(1);
let su_remain = (256u32 / su_tg_i).max(1);
let su_tg_c_raw = std::cmp::min(params.channels, su_remain).max(1);
// ADR-029 iter-175 Step 1r: ensure su_tg_i * su_tg_c is a multiple of
// 32 (Apple's threadExecutionWidth). Pre-fix this site computed
// (3, 85, 1) → 255 threads per TG when k_width=4 and channels≥85,
// which is NOT a multiple of 32 and causes UB under
// HF2Q_PIPELINE_TG_MULT_HINT=1. Coherence_smoke
// apex-q5km/{the-quick-brown-fox, what-is-22, hello-my-name-is}
// caught this via the Step 1q safety assertion.
//
// Strategy: shrink su_tg_c to the largest multiple of (32 / gcd(su_tg_i, 32))
// that still respects the channel/256 limit. gcd(3, 32) = 1 → step = 32,
// so su_tg_c rounds DOWN to a multiple of 32 (85 → 64). gcd(2, 32) = 2 →
// step = 16; gcd(4, 32) = 4 → step = 8; gcd(8, 32) = 8 → step = 4; etc.
// When su_tg_i is itself a multiple of 32 (rare for SSM conv), step = 1
// and any su_tg_c works — leaves the original value.
fn gcd_u32(mut a: u32, mut b: u32) -> u32 {
while b != 0 { let t = b; b = a % b; a = t; }
a
}
let step = 32u32 / gcd_u32(su_tg_i, 32);
let su_tg_c = if step <= 1 {
su_tg_c_raw
} else if su_tg_c_raw >= step {
(su_tg_c_raw / step) * step // round DOWN to multiple of step
} else {
// su_tg_c_raw < step: can't satisfy multiple-of-32 with this su_tg_i.
// Leave su_tg_c_raw alone; safety check will fire under
// HF2Q_PIPELINE_TG_MULT_HINT=1 (correct behavior — this case
// means the kernel inherently can't satisfy the constraint).
su_tg_c_raw
};
let su_remain2 = (256u32 / (su_tg_i * su_tg_c)).max(1);
let su_tg_s = std::cmp::min(params.n_seqs, su_remain2).max(1);
let state_tg = MTLSize::new(su_tg_i as u64, su_tg_c as u64, su_tg_s as u64);
encoder.encode(
state_pipeline,
&[
(0, x),
(1, old_state),
(2, new_state),
(3, params_buf),
],
state_grid,
state_tg,
);
Ok(())
}