mlx_native/kernel_registry.rs
1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`. On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached. Subsequent
6//! calls return the cached pipeline.
7//!
8//! ## ADR-029 iter-175 Step 1l: precompiled `.metallib` fast path
9//!
10//! `build.rs` runs `xcrun metal -O3` on every `.metal` file under
11//! `src/shaders/` and links the results into a single `default.metallib`
12//! placed in `OUT_DIR`. We embed the bytes via `include_bytes!`.
13//!
14//! When `MLX_PRECOMPILED_METALLIB=1` is set, `get_pipeline` and
15//! `get_pipeline_with_constants` first try to resolve the kernel function
16//! against this precompiled library; if found, build the pipeline from it
17//! (saves Apple's runtime source-compile pass). On any failure (function
18//! missing, empty embedded blob, load error) the code transparently falls
19//! back to the original source-compile path — byte-identical behavior.
20//!
21//! Empirical motivation (iter 1k test
22//! `tests/iter175_h_e_metallib_perf.rs`): on M5 Max at the gemma4 Q_sliding
23//! decode shape, precompiled is **+5.89% faster** than runtime-source
24//! (median 18.35 µs vs 19.49 µs) with a tighter p90 (19.54 vs 70.11) —
25//! Apple's runtime compile chain has occasional jitter outliers.
26
27use std::collections::HashMap;
28use std::sync::OnceLock;
29
30use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
31
32use crate::error::{MlxError, Result};
33
34/// Bytes of the precompiled `default.metallib` produced by `build.rs` from
35/// every `src/shaders/*.metal` file. Empty when `MLX_NATIVE_SKIP_METALLIB`
36/// was set at build time or xcrun was unavailable.
37const EMBEDDED_METALLIB: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/default.metallib"));
38
39/// Returns `true` when the precompiled `.metallib` fast path is enabled
40/// (default-ON since Step 1m validation). Set `MLX_PRECOMPILED_METALLIB=0`
41/// (or `false`, `off`) to opt out — useful for diagnosing kernel-compile
42/// regressions or A/B benching.
43///
44/// Step 1m multi-regime validation at HEAD (gemma4-APEX-Q5_K_M, M5 Max):
45/// - tg100 alt-pair 2-cycle: decode 95.5 → 95.9 t/s (+0.42%), prefill
46/// 179 → 185 t/s (+3.3%)
47/// - tg2000 alt-pair 1-cycle: decode 93.1 → 92.8 t/s (−0.32%, within σ),
48/// prefill 171 → 178.5 t/s (+4.4%)
49/// - coherence_smoke: 2/2 PASS both configs
50///
51/// The Step 1l initial smoke (`tg50 95.4 → 62.1`) did NOT reproduce on
52/// rebuild — most likely a cold-PSO-cache or first-run-of-rebuilt-binary
53/// artifact rather than a real perf regression.
54fn precompiled_enabled() -> bool {
55 static FLAG: OnceLock<bool> = OnceLock::new();
56 *FLAG.get_or_init(|| {
57 match std::env::var("MLX_PRECOMPILED_METALLIB").as_deref() {
58 Ok("0") | Ok("false") | Ok("off") => false,
59 _ => true,
60 }
61 })
62}
63
64/// Returns `true` when the precompiled `.metallib` is consulted for
65/// `get_pipeline_with_constants` (FCV-specialized) kernels. Inherits
66/// the master gate [`precompiled_enabled`]; both must be ON for the
67/// FCV path to use precompiled. Default-ON since Step 1m validation.
68fn precompiled_fcv_enabled() -> bool {
69 static FLAG: OnceLock<bool> = OnceLock::new();
70 *FLAG.get_or_init(|| {
71 match std::env::var("MLX_PRECOMPILED_METALLIB_FCV").as_deref() {
72 Ok("0") | Ok("false") | Ok("off") => false,
73 _ => true,
74 }
75 })
76}
77
78// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
79// Int = 29
80// Bool = 53
81// These are used when calling set_constant_value_at_index so the Metal runtime
82// knows how wide each constant value is.
83
84/// Registry that lazily compiles and caches Metal compute pipelines from
85/// embedded MSL source.
86///
87/// # Usage
88///
89/// ```ignore
90/// let mut registry = KernelRegistry::new();
91/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
92/// encoder.encode(&pipeline, &buffers, grid, tg);
93/// ```
94///
95/// # Thread Safety
96///
97/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
98/// `get_pipeline` to allow mutable cache insertion). If you need concurrent
99/// access, wrap it in a `Mutex` or use one registry per thread.
100pub struct KernelRegistry {
101 /// Cached pipelines keyed by kernel function name.
102 cache: HashMap<String, ComputePipelineState>,
103 /// MSL source text keyed by kernel function name.
104 ///
105 /// Populated at construction time with all embedded shader sources.
106 sources: HashMap<String, &'static str>,
107 /// ADR-029 iter-175 Step 1l: precompiled `default.metallib` (lazy).
108 ///
109 /// `None` initially. On first `get_pipeline*` call under
110 /// `MLX_PRECOMPILED_METALLIB=1`, populated via
111 /// `device.new_library_with_data(EMBEDDED_METALLIB)` — or set to a
112 /// sentinel "load-failed" marker (still `None`) so we don't retry.
113 /// Set to `Some(library)` on success.
114 ///
115 /// Lazily filled because we need a `&metal::DeviceRef` to load it
116 /// and `KernelRegistry::new()` does not have one.
117 precompiled_lib: Option<metal::Library>,
118 /// Whether we've already attempted to load the precompiled library.
119 /// Prevents repeated load attempts on failure.
120 precompiled_load_attempted: bool,
121}
122
123impl KernelRegistry {
124 /// Create a new registry with all embedded shader sources pre-registered.
125 ///
126 /// No compilation happens here — shaders are compiled lazily on first use.
127 pub fn new() -> Self {
128 let mut sources = HashMap::new();
129
130 // Register embedded shader sources.
131 sources.insert(
132 "placeholder".into(),
133 include_str!("shaders/placeholder.metal"),
134 );
135 sources.insert(
136 "quantized_matmul".into(),
137 include_str!("shaders/quantized_matmul.metal"),
138 );
139 sources.insert(
140 "quantized_matmul_simd".into(),
141 include_str!("shaders/quantized_matmul.metal"),
142 );
143 sources.insert(
144 "quantized_matmul_simd_bf16".into(),
145 include_str!("shaders/quantized_matmul.metal"),
146 );
147 sources.insert(
148 "quantized_matmul_simd_bf16_expert".into(),
149 include_str!("shaders/quantized_matmul.metal"),
150 );
151
152 // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
153 let ggml_src: &'static str =
154 include_str!("shaders/quantized_matmul_ggml.metal");
155 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
156 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
157 // ADR-028 iter-368: peer-style NSG=4 NR=2 variant (128 threads/TG).
158 sources.insert("kernel_mul_mv_q8_0_f32_nr2".into(), ggml_src);
159 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
160 // ADR-028 iter-309 — q6_K mat-vec with nr0=2 + cached yl[16]
161 // (peer-pattern port of llama.cpp's `kernel_mul_mv_q6_K_f32_impl`
162 // with N_R0_Q6_K=2; 4 rows/TG vs baseline's 2). Env-gated via
163 // `HF2Q_Q6K_MV_NR2=1` in the dispatcher.
164 sources.insert("kernel_mul_mv_q6_K_f32_nr2".into(), ggml_src);
165 // ADR-022 Phase 1 — Q5_1 / IQ4_NL dense mat-vec.
166 sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
167 sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
168 // ADR-013 P7 — Q4_K dense decode mat-vec (port of llama.cpp's
169 // kernel_mul_mv_q4_K_f32 at ggml-metal.metal:7715-7821).
170 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
171 // ADR-022 Phase 2 — Q5_K dense mv kernel.
172 sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
173
174 // GGML block-format quantized matrix-matrix kernels
175 // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
176 // Used at prefill m > 8 to reuse each weight tile across a 32-row
177 // block via threadgroup-staged simdgroup MMA, instead of re-reading
178 // every block per prompt-token as the mv kernel does.
179 let ggml_mm_src: &'static str =
180 include_str!("shaders/quantized_matmul_mm.metal");
181 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
182 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
183 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
184 // ADR-022 Phase 1 — dense Q5_1 / IQ4_NL mm.
185 sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
186 sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
187 // ADR-022 Phase 2 — dense Q5_K mm.
188 sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
189 // ADR-022 Phase 3 — dense Q4_K mm.
190 sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
191
192 // GGML block-format quantized matrix-matrix kernels — tensor API
193 // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
194 // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
195 // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
196 // primitive which on M3+ dispatches to hardware tensor cores for
197 // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
198 // Only compiled on devices where the tensor API is available; the
199 // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
200 // compilation so non-tensor devices transparently fall back to the
201 // non-tensor `kernel_mul_mm_<q>_f32` kernels.
202 let ggml_mm_tensor_src: &'static str =
203 include_str!("shaders/quantized_matmul_mm_tensor.metal");
204 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
205 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
206 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
207 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
208 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
209 // ADR-022 Phase 1 — Q5_1 / IQ4_NL tensor mm.
210 sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
211 sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
212 // ADR-022 Phase 2 — Q5_K tensor mm.
213 sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
214 // ADR-022 Phase 3 — Q4_K tensor mm + Q8_0 perm021.
215 sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
216 sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
217 // ADR-029 iter-30 H29-speed — F16-weight V2 large-tile mm.
218 // Same source file as the V2 quantized variants; reads F16 weight
219 // directly from device memory (no per-call dequant). Used when
220 // MlxQWeight.f16_shadow is populated and m > MM_ROUTING_THRESHOLD.
221 sources.insert("hf2q_mul_mm_tensor_v2_f16".into(), ggml_mm_tensor_src);
222 // ADR-029 iter-36 H28-D — F16-weight perm021 mm for O-projection.
223 // Same source file; reads F16 weight from MlxQWeight.f16_shadow when
224 // populated, bypassing the per-call quantized dequant. B-stage
225 // (bfloat permuted [n_heads, seq_len, head_dim] input) is byte-
226 // identical to the quantized variant.
227 sources.insert("kernel_mul_mm_f16_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
228 // ADR-029 iter-23 H28-A — V2 large-tile tensor mm (NRA=64 M, NRB=128 N).
229 // Same source file as V1 tensor mm; distinct kernel host names so the
230 // dispatcher can pick V1 vs V2 at runtime via HF2Q_LARGE_TILE_MM.
231 sources.insert("kernel_mul_mm_q4_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
232 sources.insert("kernel_mul_mm_q8_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
233 sources.insert("kernel_mul_mm_q6_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
234 sources.insert("kernel_mul_mm_q5_1_tensor_v2_f32".into(), ggml_mm_tensor_src);
235 sources.insert("kernel_mul_mm_iq4_nl_tensor_v2_f32".into(), ggml_mm_tensor_src);
236 sources.insert("kernel_mul_mm_q5_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
237 sources.insert("kernel_mul_mm_q4_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
238 // ADR-029 iter-28 H29 — whole-tensor dequant from block_q → F16.
239 // Used at model load to materialize an F16 shadow of attn/dense MLP
240 // weights so the runtime dispatch can use kernel_mul_mm_f16_f32_*
241 // (peer's gemma4 pattern). Trades ~1 GB resident memory for 2-3×
242 // faster per-call dense matmul at prefill.
243 let dequant_to_f16_src: &'static str =
244 include_str!("shaders/dequant_to_f16.metal");
245 sources.insert("hf2q_dequant_q4_0_to_f16".into(), dequant_to_f16_src);
246 sources.insert("hf2q_dequant_q8_0_to_f16".into(), dequant_to_f16_src);
247 sources.insert("hf2q_dequant_q5_1_to_f16".into(), dequant_to_f16_src);
248 sources.insert("hf2q_dequant_iq4_nl_to_f16".into(), dequant_to_f16_src);
249 sources.insert("hf2q_dequant_q4_K_to_f16".into(), dequant_to_f16_src);
250 sources.insert("hf2q_dequant_q5_K_to_f16".into(), dequant_to_f16_src);
251 sources.insert("hf2q_dequant_q6_K_to_f16".into(), dequant_to_f16_src);
252
253 // ADR-022 Phase 1 P1.7 — Q5_1 / IQ4_NL mul_mv_ext r1 family.
254 // Eight instantiations (2 types × 4 r1ptg widths). Each PSO is
255 // additionally specialized at PSO-compile time with FC_mul_mv_nsg
256 // (function_constant 600) and FC_mul_mv_nxpsg (function_constant 601).
257 let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
258 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
259 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
260 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
261 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
262 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
263 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
264 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
265 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
266 // ADR-022 Phase 4 — Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K mv_ext.
267 // 5 types × 4 r1ptg widths = 20 instantiations.
268 for r1 in [2, 3, 4, 5].iter() {
269 for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
270 let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
271 sources.insert(name, mul_mv_ext_src);
272 }
273 }
274
275 // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
276 // prefill Q@K^T and scores@V, modeled on llama.cpp's
277 // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
278 // active). Tile geometry and write-back identical to the
279 // quantized tensor kernel; only the A-stage copy (bfloat →
280 // bfloat, no dequantize) differs.
281 let dense_mm_bf16_tensor_src: &'static str =
282 include_str!("shaders/dense_mm_bf16_tensor.metal");
283 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
284 // ADR-029 iter-80 H60: V2 large-tile variant (NRA=64, NRB=128).
285 // Same source file (`dense_mm_bf16_tensor.metal`) — second host_name
286 // entry resolves to the V2 kernel appended at the bottom of that
287 // file. Picked at dispatch time when HF2Q_LARGE_TILE_MM=1.
288 sources.insert("hf2q_dense_mm_bf16_f32_tensor_v2".into(), dense_mm_bf16_tensor_src);
289
290 // Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
291 // sibling of dense_mm_bf16_tensor). Used by hf2q's ADR-005
292 // iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
293 // the BF16 K-stage cast as a confounding variable. Port of
294 // llama.cpp's kernel_mul_mm_f32_f32 specialization
295 // (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
296 // branch. Same tile geometry (NR0=64 NR1=32 NK=32) but
297 // float-everywhere shmem staging.
298 let dense_mm_f32_f32_tensor_src: &'static str =
299 include_str!("shaders/dense_mm_f32_f32.metal");
300 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
301
302 // Dense f16×f32 → f32 tensor-API matmul (F16-staging sibling
303 // of dense_mm_bf16_tensor). Used by hf2q's ADR-005 Phase 2c
304 // iter-128 gemma4v ViT precision-parity path: every mmproj
305 // weight is stored as F16 in GGUF, peer's `kernel_mul_mm_f16_f32`
306 // (`ggml-metal.metal:10099`) stages BOTH A and B as `half` in
307 // shmem and computes on `simdgroup_half8x8`. Matches peer
308 // per-element rounding budget exactly (10-bit mantissa vs
309 // BF16's 7-bit), closing the 1.16x/block cascade compound that
310 // iter-127 numerically bisected to BF16 staging. Same tile
311 // geometry as the BF16 sibling (NR0=64 NR1=32 NK=32, 8 KB
312 // shmem) — half and bfloat share 16-bit storage.
313 let dense_mm_f16_tensor_src: &'static str =
314 include_str!("shaders/dense_mm_f16_tensor.metal");
315 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
316
317 // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
318 // for M=1 single-token decode. Port of llama.cpp's
319 // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
320 // Used in apply_linear_projection_f32 when seq_len=1 and the
321 // weight matrix is BF16, replacing the MM kernel (~2× faster for
322 // M=1 due to better memory bandwidth utilization per thread).
323 let dense_gemv_bf16_src: &'static str =
324 include_str!("shaders/dense_gemv_bf16.metal");
325 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
326
327 // Fused scale-mask-softmax for the non-flash-attention prefill
328 // path. One row-local threadgroup per (head, query) pair
329 // replaces three separate dispatches (scale, mask-add, softmax);
330 // reads a bf16 mask (-INF at masked positions, matching
331 // flash_attn_prefill_mask.metal) that is shared across heads.
332 let scale_mask_softmax_src: &'static str =
333 include_str!("shaders/scale_mask_softmax.metal");
334 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
335 // ADR-029 iter-93 H71: float4-vectorized variant for peer parity
336 // with kernel_soft_max_f32_4. Same source file; v4 host_name resolves
337 // to the second kernel appended at the bottom of scale_mask_softmax.metal.
338 sources.insert("scale_mask_softmax_f32_v4".into(), scale_mask_softmax_src);
339
340 // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
341 sources.insert(
342 "quantized_matmul_id".into(),
343 include_str!("shaders/quantized_matmul_id.metal"),
344 );
345
346 // Expert-routed (MoE) GGML block-format quantized matmul kernels
347 let ggml_id_src: &'static str =
348 include_str!("shaders/quantized_matmul_id_ggml.metal");
349 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
350 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
351 // ADR-013 P7 — Q4_K MoE expert-routed mat-vec (port of
352 // llama.cpp's kernel_mul_mv_id_q4_K_f32 at ggml-metal.metal:10349).
353 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
354 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
355 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
356 // ADR-028 iter-321 — q6_K _id with nr0=2 + cached yl[16]
357 // (peer-pattern port mirroring iter-309's non-_id variant).
358 // Env-gated via HF2Q_Q6K_ID_MV_NR2=1 in dispatch_id_mv.
359 sources.insert("kernel_mul_mv_id_q6_K_f32_nr2".into(), ggml_id_src);
360 // ADR-029 iter-6 — q8_0 _id with nr0=2 + nsg=4 cross-SG reduce
361 // (peer-pattern port; peer N_R0_Q8_0=2 + N_SG_Q8_0=4 in
362 // /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h:27,40).
363 // Env-gated via HF2Q_Q8_0_ID_MV_NR2=1 in dispatch_id_mv.
364 sources.insert("kernel_mul_mv_id_q8_0_f32_nr2".into(), ggml_id_src);
365 // ADR-022 Phase 1 — Q5_1 / IQ4_NL MoE expert-routed mat-vec.
366 sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
367 sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
368 // Fused-SwiGLU mv_id variants (ADR-012 §Optimize / Task #15):
369 // computes y[r][n] = sum_k(dequant(W[expert][n][k]) * silu(gate[r][k]) * up[r][k])
370 // in one dispatch — replaces silu_mul + expert_down sequence.
371 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
372
373 // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
374 // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
375 // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
376 // Two-stage dispatch: map0 regroups the token-to-expert table into
377 // per-expert routed-token lists, then mm_id stages a 64x32 expert
378 // weight tile into threadgroup shmem and reuses it across a 32-row
379 // block of that expert's routed tokens.
380 let ggml_id_mm_src: &'static str =
381 include_str!("shaders/quantized_matmul_id_mm.metal");
382 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
383 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
384 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
385 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
386 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
387 // ADR-013 P16 — Q4_K mm_id (port of llama.cpp ggml-metal.metal:10169).
388 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
389 // ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL mm_id template instantiations.
390 sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
391 sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
392 // ADR-022 Phase 2 — Q5_K mm_id template instantiation.
393 sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
394
395 // MoE-routed quantized matrix-matrix kernels — tensor API variant
396 // (ADR-011 Phase 3 Wave P3b-tensor). Uses the MPP tensor_ops
397 // matmul2d primitive for hardware-tensor-core MMA on M3+. Only
398 // the mm_id kernel is ported — map0 is a short pre-pass (not
399 // matmul) and continues to use the simdgroup version.
400 let ggml_id_mm_tensor_src: &'static str =
401 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
402 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
403 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
404 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
405 // ADR-013 P16 — Q4_K tensor-API mm_id.
406 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
407 // ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL tensor-API mm_id.
408 sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
409 sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
410 // ADR-022 Phase 2 — Q5_K tensor-API mm_id.
411 sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
412
413 // Embedding kernels (Story 1.5)
414 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
415 sources.insert("embedding_gather_4bit".into(), embedding_src);
416 sources.insert("embedding_gather_6bit".into(), embedding_src);
417
418 // MoE gate kernel (Story 1.5)
419 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
420 sources.insert("moe_gate".into(), moe_gate_src);
421
422 // MoE dispatch kernels (Story 1.5)
423 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
424 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
425 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
426 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
427 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
428 sources.insert("moe_accumulate".into(), moe_dispatch_src);
429 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
430 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
431 sources.insert("zero_buffer".into(), moe_dispatch_src);
432 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
433 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
434 // bf16 variants (Phase 2 bf16 activation path)
435 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
436 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
437 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
438 // ADR-020 iter-11h-e3a: backward kernels for moe_weighted_sum_seq.
439 sources.insert(
440 "moe_weighted_sum_seq_backward_outputs_f32".into(),
441 moe_dispatch_src,
442 );
443 sources.insert(
444 "moe_weighted_sum_seq_backward_weights_f32".into(),
445 moe_dispatch_src,
446 );
447 // ADR-020 iter-11h-e3b: fused backward kernel for moe_swiglu_seq.
448 sources.insert(
449 "moe_swiglu_seq_backward_f32".into(),
450 moe_dispatch_src,
451 );
452
453 // Batched KV cache copy kernels
454 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
455 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
456 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
457 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
458 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
459 // Wave P4.11 — fused K+V copy variants
460 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
461 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
462 // ADR-028 iter-145 — fused single-position K+V copy variants (decode shape)
463 sources.insert("kv_cache_copy_batch_f32_kv_dual".into(), kv_cache_src);
464 sources.insert("kv_cache_copy_batch_f32_to_f16_kv_dual".into(), kv_cache_src);
465 // bf16-source KV cache copy (Phase 2 bf16 activation path)
466 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
467 // ADR-030 iter-95 — bit-exact BF16→BF16 head-major cache copy for
468 // Option A xlen verify (avoids F16 round-trip precision drift).
469 sources.insert("kv_cache_copy_seq_bf16_to_bf16_head_major".into(), kv_cache_src);
470
471 // Elementwise and transpose kernels (Story 1.5)
472 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
473 sources.insert("elementwise_add_f32".into(), elementwise_src);
474 sources.insert("elementwise_add_f16".into(), elementwise_src);
475 sources.insert("elementwise_mul_f32".into(), elementwise_src);
476 sources.insert("elementwise_mul_f16".into(), elementwise_src);
477 sources.insert("elementwise_add_bf16".into(), elementwise_src);
478 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
479 sources.insert("cast_f16_to_f32".into(), elementwise_src);
480 sources.insert("cast_f32_to_f16".into(), elementwise_src);
481 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
482 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
483 sources.insert("scalar_mul_bf16".into(), elementwise_src);
484 sources.insert("scalar_mul_f32".into(), elementwise_src);
485 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
486 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
487 sources.insert("permute_021_bf16".into(), elementwise_src);
488 sources.insert("transpose_last2_bf16".into(), elementwise_src);
489 sources.insert("transpose_last2_f16".into(), elementwise_src);
490 sources.insert("permute_021_f32".into(), elementwise_src);
491 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
492 sources.insert("transpose_2d_f32".into(), elementwise_src);
493 sources.insert("transpose_2d_f16".into(), elementwise_src);
494
495 // Attention kernels (Story 1.3)
496 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
497 sources.insert("sdpa".into(), sdpa_src);
498 sources.insert("sdpa_bf16".into(), sdpa_src);
499 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
500 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
501 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
502
503 // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
504 // Ten entry points; all backed by the same shader source.
505 // Pipelines are compiled with function constants via
506 // `get_pipeline_with_bool_constants` — not `get_pipeline`.
507 let flash_attn_prefill_src: &'static str =
508 include_str!("shaders/flash_attn_prefill.metal");
509 // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
510 sources.insert(
511 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
512 flash_attn_prefill_src,
513 );
514 sources.insert(
515 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
516 flash_attn_prefill_src,
517 );
518 sources.insert(
519 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
520 flash_attn_prefill_src,
521 );
522 sources.insert(
523 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
524 flash_attn_prefill_src,
525 );
526 sources.insert(
527 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
528 flash_attn_prefill_src,
529 );
530 sources.insert(
531 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
532 flash_attn_prefill_src,
533 );
534 // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
535 // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
536 // the 32 KB Metal limit (candle sdpa.rs:86-94).
537 sources.insert(
538 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
539 flash_attn_prefill_src,
540 );
541 sources.insert(
542 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
543 flash_attn_prefill_src,
544 );
545 sources.insert(
546 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
547 flash_attn_prefill_src,
548 );
549 sources.insert(
550 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
551 flash_attn_prefill_src,
552 );
553
554 // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
555 // (ported from llama.cpp flash_attn_ext_vec)
556 let flash_attn_vec_src: &'static str =
557 include_str!("shaders/flash_attn_vec.metal");
558 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
559 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
560 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
561 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
562 // F16 KV variants (Phase 4a)
563 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
564 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
565
566 // RoPE, normalization, activation kernels (Story 1.4)
567 let rope_src: &'static str = include_str!("shaders/rope.metal");
568 sources.insert("rope_f32".into(), rope_src);
569 sources.insert("rope_f16".into(), rope_src);
570 sources.insert("rope_bf16".into(), rope_src);
571 sources.insert("rope_neox_bf16".into(), rope_src);
572 sources.insert("rope_neox_f32".into(), rope_src);
573 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
574 sources.insert("rms_norm_f32".into(), rms_norm_src);
575 // ADR-028 iter-310 — float4 + simd_sum variants (peer-pattern,
576 // ported from llama.cpp kernel_rms_norm_fuse_impl<float4, 1>).
577 // Env-gated via HF2Q_RMS_NORM_V2=1 in the dispatchers.
578 sources.insert("rms_norm_f32_v2".into(), rms_norm_src);
579 sources.insert("rms_norm_no_scale_f32_v2".into(), rms_norm_src);
580 sources.insert("rms_norm_f16".into(), rms_norm_src);
581 sources.insert("rms_norm_bf16".into(), rms_norm_src);
582 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
583 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
584 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
585 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
586 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
587 // ADR-028 iter-370: V2 (float4 + simd_sum) variant of triple_norm.
588 sources.insert("fused_post_attn_triple_norm_f32_v2".into(), rms_norm_src);
589 // ADR-028 iter-217: fused post-FF norm 2 + end-of-layer FINAL
590 // (combines 2 sequential fused_norm_add dispatches into 1 kernel).
591 sources.insert("fused_post_ff_norm2_endlayer_f32".into(), rms_norm_src);
592 // ADR-028 iter-362: V2 (float4 + simd_sum) variant of the above.
593 // Same math, 75% fewer barriers per dispatch (4 vs 16 at tg=256).
594 sources.insert("fused_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
595 // ADR-028 iter-367: V2 fusion of moe_weighted_sum INTO Path A end-of-layer.
596 // Eliminates 1 dispatch + moe_accum round-trip from gemma4 decode default.
597 sources.insert("fused_moe_wsum_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
598 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
599 // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
600 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
601 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
602 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
603 // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
604 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
605 sources.insert("l2_norm_f32".into(), l2_norm_src);
606 sources.insert("l2_norm_f16".into(), l2_norm_src);
607 sources.insert("l2_norm_bf16".into(), l2_norm_src);
608 // ADR-015 iter59a — fused L2 norm + scalar multiply (DN q-path).
609 sources.insert("l2_norm_scale_f32".into(), l2_norm_src);
610 // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
611 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
612 sources.insert("cumsum_f32".into(), cumsum_src);
613 sources.insert("cumsum_bf16".into(), cumsum_src);
614 // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
615 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
616 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
617 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
618 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
619 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
620 // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
621 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
622 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
623 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
624 // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
625 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
626 sources.insert("rope_multi_f32".into(), rope_multi_src);
627 sources.insert("rope_multi_bf16".into(), rope_multi_src);
628 // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
629 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
630 sources.insert("gated_delta_net_f32".into(), gdn_src);
631 // ADR-015 iter56 — decode-only `simd_sum` variant. Three NSG-templated
632 // host names share the same source; selection is by D_k via
633 // `dispatch_gated_delta_net_decode`. Drop-in for the fused kernel
634 // above when n_tokens=1.
635 let gdn_decode_src: &'static str =
636 include_str!("shaders/gated_delta_net_decode.metal");
637 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
638 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
639 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
640 // Wave 5b — chunk-parallel inter-chunk state-recurrence kernel
641 // (the one new kernel in the chunk-parallel pipeline; spec source:
642 // arXiv 2412.06464 §4 + FLA chunk_delta_h.py:43-298).
643 let gdn_chunk_src: &'static str =
644 include_str!("shaders/gated_delta_net_chunk.metal");
645 sources.insert(
646 "gated_delta_net_chunk_inter_state_bf16".into(),
647 gdn_chunk_src,
648 );
649 // Wave 5b.1 iter 2 — chunk_scaled_dot_kkt kernel (input-side of
650 // the chunk pipeline; spec source: FLA chunk_scaled_dot_kkt.py:36-99).
651 let gdn_kkt_src: &'static str =
652 include_str!("shaders/gated_delta_net_kkt.metal");
653 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
654 // Wave 5b.1 iter 2 — recompute_w_u_fwd kernel (applies post-solve A
655 // to (β·v) and (β·k·exp(g)) to produce w and u; spec source: FLA
656 // wy_fast.py:29-117).
657 let gdn_recompute_wu_src: &'static str =
658 include_str!("shaders/gated_delta_net_recompute_wu.metal");
659 sources.insert(
660 "gated_delta_net_recompute_wu_bf16".into(),
661 gdn_recompute_wu_src,
662 );
663 // Wave 5b.1 iter 3 — chunk_fwd_o kernel (per-chunk output: closes
664 // the chunk pipeline; spec source: FLA chunk_o.py:42-138).
665 let gdn_chunk_o_src: &'static str =
666 include_str!("shaders/gated_delta_net_chunk_o.metal");
667 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
668 // Wave 5b.1 iter 4 — orchestrator helper kernels:
669 // chunk_local_cumsum_g_f32 — per-chunk prefix sum on g [B, T, H]
670 // chunk_tri_solve_invert_f32 — per-chunk-block (I + A_strict)^-1
671 // on FLA's [B, T, H, BT] layout.
672 let chunk_local_cumsum_g_src: &'static str =
673 include_str!("shaders/chunk_local_cumsum_g.metal");
674 sources.insert(
675 "chunk_local_cumsum_g_f32".into(),
676 chunk_local_cumsum_g_src,
677 );
678 let chunk_tri_solve_invert_src: &'static str =
679 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
680 sources.insert(
681 "chunk_tri_solve_invert_f32".into(),
682 chunk_tri_solve_invert_src,
683 );
684 // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
685 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
686 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
687 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
688 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
689 sources.insert("silu_mul_f32".into(), silu_mul_src);
690 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
691 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
692 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
693 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
694 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
695 sources.insert("gelu_f32".into(), gelu_src);
696 sources.insert("gelu_f16".into(), gelu_src);
697 sources.insert("gelu_bf16".into(), gelu_src);
698 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
699 sources.insert("softmax_f32".into(), softmax_src);
700 sources.insert("softmax_f16".into(), softmax_src);
701 sources.insert("softmax_bf16".into(), softmax_src);
702 let softmax_backward_src: &'static str =
703 include_str!("shaders/softmax_backward.metal");
704 sources.insert("softmax_backward_f32".into(), softmax_backward_src);
705 let log_elementwise_src: &'static str =
706 include_str!("shaders/log_elementwise.metal");
707 sources.insert("log_f32".into(), log_elementwise_src);
708 sources.insert("log_backward_f32".into(), log_elementwise_src);
709 let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
710 sources.insert("row_sum_f32".into(), row_sum_src);
711 sources.insert("row_sum_backward_f32".into(), row_sum_src);
712 // ADR-020 iter-10a: GGUF-legacy quantize-dequantize round-trip kernels
713 // (Q4_0 + Q8_0). Used by hf2q's dynamic_quant Track 1 to produce
714 // W_low / W_high for the gradient-Taylor sensitivity formula.
715 let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
716 sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
717 sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
718 // ADR-020 iter-10b: RMSNorm reverse-mode autograd kernels.
719 // r_inv helper is reused by both backward kernels; dx and dw cover
720 // the full backward identity for `y = x * rsqrt(mean(x²) + eps) * w`.
721 let rms_norm_backward_src: &'static str =
722 include_str!("shaders/rms_norm_backward.metal");
723 sources.insert(
724 "rms_norm_compute_rms_inv_f32".into(),
725 rms_norm_backward_src,
726 );
727 sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
728 sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
729 // ADR-020 iter-11a: 2-D row-major slice + concat-by-column kernels.
730 // Used by hf2q's multi-head SDPA on GpuTape (slice Q/K/V into
731 // per-head views, run per-head SDPA, concat per-head contexts
732 // back to full attention output).
733 let slice_concat_2d_src: &'static str =
734 include_str!("shaders/slice_concat_2d.metal");
735 sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
736 sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
737 // ADR-020 iter-11b: SiLU forward + backward kernels for GpuTape
738 // SwiGLU FFN composition.
739 let silu_backward_src: &'static str =
740 include_str!("shaders/silu_backward.metal");
741 sources.insert("silu_f32".into(), silu_backward_src);
742 sources.insert("silu_backward_f32".into(), silu_backward_src);
743 // ADR-020 iter-11d: FP32 embedding lookup + scatter-add backward.
744 let embedding_autograd_src: &'static str =
745 include_str!("shaders/embedding_autograd.metal");
746 sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
747 sources.insert(
748 "embedding_scatter_add_f32".into(),
749 embedding_autograd_src,
750 );
751 // ADR-020 iter-13a: Adam optimizer step kernel for Track 2
752 // DWQ-proper training loop.
753 let adam_update_src: &'static str =
754 include_str!("shaders/adam_update.metal");
755 sources.insert("adam_update_f32".into(), adam_update_src);
756 // ADR-020 iter-13b: differentiable affine qdq kernels for the
757 // DWQ-proper training loop. Init + forward + backward (scales,
758 // biases) — q_int is FROZEN, scales+biases learnable.
759 let qdq_affine_src: &'static str =
760 include_str!("shaders/qdq_affine.metal");
761 sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
762 sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
763 sources.insert(
764 "qdq_affine_backward_scales_f32".into(),
765 qdq_affine_src,
766 );
767 sources.insert(
768 "qdq_affine_backward_biases_f32".into(),
769 qdq_affine_src,
770 );
771 // ADR-020 iter-15: fused affine quantized matmul for DWQ inference.
772 // Per-element kernel; one thread per (m, n) output element.
773 // Tiled + simdgroup-MMA variant lands in iter-15b.
774 let qmm_affine_src: &'static str =
775 include_str!("shaders/qmm_affine.metal");
776 sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
777 // ADR-020 iter-15b: tiled variant — 16x16 thread block with
778 // cooperative-load X/W tiles in threadgroup-shared memory for
779 // 2-5x speedup over the per-element kernel.
780 let qmm_affine_tiled_src: &'static str =
781 include_str!("shaders/qmm_affine_tiled.metal");
782 sources.insert(
783 "qmm_affine_t_f32_tiled".into(),
784 qmm_affine_tiled_src,
785 );
786 // ADR-020 iter-15c: simdgroup-MMA variant — uses Apple GPU
787 // hardware `simdgroup_matrix<float, 8, 8>` MMA for the inner
788 // reduction. Per-tile algorithmic 8× over scalar tiled, lands
789 // as ~3-4× wall after launch / load amortization.
790 let qmm_affine_simd_src: &'static str =
791 include_str!("shaders/qmm_affine_simd.metal");
792 sources.insert(
793 "qmm_affine_t_f32_simd".into(),
794 qmm_affine_simd_src,
795 );
796 // ADR-020 iter-15c-2: 4-simdgroup-per-TG variant — 32×32
797 // output tile, 4 simdgroups arranged as 2×2 grid each owning
798 // a 16×16 sub-tile = 4 simdgroup_matrix accumulators. Same
799 // math as 15c-1, fuller warp-pool exploitation.
800 let qmm_affine_simd4_src: &'static str =
801 include_str!("shaders/qmm_affine_simd4.metal");
802 sources.insert(
803 "qmm_affine_t_f32_simd4".into(),
804 qmm_affine_simd4_src,
805 );
806 // ADR-020 iter-15c-2b: gs=64 variant (mlx-lm dynamic_quant
807 // canonical default). Same 4-simdgroup geometry, BK=64
808 // instead of 32 (= 8 sub-K-tiles per K-step instead of 4).
809 let qmm_affine_simd4_gs64_src: &'static str =
810 include_str!("shaders/qmm_affine_simd4_gs64.metal");
811 sources.insert(
812 "qmm_affine_t_f32_simd4_gs64".into(),
813 qmm_affine_simd4_gs64_src,
814 );
815 // ADR-020 AC#5 Iter A: packed-U32 dense affine matmul (bits=4,
816 // gs=32) — production decode/prefill kernel for serving DWQ
817 // safetensors directly without a load-time unpack pass.
818 let qmm_affine_t_packed_simd4_b4_src: &'static str =
819 include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
820 sources.insert(
821 "qmm_affine_t_packed_simd4_b4".into(),
822 qmm_affine_t_packed_simd4_b4_src,
823 );
824 // ADR-020 iter-11h-b: training-mode causal depthwise 1D
825 // convolution (forward + backward dx + backward dw). Used by
826 // GpuTape autograd for differentiable Qwen3.5MoE forward
827 // (GatedDeltaNet's conv1d step).
828 let conv1d_dwc_src: &'static str =
829 include_str!("shaders/conv1d_depthwise_causal.metal");
830 sources.insert(
831 "conv1d_depthwise_causal_forward_f32".into(),
832 conv1d_dwc_src,
833 );
834 sources.insert(
835 "conv1d_depthwise_causal_backward_dx_f32".into(),
836 conv1d_dwc_src,
837 );
838 sources.insert(
839 "conv1d_depthwise_causal_backward_dw_f32".into(),
840 conv1d_dwc_src,
841 );
842 // ADR-020 iter-11h-c1: elementwise exp forward + backward.
843 // Building block for GatedDeltaNet's alpha = exp(-g) state-decay.
844 let exp_src: &'static str =
845 include_str!("shaders/exp_elementwise.metal");
846 sources.insert("exp_f32".into(), exp_src);
847 sources.insert("exp_backward_f32".into(), exp_src);
848 // ADR-020 iter-11h-c2: vector outer product (forward + dlhs +
849 // drhs). Building block for gated_delta_update's
850 // outer(delta, k) state-update term.
851 let outer_src: &'static str =
852 include_str!("shaders/outer_product.metal");
853 sources.insert("outer_product_f32".into(), outer_src);
854 sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
855 sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
856 // ADR-020 iter-11h-e1: take_along_axis (gather) + scatter-backward.
857 // Building block for MoE router on GpuTape.
858 let taa_src: &'static str =
859 include_str!("shaders/take_along_axis.metal");
860 sources.insert("take_along_axis_f32".into(), taa_src);
861 sources.insert("take_along_axis_backward_f32".into(), taa_src);
862 // ADR-020 iter-11h-misc-1: elementwise divide forward + backward.
863 let div_src: &'static str =
864 include_str!("shaders/divide_elementwise.metal");
865 sources.insert("divide_f32".into(), div_src);
866 sources.insert("divide_backward_f32".into(), div_src);
867 // ADR-020 iter-11h-misc-3: elementwise sqrt forward + backward.
868 let sqrt_src: &'static str =
869 include_str!("shaders/sqrt_elementwise.metal");
870 sources.insert("sqrt_f32".into(), sqrt_src);
871 sources.insert("sqrt_backward_f32".into(), sqrt_src);
872 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
873 sources.insert("softcap_f32".into(), softcap_src);
874 sources.insert("softcap_f16".into(), softcap_src);
875 sources.insert("softcap_bf16".into(), softcap_src);
876
877 // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
878 // normed = rms_norm(input, weight, eps); output = residual + normed
879 let fused_norm_add_src: &'static str =
880 include_str!("shaders/fused_norm_add_bf16.metal");
881 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
882 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
883
884 // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
885 let fused_hnr_f32_src: &'static str =
886 include_str!("shaders/fused_head_norm_rope_f32.metal");
887 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
888 // ADR-028 iter-337 — float4 + simd_sum Phase 1 variant. Phases
889 // 2-4 byte-identical to v1; race-fix barrier preserved. Env-gated
890 // via HF2Q_FUSED_HEAD_NORM_ROPE_V2 (default ON, opt-out via =0).
891 sources.insert("fused_head_norm_rope_f32_v2".into(), fused_hnr_f32_src);
892
893 // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
894 // Both entry points live in the same .metal file.
895 let fused_hnr_bf16_src: &'static str =
896 include_str!("shaders/fused_head_norm_rope_bf16.metal");
897 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
898 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
899
900 // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
901 let fused_norm_add_f32_src: &'static str =
902 include_str!("shaders/fused_norm_add_f32.metal");
903 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
904 // ADR-028 iter-331 — float4 + simd_sum variant (peer-pattern,
905 // ported from llama.cpp kernel_rms_norm_fuse_impl<float4, 3>).
906 // Env-gated via HF2Q_FUSED_NORM_ADD_V2=1 in the dispatcher
907 // (default ON since iter-331; opt-out via =0/false/off).
908 sources.insert("fused_norm_add_f32_v2".into(), fused_norm_add_f32_src);
909 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
910 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
911 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
912 // ADR-028 iter-363: V2 (simd_max + simd_sum) variant of MoE routing.
913 sources.insert("fused_moe_routing_f32_v2".into(), fused_norm_add_f32_src);
914 // ADR-029 iter-175 Step 1i: V3 = V2 + parallel SG-tournament top-K
915 // (replaces V2's single-thread serial scan for k = 0..top_k).
916 sources.insert("fused_moe_routing_f32_v3".into(), fused_norm_add_f32_src);
917 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
918 // ADR-029 iter-175 Step 1j: batched-prefill V3 (same parallel
919 // SG-tournament top-K as fused_moe_routing_f32_v3, applied per-token
920 // within each TG of the batched dispatch).
921 sources.insert("fused_moe_routing_batch_f32_v3".into(), fused_norm_add_f32_src);
922 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
923 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
924 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
925
926 // Argsort kernel (Story 2.3) — MoE top-K routing
927 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
928 sources.insert("argsort_desc_f32".into(), argsort_src);
929
930 // Gather / index_select kernel (Story 2.4)
931 let gather_src: &'static str = include_str!("shaders/gather.metal");
932 sources.insert("gather_f32".into(), gather_src);
933
934 // F32 KV cache copy kernel (Session merge S1+S2)
935 let kv_cache_copy_src: &'static str =
936 include_str!("shaders/kv_cache_copy.metal");
937 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
938 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
939
940 // Strided copy kernel (Story 2.5)
941 let copy_src: &'static str = include_str!("shaders/copy.metal");
942 sources.insert("strided_copy_f32".into(), copy_src);
943 sources.insert("offset_copy_f32".into(), copy_src);
944
945 // Fused-QKV split kernel (ADR-005 W-5b.18 — replaces hf2q CPU
946 // download → triple-loop split → 3× upload round-trip in
947 // gpu_delta_net::layer_qkv_deinterleave).
948 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
949 sources.insert("qkv_split_f32".into(), qkv_split_src);
950
951 // Tiled-GQA broadcast kernel (ADR-005 W-5b.19 — replaces hf2q CPU
952 // tiled-replicate at gpu_delta_net::apply_gated_delta_net_chunk
953 // GQA pre-expansion, ~497 ms / 10.4 ms-per-layer at PP4106).
954 let repeat_tiled_src: &'static str =
955 include_str!("shaders/repeat_tiled.metal");
956 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
957
958 // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
959 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
960 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
961 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
962 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
963 // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
964 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
965 // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
966 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
967
968 // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
969 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
970 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
971 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
972 // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
973 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
974 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
975 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
976 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
977
978 // Fast Hadamard quantize (SIMD shuffle, zero barriers)
979 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
980 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
981 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
982 // ADR-028 iter-485 (Phase 7d / H4): fused K+V single-position 4-bit encoder.
983 sources.insert("hadamard_quantize_kv_fast_dual_d256".into(), hq_fast_src);
984 sources.insert("hadamard_quantize_kv_fast_dual_d512".into(), hq_fast_src);
985 // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
986 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
987 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
988 // ADR-028 iter-148: fused K+V single-position HB encoder
989 sources.insert("hadamard_quantize_kv_hb_dual_d256".into(), hq_fast_src);
990 sources.insert("hadamard_quantize_kv_hb_dual_d512".into(), hq_fast_src);
991 // ADR-028 Phase 10e.5 (iter-351): no-FWHT V quantize for hybrid path.
992 // Same byte-packed Lloyd-Max codebook output, but skips the Hadamard
993 // rotation so dequant in SDPA recovers raw V (no FWHT-undo needed).
994 sources.insert("kv_quantize_v_no_fwht_d256".into(), hq_fast_src);
995 sources.insert("kv_quantize_v_no_fwht_d512".into(), hq_fast_src);
996 // ADR-028 Phase 10c.5 (iter-354): fused F16-K-copy + V-no-FWHT-encode.
997 // Saves 30 KV-write dispatches/decode-token at gemma4 30L by combining
998 // the per-layer K-cast and V-encode into a single dispatch (Z-dim).
999 sources.insert("kv_copy_kf16_quantize_v_no_fwht_d256".into(), hq_fast_src);
1000 sources.insert("kv_copy_kf16_quantize_v_no_fwht_d512".into(), hq_fast_src);
1001
1002 // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
1003 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
1004 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
1005 // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
1006 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
1007 // ADR-027 Phase B iter-30 (hf2q sub-sub-iter 23c-β.1): sequence-batch
1008 // dequant variant. Same MSL source; new kernel entry point
1009 // `tq_dequantize_hb_kv_seq` reads positions [start_pos..start_pos+n_tokens)
1010 // in one dispatch (one threadgroup per (kv_head, position)). Unblocks
1011 // hf2q's TQ-aware prefill SDPA path (current per-position kernel
1012 // requires cur_len separate dispatches).
1013 sources.insert("tq_dequantize_hb_kv_seq".into(), tq_dq_src);
1014
1015 // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
1016 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
1017 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
1018 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
1019
1020 // ADR-028 §iter-485 (Phase 7d H3): fused TQ-HB reduce + FWHT-sign-undo.
1021 // Combines flash_attn_vec_reduce + fwht_sign_undo_f32 into a single
1022 // dispatch, saving 1 dispatch + 1 forced barrier per layer per decode
1023 // token. Gated by env flag `HF2Q_TQ_HB_OUT_FUSED=1` in forward_mlx.rs.
1024 let reduce_undo_src: &'static str = include_str!("shaders/flash_attn_vec_reduce_tq_hb_undo.metal");
1025 sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk256".into(), reduce_undo_src);
1026 sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk512".into(), reduce_undo_src);
1027
1028 // ADR-028 Phase 10d (iter-349): hybrid F16-K + TQ-HB-V SDPA kernel.
1029 // Same V-side codebook as flash_attn_vec_tq_hb (5/6/8-bit Lloyd-Max);
1030 // K-side reads F16 dense — peer-equivalent layout, no codebook lookup.
1031 let hybrid_src: &'static str = include_str!("shaders/flash_attn_vec_hybrid.metal");
1032 sources.insert("flash_attn_vec_hybrid_dk256".into(), hybrid_src);
1033 sources.insert("flash_attn_vec_hybrid_dk512".into(), hybrid_src);
1034
1035 // ADR-029 CFA cfa-20260512-fa-peer-port (iter-122): verbatim llama.cpp peer port.
1036 // F16-K + F16-V, DK=DV=256, NWG=1, NSG=1, NE=1. No function constants — baked.
1037 let peer_port_src: &'static str = include_str!("shaders/flash_attn_vec_peer_port_f16.metal");
1038 sources.insert("flash_attn_vec_peer_port_f16_dk256_dv256".into(), peer_port_src);
1039
1040 // ADR-029 iter-134: peer reduce kernel (verbatim port of ggml-metal.metal 7235-7275).
1041 // Pairs with the NWG=32 vec kernel to match peer's actual runtime dispatch.
1042 let peer_port_reduce_src: &'static str =
1043 include_str!("shaders/flash_attn_vec_peer_port_f16_reduce.metal");
1044 sources.insert(
1045 "flash_attn_vec_peer_port_f16_reduce_dv256_nwg32".into(),
1046 peer_port_reduce_src,
1047 );
1048
1049 // ADR-029 iter-135: NWG=32 variant of the verbatim peer port. Same body as
1050 // flash_attn_vec_peer_port_f16.metal with NWG=1→32. Pairs with iter-134 reduce kernel.
1051 let peer_port_nwg32_src: &'static str =
1052 include_str!("shaders/flash_attn_vec_peer_port_f16_nwg32.metal");
1053 sources.insert(
1054 "flash_attn_vec_peer_port_f16_nwg32_dk256_dv256".into(),
1055 peer_port_nwg32_src,
1056 );
1057
1058 // GPU sampling kernels — eliminate logits readback (Phase 6)
1059 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
1060 sources.insert("argmax_f32".into(), argmax_src);
1061 let softmax_sample_src: &'static str =
1062 include_str!("shaders/softmax_sample.metal");
1063 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
1064 // Top-K kernel for Q8 rerank: avoids full-logits readback.
1065 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
1066 sources.insert("top_k_f32".into(), top_k_src);
1067
1068 // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
1069 // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
1070 let moe_stk_src: &'static str =
1071 include_str!("shaders/moe_softmax_topk.metal");
1072 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
1073 let moe_wr_src: &'static str =
1074 include_str!("shaders/moe_weighted_reduce.metal");
1075 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
1076 let sdpa_decode_src: &'static str =
1077 include_str!("shaders/sdpa_decode.metal");
1078 sources.insert("sdpa_decode".into(), sdpa_decode_src);
1079
1080 Self {
1081 cache: HashMap::new(),
1082 sources,
1083 precompiled_lib: None,
1084 precompiled_load_attempted: false,
1085 }
1086 }
1087
1088 /// Try to obtain the precompiled `default.metallib` Library, loading it
1089 /// lazily on first call. Returns `None` when:
1090 /// - `MLX_PRECOMPILED_METALLIB` is unset (default)
1091 /// - The embedded blob is empty (build.rs skipped metallib build)
1092 /// - `device.new_library_with_data` failed previously
1093 /// - The previous load attempt already failed (no retry)
1094 fn try_precompiled_lib(
1095 &mut self,
1096 device: &metal::DeviceRef,
1097 ) -> Option<&metal::LibraryRef> {
1098 if !precompiled_enabled() {
1099 return None;
1100 }
1101 if !self.precompiled_load_attempted {
1102 self.precompiled_load_attempted = true;
1103 if EMBEDDED_METALLIB.is_empty() {
1104 return None;
1105 }
1106 // Apple's `newLibraryWithData:` expects a dispatch_data_t.
1107 // metal-rs wraps this via `new_library_with_data` which takes
1108 // a `&[u8]`.
1109 match device.new_library_with_data(EMBEDDED_METALLIB) {
1110 Ok(lib) => self.precompiled_lib = Some(lib),
1111 Err(_) => self.precompiled_lib = None,
1112 }
1113 }
1114 self.precompiled_lib.as_deref()
1115 }
1116
1117 /// Register a shader source at runtime (useful for testing and dynamic
1118 /// kernel generation).
1119 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
1120 let name = name.into();
1121 // Invalidate any cached pipeline for this name since the source changed.
1122 self.cache.remove(&name);
1123 self.sources.insert(name, source);
1124 }
1125
1126 /// Get a compiled compute pipeline for the named kernel function.
1127 ///
1128 /// On first call for a given name, this compiles the MSL source into a
1129 /// Metal library, extracts the named function, and creates a
1130 /// `ComputePipelineState`. Subsequent calls return the cached pipeline.
1131 ///
1132 /// # Errors
1133 ///
1134 /// * `MlxError::KernelNotFound` — no source registered for this name.
1135 /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
1136 /// creation failed.
1137 pub fn get_pipeline(
1138 &mut self,
1139 name: &str,
1140 device: &metal::DeviceRef,
1141 ) -> Result<&ComputePipelineState> {
1142 if !self.cache.contains_key(name) {
1143 // ADR-029 iter-175 Step 1l: precompiled .metallib fast path.
1144 // When MLX_PRECOMPILED_METALLIB=1 AND the kernel exists in the
1145 // embedded library, use it. Otherwise fall through to runtime
1146 // source compile. Empirically ~+5.89% faster on q6_K matvec
1147 // (iter 1k bench).
1148 let precompiled_function = self
1149 .try_precompiled_lib(device)
1150 .and_then(|lib| lib.get_function(name, None).ok());
1151
1152 let function = match precompiled_function {
1153 Some(f) => f,
1154 None => {
1155 // Slow path: compile the shader.
1156 let source = self.sources.get(name).ok_or_else(|| {
1157 MlxError::KernelNotFound(name.to_string())
1158 })?;
1159
1160 let compile_opts = metal::CompileOptions::new();
1161 let library = device
1162 .new_library_with_source(source, &compile_opts)
1163 .map_err(|msg| MlxError::ShaderCompilationError {
1164 name: name.to_string(),
1165 message: msg,
1166 })?;
1167
1168 library
1169 .get_function(name, None)
1170 .map_err(|msg| MlxError::ShaderCompilationError {
1171 name: name.to_string(),
1172 message: msg,
1173 })?
1174 }
1175 };
1176
1177 // Build the pipeline through a descriptor so we can attach a
1178 // human-readable label. The label propagates into Instruments /
1179 // xctrace Metal System Trace as the per-pipeline identifier
1180 // (`metal-object-label` schema), giving us per-kernel attribution
1181 // instead of the generic "Compute Command 0" placeholder.
1182 //
1183 // `MTLComputePipelineState.label` is read-only after creation per
1184 // the Apple Metal spec; the only supported way to set it is via
1185 // the descriptor before pipeline creation. ADR-015 iter9b.
1186 let descriptor = ComputePipelineDescriptor::new();
1187 descriptor.set_compute_function(Some(&function));
1188 descriptor.set_label(name);
1189 // ADR-028 iter-376: threadGroupSizeIsMultipleOfThreadExecutionWidth
1190 // hint allows the Metal compiler to skip bounds checks and use more
1191 // aggressive codegen. Opt-in via HF2Q_PIPELINE_TG_MULT_HINT=1.
1192 // SAFETY: every dispatched threadgroup MUST be a multiple of 32 at
1193 // runtime — Apple specifies undefined behavior otherwise. Our hot
1194 // kernels use tg_size ∈ {32, 64, 256, 1024} (all multiples of 32).
1195 if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
1196 descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
1197 }
1198
1199 let pipeline = device
1200 .new_compute_pipeline_state(&descriptor)
1201 .map_err(|msg| MlxError::ShaderCompilationError {
1202 name: name.to_string(),
1203 message: msg,
1204 })?;
1205
1206 self.cache.insert(name.to_string(), pipeline);
1207 }
1208
1209 // At this point the pipeline is guaranteed to be in the cache.
1210 // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
1211 self.cache.get(name).ok_or_else(|| {
1212 MlxError::KernelNotFound(name.to_string())
1213 })
1214 }
1215
1216 /// Get a compiled compute pipeline for the named kernel, specialized with
1217 /// Metal function constants (both bool and i32 in one call).
1218 ///
1219 /// `bool_constants` contains `(index, value)` pairs mapping to
1220 /// `[[function_constant(index)]]` bool declarations in the MSL shader.
1221 /// `int_constants` contains `(index, value)` pairs mapping to
1222 /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
1223 /// shader.
1224 ///
1225 /// Pipelines are cached by a composite key:
1226 /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`. The 'b' prefix
1227 /// marks bool entries and the 'i' prefix marks i32 entries, making the
1228 /// format unambiguous regardless of constant ordering. Distinct
1229 /// `(name, constants)` combinations each compile to a separate pipeline;
1230 /// the slow compilation path runs at most once per unique combination.
1231 ///
1232 /// # Errors
1233 ///
1234 /// * `MlxError::KernelNotFound` — no source registered for this name.
1235 /// * `MlxError::ShaderCompilationError` — MSL compilation, function
1236 /// specialisation, or pipeline creation failed.
1237 pub fn get_pipeline_with_constants(
1238 &mut self,
1239 name: &str,
1240 device: &metal::DeviceRef,
1241 bool_constants: &[(usize, bool)],
1242 int_constants: &[(usize, i32)],
1243 ) -> Result<&ComputePipelineState> {
1244 // Build a composite cache key so distinct constant combinations each
1245 // compile to their own pipeline. Bool entries use the 'b' type marker
1246 // and i32 entries use 'i'; this prevents a collision between, e.g.,
1247 // bool index 5 value 1 and int index 5 value 1.
1248 let mut cache_key = name.to_string();
1249 for &(index, value) in bool_constants {
1250 cache_key.push('|');
1251 cache_key.push_str(&index.to_string());
1252 cache_key.push_str(if value { ":b1" } else { ":b0" });
1253 }
1254 for &(index, value) in int_constants {
1255 cache_key.push('|');
1256 cache_key.push_str(&index.to_string());
1257 cache_key.push(':');
1258 cache_key.push('i');
1259 cache_key.push_str(&value.to_string());
1260 }
1261
1262 if !self.cache.contains_key(&cache_key) {
1263 // Build the FunctionConstantValues object with all bool and i32
1264 // constants. Metal's set_constant_value_at_index reads the value
1265 // through a raw pointer; the pointed-to bytes must match the size
1266 // declared in the MSL shader (1 byte for bool, 4 bytes for int).
1267 let fcv = FunctionConstantValues::new();
1268
1269 for &(index, value) in bool_constants {
1270 // MTLDataType::Bool = 53 (metal-rs argument.rs).
1271 // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
1272 let v: u8 = if value { 1 } else { 0 };
1273 fcv.set_constant_value_at_index(
1274 (&v as *const u8).cast::<std::ffi::c_void>(),
1275 MTLDataType::Bool,
1276 index as u64,
1277 );
1278 }
1279
1280 for &(index, value) in int_constants {
1281 // MTLDataType::Int = 29 (metal-rs argument.rs).
1282 // The Metal runtime reads 4 bytes as a signed 32-bit integer,
1283 // matching the Metal shader type `constant int`.
1284 fcv.set_constant_value_at_index(
1285 (&value as *const i32).cast::<std::ffi::c_void>(),
1286 MTLDataType::Int,
1287 index as u64,
1288 );
1289 }
1290
1291 // ADR-029 iter-175 Step 1l: try precompiled .metallib first.
1292 // Step 1m: gated separately on MLX_PRECOMPILED_METALLIB_FCV=1
1293 // so we can isolate whether the integration regression at
1294 // Step 1l (tg50 95.4 → 62.1) lives in this FCV path vs the
1295 // no-FCV path in `get_pipeline`.
1296 //
1297 // get_function with FCV takes ownership of the FCV, so we
1298 // build a separate one for the precompiled probe.
1299 let precompiled_function = if precompiled_fcv_enabled() {
1300 let probe_fcv = FunctionConstantValues::new();
1301 for &(index, value) in bool_constants {
1302 let v: u8 = if value { 1 } else { 0 };
1303 probe_fcv.set_constant_value_at_index(
1304 (&v as *const u8).cast::<std::ffi::c_void>(),
1305 MTLDataType::Bool,
1306 index as u64,
1307 );
1308 }
1309 for &(index, value) in int_constants {
1310 probe_fcv.set_constant_value_at_index(
1311 (&value as *const i32).cast::<std::ffi::c_void>(),
1312 MTLDataType::Int,
1313 index as u64,
1314 );
1315 }
1316 self.try_precompiled_lib(device)
1317 .and_then(|lib| lib.get_function(name, Some(probe_fcv)).ok())
1318 } else {
1319 None
1320 };
1321
1322 let function = match precompiled_function {
1323 Some(f) => f,
1324 None => {
1325 // Slow path: compile the shader with function constant specialisation.
1326 let source = self.sources.get(name).ok_or_else(|| {
1327 MlxError::KernelNotFound(name.to_string())
1328 })?;
1329
1330 let compile_opts = metal::CompileOptions::new();
1331 let library = device
1332 .new_library_with_source(source, &compile_opts)
1333 .map_err(|msg| MlxError::ShaderCompilationError {
1334 name: name.to_string(),
1335 message: msg,
1336 })?;
1337
1338 library
1339 .get_function(name, Some(fcv))
1340 .map_err(|msg| MlxError::ShaderCompilationError {
1341 name: name.to_string(),
1342 message: msg,
1343 })?
1344 }
1345 };
1346
1347 // Label this specialisation with the full composite cache key
1348 // (e.g. `kernel_mul_mv_q4_0_f32|0:b1|3:i32`) so xctrace Metal
1349 // System Trace shows each function-constant variant as a distinct
1350 // pipeline. Without this, all specialisations share a generic
1351 // "Compute Command 0" identifier and we cannot attribute µs/token
1352 // to a specific (kernel, constants) combination. ADR-015 iter9b.
1353 let descriptor = ComputePipelineDescriptor::new();
1354 descriptor.set_compute_function(Some(&function));
1355 descriptor.set_label(&cache_key);
1356 // ADR-028 iter-376: same hint as primary pipeline path.
1357 if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
1358 descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
1359 }
1360
1361 let pipeline = device
1362 .new_compute_pipeline_state(&descriptor)
1363 .map_err(|msg| MlxError::ShaderCompilationError {
1364 name: name.to_string(),
1365 message: msg,
1366 })?;
1367
1368 self.cache.insert(cache_key.clone(), pipeline);
1369 }
1370
1371 self.cache.get(&cache_key).ok_or_else(|| {
1372 MlxError::KernelNotFound(name.to_string())
1373 })
1374 }
1375
1376 /// Get a compiled compute pipeline for the named kernel, specialized with
1377 /// Metal bool function constants.
1378 ///
1379 /// The `bool_constants` slice contains `(index, value)` pairs. Each pair
1380 /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
1381 ///
1382 /// This is a thin wrapper around [`get_pipeline_with_constants`] that
1383 /// passes an empty `int_constants` slice. Existing callers continue to
1384 /// work without modification; the cache-key format for pure-bool pipelines
1385 /// is compatible (bool entries carry the 'b' type marker, which is the
1386 /// only format ever written by this wrapper).
1387 ///
1388 /// # Errors
1389 ///
1390 /// * `MlxError::KernelNotFound` — no source registered for this name.
1391 /// * `MlxError::ShaderCompilationError` — MSL compilation, function
1392 /// specialisation, or pipeline creation failed.
1393 pub fn get_pipeline_with_bool_constants(
1394 &mut self,
1395 name: &str,
1396 device: &metal::DeviceRef,
1397 bool_constants: &[(usize, bool)],
1398 ) -> Result<&ComputePipelineState> {
1399 self.get_pipeline_with_constants(name, device, bool_constants, &[])
1400 }
1401
1402 /// Check if a pipeline for the given name is already compiled and cached.
1403 pub fn is_cached(&self, name: &str) -> bool {
1404 self.cache.contains_key(name)
1405 }
1406
1407 /// Number of compiled pipelines currently in the cache.
1408 pub fn cached_count(&self) -> usize {
1409 self.cache.len()
1410 }
1411
1412 /// Number of registered shader sources.
1413 pub fn source_count(&self) -> usize {
1414 self.sources.len()
1415 }
1416}
1417
1418impl Default for KernelRegistry {
1419 fn default() -> Self {
1420 Self::new()
1421 }
1422}
1423
1424#[cfg(test)]
1425mod tests {
1426 use super::*;
1427
1428 /// Minimal Metal shader that uses a single int function constant.
1429 ///
1430 /// The kernel writes the constant value N into the first element of the
1431 /// output buffer, allowing the test to verify that the Metal compiler
1432 /// actually sees distinct specialisations for N=4 and N=8.
1433 ///
1434 /// The shader is intentionally trivial — we only need it to *compile* with
1435 /// an int function constant; correctness of the kernel logic is not under
1436 /// test here.
1437 const INT_FC_TEST_SHADER: &str = r#"
1438#include <metal_stdlib>
1439using namespace metal;
1440
1441constant int test_N [[function_constant(100)]];
1442
1443kernel void int_fc_test_kernel(
1444 device int* out [[buffer(0)]],
1445 uint tid [[thread_position_in_grid]])
1446{
1447 if (tid == 0) {
1448 out[0] = test_N;
1449 }
1450}
1451"#;
1452
1453 /// Verify that `get_pipeline_with_constants` produces distinct cached
1454 /// pipelines for different i32 function-constant values, and that
1455 /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
1456 /// works correctly with the new 'b'-prefixed cache-key format.
1457 ///
1458 /// This test requires a real Metal device and is therefore marked
1459 /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
1460 #[test]
1461 fn test_int_fc_distinct_pipelines_and_bool_compat() {
1462 let device = metal::Device::system_default()
1463 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1464
1465 let mut registry = KernelRegistry::new();
1466
1467 // Register the inline test shader under a name that cannot collide with
1468 // any production kernel.
1469 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1470
1471 // Compile with N=4.
1472 let p4_ptr = registry
1473 .get_pipeline_with_constants(
1474 "int_fc_test_kernel",
1475 &device,
1476 &[], // no bool constants
1477 &[(100, 4_i32)], // int constant index 100 = 4
1478 )
1479 .expect("pipeline N=4 should compile") as *const _;
1480
1481 // Cache must now have exactly 1 entry for this kernel.
1482 // (Other production kernels may already be in cache from new(); here
1483 // we check that the N=4 key was inserted.)
1484 let count_after_n4 = registry.cached_count();
1485
1486 // Compile with N=8 — must produce a SEPARATE pipeline.
1487 let p8_ptr = registry
1488 .get_pipeline_with_constants(
1489 "int_fc_test_kernel",
1490 &device,
1491 &[],
1492 &[(100, 8_i32)],
1493 )
1494 .expect("pipeline N=8 should compile") as *const _;
1495
1496 // Cache must have grown by exactly 1.
1497 assert_eq!(
1498 registry.cached_count(),
1499 count_after_n4 + 1,
1500 "N=8 must produce a new cache entry"
1501 );
1502
1503 // The two pipelines must be distinct objects in the cache.
1504 assert_ne!(
1505 p4_ptr, p8_ptr,
1506 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1507 );
1508
1509 // A second call with N=4 must return the SAME pipeline (cache hit, no
1510 // new compilation).
1511 let p4_again_ptr = registry
1512 .get_pipeline_with_constants(
1513 "int_fc_test_kernel",
1514 &device,
1515 &[],
1516 &[(100, 4_i32)],
1517 )
1518 .expect("pipeline N=4 cache hit should succeed") as *const _;
1519
1520 assert_eq!(
1521 registry.cached_count(),
1522 count_after_n4 + 1,
1523 "repeated N=4 call must be a cache hit, not a new entry"
1524 );
1525 assert_eq!(
1526 p4_ptr, p4_again_ptr,
1527 "repeated N=4 call must return the same pipeline pointer"
1528 );
1529
1530 // Verify backward compatibility: get_pipeline_with_bool_constants must
1531 // still route through get_pipeline_with_constants and produce a cached
1532 // pipeline without panicking.
1533 //
1534 // We register a separate bool-constant shader that does NOT use a bool
1535 // function constant (so the Metal compiler ignores missing FCs for
1536 // this trivial case) — but the call path and cache-key format are what
1537 // matter here. We reuse the int_fc_test_kernel source; the bool FC is
1538 // simply unused by the shader (Metal allows unused FCs when the shader
1539 // declares them with `function_constant` but the value is never read).
1540 //
1541 // To avoid a Metal compiler error for an undeclared function constant,
1542 // we register a separate bare-kernel shader for the bool wrapper test.
1543 const BARE_SHADER: &str = r#"
1544#include <metal_stdlib>
1545using namespace metal;
1546kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1547 if (tid == 0) { out[0] = 42; }
1548}
1549"#;
1550 registry.register_source("bare_kernel", BARE_SHADER);
1551
1552 let count_before_bool = registry.cached_count();
1553 let _bool_pipeline = registry
1554 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1555 .expect("bool-constants wrapper with empty slice must succeed");
1556
1557 assert_eq!(
1558 registry.cached_count(),
1559 count_before_bool + 1,
1560 "bool-constants wrapper must insert one new cache entry"
1561 );
1562 }
1563
1564 /// Verify that the `MTLComputePipelineState.label` produced by
1565 /// `get_pipeline` and `get_pipeline_with_constants` actually propagates
1566 /// from the descriptor to the resulting pipeline state.
1567 ///
1568 /// This is the in-process smoke check for ADR-015 iter9b: we cannot
1569 /// reach into xctrace from Rust, but we can read back the same `label`
1570 /// property xctrace consumes via `ComputePipelineStateRef::label()`.
1571 /// If labels are missing or wrong here, the MST trace will also show
1572 /// generic identifiers — so this test gates the iter9 retry's
1573 /// per-Q4_0-kernel attribution.
1574 #[test]
1575 fn test_pipeline_labels_propagate_for_mst() {
1576 let device = metal::Device::system_default()
1577 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1578
1579 let mut registry = KernelRegistry::new();
1580
1581 // Reuse the same trivial shaders as the int-FC test.
1582 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1583
1584 const BARE_SHADER_LABEL_TEST: &str = r#"
1585#include <metal_stdlib>
1586using namespace metal;
1587kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1588 if (tid == 0) { out[0] = 7; }
1589}
1590"#;
1591 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1592
1593 // Plain get_pipeline path — label must equal the kernel name.
1594 // Capture as owned String so the cache borrow is released before
1595 // the next get_pipeline_with_constants call below.
1596 let plain_label = registry
1597 .get_pipeline("label_smoke_kernel", &device)
1598 .expect("plain pipeline must compile")
1599 .label()
1600 .to_string();
1601 assert_eq!(
1602 plain_label, "label_smoke_kernel",
1603 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1604 );
1605
1606 // Constants path — label must equal the composite cache key so each
1607 // function-constant variant is individually attributable in MST.
1608 // We capture the label as an owned String to release the borrow on
1609 // the cache before fetching the next specialisation.
1610 let label_v7 = registry
1611 .get_pipeline_with_constants(
1612 "int_fc_test_kernel",
1613 &device,
1614 &[],
1615 &[(100, 7_i32)],
1616 )
1617 .expect("specialised pipeline must compile")
1618 .label()
1619 .to_string();
1620 assert_eq!(
1621 label_v7, "int_fc_test_kernel|100:i7",
1622 "get_pipeline_with_constants must label with the cache_key so each \
1623 specialisation is distinct in xctrace MST"
1624 );
1625
1626 // A second specialisation must produce a different label.
1627 let label_v13 = registry
1628 .get_pipeline_with_constants(
1629 "int_fc_test_kernel",
1630 &device,
1631 &[],
1632 &[(100, 13_i32)],
1633 )
1634 .expect("second specialised pipeline must compile")
1635 .label()
1636 .to_string();
1637 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1638 assert_ne!(
1639 label_v7, label_v13,
1640 "distinct constant values must yield distinct pipeline labels"
1641 );
1642 }
1643}