Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        // ADR-028 iter-368: peer-style NSG=4 NR=2 variant (128 threads/TG).
80        sources.insert("kernel_mul_mv_q8_0_f32_nr2".into(), ggml_src);
81        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
82        // ADR-028 iter-309 — q6_K mat-vec with nr0=2 + cached yl[16]
83        // (peer-pattern port of llama.cpp's `kernel_mul_mv_q6_K_f32_impl`
84        // with N_R0_Q6_K=2; 4 rows/TG vs baseline's 2).  Env-gated via
85        // `HF2Q_Q6K_MV_NR2=1` in the dispatcher.
86        sources.insert("kernel_mul_mv_q6_K_f32_nr2".into(), ggml_src);
87        // ADR-022 Phase 1 — Q5_1 / IQ4_NL dense mat-vec.
88        sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
89        sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
90        // ADR-013 P7 — Q4_K dense decode mat-vec (port of llama.cpp's
91        // kernel_mul_mv_q4_K_f32 at ggml-metal.metal:7715-7821).
92        sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
93        // ADR-022 Phase 2 — Q5_K dense mv kernel.
94        sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
95
96        // GGML block-format quantized matrix-matrix kernels
97        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
98        // Used at prefill m > 8 to reuse each weight tile across a 32-row
99        // block via threadgroup-staged simdgroup MMA, instead of re-reading
100        // every block per prompt-token as the mv kernel does.
101        let ggml_mm_src: &'static str =
102            include_str!("shaders/quantized_matmul_mm.metal");
103        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
104        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
105        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
106        // ADR-022 Phase 1 — dense Q5_1 / IQ4_NL mm.
107        sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
108        sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
109        // ADR-022 Phase 2 — dense Q5_K mm.
110        sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
111        // ADR-022 Phase 3 — dense Q4_K mm.
112        sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
113
114        // GGML block-format quantized matrix-matrix kernels — tensor API
115        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
116        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
117        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
118        // primitive which on M3+ dispatches to hardware tensor cores for
119        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
120        // Only compiled on devices where the tensor API is available; the
121        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
122        // compilation so non-tensor devices transparently fall back to the
123        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
124        let ggml_mm_tensor_src: &'static str =
125            include_str!("shaders/quantized_matmul_mm_tensor.metal");
126        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
127        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
128        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
129        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
130        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
131        // ADR-022 Phase 1 — Q5_1 / IQ4_NL tensor mm.
132        sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
133        sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
134        // ADR-022 Phase 2 — Q5_K tensor mm.
135        sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
136        // ADR-022 Phase 3 — Q4_K tensor mm + Q8_0 perm021.
137        sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
138        sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
139        // ADR-029 iter-30 H29-speed — F16-weight V2 large-tile mm.
140        // Same source file as the V2 quantized variants; reads F16 weight
141        // directly from device memory (no per-call dequant).  Used when
142        // MlxQWeight.f16_shadow is populated and m > MM_ROUTING_THRESHOLD.
143        sources.insert("hf2q_mul_mm_tensor_v2_f16".into(), ggml_mm_tensor_src);
144        // ADR-029 iter-36 H28-D — F16-weight perm021 mm for O-projection.
145        // Same source file; reads F16 weight from MlxQWeight.f16_shadow when
146        // populated, bypassing the per-call quantized dequant.  B-stage
147        // (bfloat permuted [n_heads, seq_len, head_dim] input) is byte-
148        // identical to the quantized variant.
149        sources.insert("kernel_mul_mm_f16_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
150        // ADR-029 iter-23 H28-A — V2 large-tile tensor mm (NRA=64 M, NRB=128 N).
151        // Same source file as V1 tensor mm; distinct kernel host names so the
152        // dispatcher can pick V1 vs V2 at runtime via HF2Q_LARGE_TILE_MM.
153        sources.insert("kernel_mul_mm_q4_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
154        sources.insert("kernel_mul_mm_q8_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
155        sources.insert("kernel_mul_mm_q6_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
156        sources.insert("kernel_mul_mm_q5_1_tensor_v2_f32".into(), ggml_mm_tensor_src);
157        sources.insert("kernel_mul_mm_iq4_nl_tensor_v2_f32".into(), ggml_mm_tensor_src);
158        sources.insert("kernel_mul_mm_q5_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
159        sources.insert("kernel_mul_mm_q4_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
160        // ADR-029 iter-28 H29 — whole-tensor dequant from block_q → F16.
161        // Used at model load to materialize an F16 shadow of attn/dense MLP
162        // weights so the runtime dispatch can use kernel_mul_mm_f16_f32_*
163        // (peer's gemma4 pattern).  Trades ~1 GB resident memory for 2-3×
164        // faster per-call dense matmul at prefill.
165        let dequant_to_f16_src: &'static str =
166            include_str!("shaders/dequant_to_f16.metal");
167        sources.insert("hf2q_dequant_q4_0_to_f16".into(), dequant_to_f16_src);
168        sources.insert("hf2q_dequant_q8_0_to_f16".into(), dequant_to_f16_src);
169        sources.insert("hf2q_dequant_q5_1_to_f16".into(), dequant_to_f16_src);
170        sources.insert("hf2q_dequant_iq4_nl_to_f16".into(), dequant_to_f16_src);
171        sources.insert("hf2q_dequant_q4_K_to_f16".into(), dequant_to_f16_src);
172        sources.insert("hf2q_dequant_q5_K_to_f16".into(), dequant_to_f16_src);
173        sources.insert("hf2q_dequant_q6_K_to_f16".into(), dequant_to_f16_src);
174
175        // ADR-022 Phase 1 P1.7 — Q5_1 / IQ4_NL mul_mv_ext r1 family.
176        // Eight instantiations (2 types × 4 r1ptg widths). Each PSO is
177        // additionally specialized at PSO-compile time with FC_mul_mv_nsg
178        // (function_constant 600) and FC_mul_mv_nxpsg (function_constant 601).
179        let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
180        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
181        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
182        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
183        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
184        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
185        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
186        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
187        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
188        // ADR-022 Phase 4 — Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K mv_ext.
189        // 5 types × 4 r1ptg widths = 20 instantiations.
190        for r1 in [2, 3, 4, 5].iter() {
191            for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
192                let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
193                sources.insert(name, mul_mv_ext_src);
194            }
195        }
196
197        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
198        // prefill Q@K^T and scores@V, modeled on llama.cpp's
199        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
200        // active).  Tile geometry and write-back identical to the
201        // quantized tensor kernel; only the A-stage copy (bfloat →
202        // bfloat, no dequantize) differs.
203        let dense_mm_bf16_tensor_src: &'static str =
204            include_str!("shaders/dense_mm_bf16_tensor.metal");
205        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
206        // ADR-029 iter-80 H60: V2 large-tile variant (NRA=64, NRB=128).
207        // Same source file (`dense_mm_bf16_tensor.metal`) — second host_name
208        // entry resolves to the V2 kernel appended at the bottom of that
209        // file. Picked at dispatch time when HF2Q_LARGE_TILE_MM=1.
210        sources.insert("hf2q_dense_mm_bf16_f32_tensor_v2".into(), dense_mm_bf16_tensor_src);
211
212        // Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
213        // sibling of dense_mm_bf16_tensor).  Used by hf2q's ADR-005
214        // iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
215        // the BF16 K-stage cast as a confounding variable.  Port of
216        // llama.cpp's kernel_mul_mm_f32_f32 specialization
217        // (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
218        // branch.  Same tile geometry (NR0=64 NR1=32 NK=32) but
219        // float-everywhere shmem staging.
220        let dense_mm_f32_f32_tensor_src: &'static str =
221            include_str!("shaders/dense_mm_f32_f32.metal");
222        sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
223
224        // Dense f16×f32 → f32 tensor-API matmul (F16-staging sibling
225        // of dense_mm_bf16_tensor).  Used by hf2q's ADR-005 Phase 2c
226        // iter-128 gemma4v ViT precision-parity path: every mmproj
227        // weight is stored as F16 in GGUF, peer's `kernel_mul_mm_f16_f32`
228        // (`ggml-metal.metal:10099`) stages BOTH A and B as `half` in
229        // shmem and computes on `simdgroup_half8x8`.  Matches peer
230        // per-element rounding budget exactly (10-bit mantissa vs
231        // BF16's 7-bit), closing the 1.16x/block cascade compound that
232        // iter-127 numerically bisected to BF16 staging.  Same tile
233        // geometry as the BF16 sibling (NR0=64 NR1=32 NK=32, 8 KB
234        // shmem) — half and bfloat share 16-bit storage.
235        let dense_mm_f16_tensor_src: &'static str =
236            include_str!("shaders/dense_mm_f16_tensor.metal");
237        sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
238
239        // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
240        // for M=1 single-token decode.  Port of llama.cpp's
241        // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
242        // Used in apply_linear_projection_f32 when seq_len=1 and the
243        // weight matrix is BF16, replacing the MM kernel (~2× faster for
244        // M=1 due to better memory bandwidth utilization per thread).
245        let dense_gemv_bf16_src: &'static str =
246            include_str!("shaders/dense_gemv_bf16.metal");
247        sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
248
249        // Fused scale-mask-softmax for the non-flash-attention prefill
250        // path.  One row-local threadgroup per (head, query) pair
251        // replaces three separate dispatches (scale, mask-add, softmax);
252        // reads a bf16 mask (-INF at masked positions, matching
253        // flash_attn_prefill_mask.metal) that is shared across heads.
254        let scale_mask_softmax_src: &'static str =
255            include_str!("shaders/scale_mask_softmax.metal");
256        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
257        // ADR-029 iter-93 H71: float4-vectorized variant for peer parity
258        // with kernel_soft_max_f32_4. Same source file; v4 host_name resolves
259        // to the second kernel appended at the bottom of scale_mask_softmax.metal.
260        sources.insert("scale_mask_softmax_f32_v4".into(), scale_mask_softmax_src);
261
262        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
263        sources.insert(
264            "quantized_matmul_id".into(),
265            include_str!("shaders/quantized_matmul_id.metal"),
266        );
267
268        // Expert-routed (MoE) GGML block-format quantized matmul kernels
269        let ggml_id_src: &'static str =
270            include_str!("shaders/quantized_matmul_id_ggml.metal");
271        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
272        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
273        // ADR-013 P7 — Q4_K MoE expert-routed mat-vec (port of
274        // llama.cpp's kernel_mul_mv_id_q4_K_f32 at ggml-metal.metal:10349).
275        sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
276        sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
277        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
278        // ADR-028 iter-321 — q6_K _id with nr0=2 + cached yl[16]
279        // (peer-pattern port mirroring iter-309's non-_id variant).
280        // Env-gated via HF2Q_Q6K_ID_MV_NR2=1 in dispatch_id_mv.
281        sources.insert("kernel_mul_mv_id_q6_K_f32_nr2".into(), ggml_id_src);
282        // ADR-029 iter-6 — q8_0 _id with nr0=2 + nsg=4 cross-SG reduce
283        // (peer-pattern port; peer N_R0_Q8_0=2 + N_SG_Q8_0=4 in
284        //  /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h:27,40).
285        // Env-gated via HF2Q_Q8_0_ID_MV_NR2=1 in dispatch_id_mv.
286        sources.insert("kernel_mul_mv_id_q8_0_f32_nr2".into(), ggml_id_src);
287        // ADR-022 Phase 1 — Q5_1 / IQ4_NL MoE expert-routed mat-vec.
288        sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
289        sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
290        // Fused-SwiGLU mv_id variants (ADR-012 §Optimize / Task #15):
291        // computes y[r][n] = sum_k(dequant(W[expert][n][k]) * silu(gate[r][k]) * up[r][k])
292        // in one dispatch — replaces silu_mul + expert_down sequence.
293        sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
294
295        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
296        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
297        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
298        // Two-stage dispatch: map0 regroups the token-to-expert table into
299        // per-expert routed-token lists, then mm_id stages a 64x32 expert
300        // weight tile into threadgroup shmem and reuses it across a 32-row
301        // block of that expert's routed tokens.
302        let ggml_id_mm_src: &'static str =
303            include_str!("shaders/quantized_matmul_id_mm.metal");
304        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
305        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
306        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
307        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
308        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
309        // ADR-013 P16 — Q4_K mm_id (port of llama.cpp ggml-metal.metal:10169).
310        sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
311        // ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL mm_id template instantiations.
312        sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
313        sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
314        // ADR-022 Phase 2 — Q5_K mm_id template instantiation.
315        sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
316
317        // MoE-routed quantized matrix-matrix kernels — tensor API variant
318        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
319        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
320        // the mm_id kernel is ported — map0 is a short pre-pass (not
321        // matmul) and continues to use the simdgroup version.
322        let ggml_id_mm_tensor_src: &'static str =
323            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
324        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
325        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
326        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
327        // ADR-013 P16 — Q4_K tensor-API mm_id.
328        sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
329        // ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL tensor-API mm_id.
330        sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
331        sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
332        // ADR-022 Phase 2 — Q5_K tensor-API mm_id.
333        sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
334
335        // Embedding kernels (Story 1.5)
336        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
337        sources.insert("embedding_gather_4bit".into(), embedding_src);
338        sources.insert("embedding_gather_6bit".into(), embedding_src);
339
340        // MoE gate kernel (Story 1.5)
341        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
342        sources.insert("moe_gate".into(), moe_gate_src);
343
344        // MoE dispatch kernels (Story 1.5)
345        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
346        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
347        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
348        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
349        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
350        sources.insert("moe_accumulate".into(), moe_dispatch_src);
351        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
352        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
353        sources.insert("zero_buffer".into(), moe_dispatch_src);
354        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
355        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
356        // bf16 variants (Phase 2 bf16 activation path)
357        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
358        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
359        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
360        // ADR-020 iter-11h-e3a: backward kernels for moe_weighted_sum_seq.
361        sources.insert(
362            "moe_weighted_sum_seq_backward_outputs_f32".into(),
363            moe_dispatch_src,
364        );
365        sources.insert(
366            "moe_weighted_sum_seq_backward_weights_f32".into(),
367            moe_dispatch_src,
368        );
369        // ADR-020 iter-11h-e3b: fused backward kernel for moe_swiglu_seq.
370        sources.insert(
371            "moe_swiglu_seq_backward_f32".into(),
372            moe_dispatch_src,
373        );
374
375        // Batched KV cache copy kernels
376        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
377        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
378        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
379        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
380        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
381        // Wave P4.11 — fused K+V copy variants
382        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
383        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
384        // ADR-028 iter-145 — fused single-position K+V copy variants (decode shape)
385        sources.insert("kv_cache_copy_batch_f32_kv_dual".into(), kv_cache_src);
386        sources.insert("kv_cache_copy_batch_f32_to_f16_kv_dual".into(), kv_cache_src);
387        // bf16-source KV cache copy (Phase 2 bf16 activation path)
388        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
389
390        // Elementwise and transpose kernels (Story 1.5)
391        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
392        sources.insert("elementwise_add_f32".into(), elementwise_src);
393        sources.insert("elementwise_add_f16".into(), elementwise_src);
394        sources.insert("elementwise_mul_f32".into(), elementwise_src);
395        sources.insert("elementwise_mul_f16".into(), elementwise_src);
396        sources.insert("elementwise_add_bf16".into(), elementwise_src);
397        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
398        sources.insert("cast_f16_to_f32".into(), elementwise_src);
399        sources.insert("cast_f32_to_f16".into(), elementwise_src);
400        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
401        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
402        sources.insert("scalar_mul_bf16".into(), elementwise_src);
403        sources.insert("scalar_mul_f32".into(), elementwise_src);
404        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
405        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
406        sources.insert("permute_021_bf16".into(), elementwise_src);
407        sources.insert("transpose_last2_bf16".into(), elementwise_src);
408        sources.insert("transpose_last2_f16".into(), elementwise_src);
409        sources.insert("permute_021_f32".into(), elementwise_src);
410        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
411        sources.insert("transpose_2d_f32".into(), elementwise_src);
412        sources.insert("transpose_2d_f16".into(), elementwise_src);
413
414        // Attention kernels (Story 1.3)
415        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
416        sources.insert("sdpa".into(), sdpa_src);
417        sources.insert("sdpa_bf16".into(), sdpa_src);
418        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
419        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
420        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
421
422        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
423        // Ten entry points; all backed by the same shader source.
424        // Pipelines are compiled with function constants via
425        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
426        let flash_attn_prefill_src: &'static str =
427            include_str!("shaders/flash_attn_prefill.metal");
428        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
429        sources.insert(
430            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
431            flash_attn_prefill_src,
432        );
433        sources.insert(
434            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
435            flash_attn_prefill_src,
436        );
437        sources.insert(
438            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
439            flash_attn_prefill_src,
440        );
441        sources.insert(
442            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
443            flash_attn_prefill_src,
444        );
445        sources.insert(
446            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
447            flash_attn_prefill_src,
448        );
449        sources.insert(
450            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
451            flash_attn_prefill_src,
452        );
453        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
454        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
455        // the 32 KB Metal limit (candle sdpa.rs:86-94).
456        sources.insert(
457            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
458            flash_attn_prefill_src,
459        );
460        sources.insert(
461            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
462            flash_attn_prefill_src,
463        );
464        sources.insert(
465            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
466            flash_attn_prefill_src,
467        );
468        sources.insert(
469            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
470            flash_attn_prefill_src,
471        );
472
473        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
474        // (ported from llama.cpp flash_attn_ext_vec)
475        let flash_attn_vec_src: &'static str =
476            include_str!("shaders/flash_attn_vec.metal");
477        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
478        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
479        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
480        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
481        // F16 KV variants (Phase 4a)
482        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
483        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
484
485        // RoPE, normalization, activation kernels (Story 1.4)
486        let rope_src: &'static str = include_str!("shaders/rope.metal");
487        sources.insert("rope_f32".into(), rope_src);
488        sources.insert("rope_f16".into(), rope_src);
489        sources.insert("rope_bf16".into(), rope_src);
490        sources.insert("rope_neox_bf16".into(), rope_src);
491        sources.insert("rope_neox_f32".into(), rope_src);
492        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
493        sources.insert("rms_norm_f32".into(), rms_norm_src);
494        // ADR-028 iter-310 — float4 + simd_sum variants (peer-pattern,
495        // ported from llama.cpp kernel_rms_norm_fuse_impl<float4, 1>).
496        // Env-gated via HF2Q_RMS_NORM_V2=1 in the dispatchers.
497        sources.insert("rms_norm_f32_v2".into(), rms_norm_src);
498        sources.insert("rms_norm_no_scale_f32_v2".into(), rms_norm_src);
499        sources.insert("rms_norm_f16".into(), rms_norm_src);
500        sources.insert("rms_norm_bf16".into(), rms_norm_src);
501        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
502        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
503        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
504        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
505        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
506        // ADR-028 iter-370: V2 (float4 + simd_sum) variant of triple_norm.
507        sources.insert("fused_post_attn_triple_norm_f32_v2".into(), rms_norm_src);
508        // ADR-028 iter-217: fused post-FF norm 2 + end-of-layer FINAL
509        // (combines 2 sequential fused_norm_add dispatches into 1 kernel).
510        sources.insert("fused_post_ff_norm2_endlayer_f32".into(), rms_norm_src);
511        // ADR-028 iter-362: V2 (float4 + simd_sum) variant of the above.
512        // Same math, 75% fewer barriers per dispatch (4 vs 16 at tg=256).
513        sources.insert("fused_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
514        // ADR-028 iter-367: V2 fusion of moe_weighted_sum INTO Path A end-of-layer.
515        // Eliminates 1 dispatch + moe_accum round-trip from gemma4 decode default.
516        sources.insert("fused_moe_wsum_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
517        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
518        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
519        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
520        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
521        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
522        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
523        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
524        sources.insert("l2_norm_f32".into(), l2_norm_src);
525        sources.insert("l2_norm_f16".into(), l2_norm_src);
526        sources.insert("l2_norm_bf16".into(), l2_norm_src);
527        // ADR-015 iter59a — fused L2 norm + scalar multiply (DN q-path).
528        sources.insert("l2_norm_scale_f32".into(), l2_norm_src);
529        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
530        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
531        sources.insert("cumsum_f32".into(), cumsum_src);
532        sources.insert("cumsum_bf16".into(), cumsum_src);
533        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
534        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
535        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
536        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
537        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
538        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
539        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
540        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
541        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
542        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
543        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
544        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
545        sources.insert("rope_multi_f32".into(), rope_multi_src);
546        sources.insert("rope_multi_bf16".into(), rope_multi_src);
547        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
548        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
549        sources.insert("gated_delta_net_f32".into(), gdn_src);
550        // ADR-015 iter56 — decode-only `simd_sum` variant. Three NSG-templated
551        // host names share the same source; selection is by D_k via
552        // `dispatch_gated_delta_net_decode`. Drop-in for the fused kernel
553        // above when n_tokens=1.
554        let gdn_decode_src: &'static str =
555            include_str!("shaders/gated_delta_net_decode.metal");
556        sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
557        sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
558        sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
559        // Wave 5b — chunk-parallel inter-chunk state-recurrence kernel
560        // (the one new kernel in the chunk-parallel pipeline; spec source:
561        // arXiv 2412.06464 §4 + FLA chunk_delta_h.py:43-298).
562        let gdn_chunk_src: &'static str =
563            include_str!("shaders/gated_delta_net_chunk.metal");
564        sources.insert(
565            "gated_delta_net_chunk_inter_state_bf16".into(),
566            gdn_chunk_src,
567        );
568        // Wave 5b.1 iter 2 — chunk_scaled_dot_kkt kernel (input-side of
569        // the chunk pipeline; spec source: FLA chunk_scaled_dot_kkt.py:36-99).
570        let gdn_kkt_src: &'static str =
571            include_str!("shaders/gated_delta_net_kkt.metal");
572        sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
573        // Wave 5b.1 iter 2 — recompute_w_u_fwd kernel (applies post-solve A
574        // to (β·v) and (β·k·exp(g)) to produce w and u; spec source: FLA
575        // wy_fast.py:29-117).
576        let gdn_recompute_wu_src: &'static str =
577            include_str!("shaders/gated_delta_net_recompute_wu.metal");
578        sources.insert(
579            "gated_delta_net_recompute_wu_bf16".into(),
580            gdn_recompute_wu_src,
581        );
582        // Wave 5b.1 iter 3 — chunk_fwd_o kernel (per-chunk output: closes
583        // the chunk pipeline; spec source: FLA chunk_o.py:42-138).
584        let gdn_chunk_o_src: &'static str =
585            include_str!("shaders/gated_delta_net_chunk_o.metal");
586        sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
587        // Wave 5b.1 iter 4 — orchestrator helper kernels:
588        //   chunk_local_cumsum_g_f32      — per-chunk prefix sum on g [B, T, H]
589        //   chunk_tri_solve_invert_f32    — per-chunk-block (I + A_strict)^-1
590        //                                   on FLA's [B, T, H, BT] layout.
591        let chunk_local_cumsum_g_src: &'static str =
592            include_str!("shaders/chunk_local_cumsum_g.metal");
593        sources.insert(
594            "chunk_local_cumsum_g_f32".into(),
595            chunk_local_cumsum_g_src,
596        );
597        let chunk_tri_solve_invert_src: &'static str =
598            include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
599        sources.insert(
600            "chunk_tri_solve_invert_f32".into(),
601            chunk_tri_solve_invert_src,
602        );
603        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
604        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
605        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
606        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
607        let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
608        sources.insert("silu_mul_f32".into(), silu_mul_src);
609        let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
610        sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
611        let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
612        sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
613        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
614        sources.insert("gelu_f32".into(), gelu_src);
615        sources.insert("gelu_f16".into(), gelu_src);
616        sources.insert("gelu_bf16".into(), gelu_src);
617        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
618        sources.insert("softmax_f32".into(), softmax_src);
619        sources.insert("softmax_f16".into(), softmax_src);
620        sources.insert("softmax_bf16".into(), softmax_src);
621        let softmax_backward_src: &'static str =
622            include_str!("shaders/softmax_backward.metal");
623        sources.insert("softmax_backward_f32".into(), softmax_backward_src);
624        let log_elementwise_src: &'static str =
625            include_str!("shaders/log_elementwise.metal");
626        sources.insert("log_f32".into(), log_elementwise_src);
627        sources.insert("log_backward_f32".into(), log_elementwise_src);
628        let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
629        sources.insert("row_sum_f32".into(), row_sum_src);
630        sources.insert("row_sum_backward_f32".into(), row_sum_src);
631        // ADR-020 iter-10a: GGUF-legacy quantize-dequantize round-trip kernels
632        // (Q4_0 + Q8_0).  Used by hf2q's dynamic_quant Track 1 to produce
633        // W_low / W_high for the gradient-Taylor sensitivity formula.
634        let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
635        sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
636        sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
637        // ADR-020 iter-10b: RMSNorm reverse-mode autograd kernels.
638        // r_inv helper is reused by both backward kernels; dx and dw cover
639        // the full backward identity for `y = x * rsqrt(mean(x²) + eps) * w`.
640        let rms_norm_backward_src: &'static str =
641            include_str!("shaders/rms_norm_backward.metal");
642        sources.insert(
643            "rms_norm_compute_rms_inv_f32".into(),
644            rms_norm_backward_src,
645        );
646        sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
647        sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
648        // ADR-020 iter-11a: 2-D row-major slice + concat-by-column kernels.
649        // Used by hf2q's multi-head SDPA on GpuTape (slice Q/K/V into
650        // per-head views, run per-head SDPA, concat per-head contexts
651        // back to full attention output).
652        let slice_concat_2d_src: &'static str =
653            include_str!("shaders/slice_concat_2d.metal");
654        sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
655        sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
656        // ADR-020 iter-11b: SiLU forward + backward kernels for GpuTape
657        // SwiGLU FFN composition.
658        let silu_backward_src: &'static str =
659            include_str!("shaders/silu_backward.metal");
660        sources.insert("silu_f32".into(), silu_backward_src);
661        sources.insert("silu_backward_f32".into(), silu_backward_src);
662        // ADR-020 iter-11d: FP32 embedding lookup + scatter-add backward.
663        let embedding_autograd_src: &'static str =
664            include_str!("shaders/embedding_autograd.metal");
665        sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
666        sources.insert(
667            "embedding_scatter_add_f32".into(),
668            embedding_autograd_src,
669        );
670        // ADR-020 iter-13a: Adam optimizer step kernel for Track 2
671        // DWQ-proper training loop.
672        let adam_update_src: &'static str =
673            include_str!("shaders/adam_update.metal");
674        sources.insert("adam_update_f32".into(), adam_update_src);
675        // ADR-020 iter-13b: differentiable affine qdq kernels for the
676        // DWQ-proper training loop.  Init + forward + backward (scales,
677        // biases) — q_int is FROZEN, scales+biases learnable.
678        let qdq_affine_src: &'static str =
679            include_str!("shaders/qdq_affine.metal");
680        sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
681        sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
682        sources.insert(
683            "qdq_affine_backward_scales_f32".into(),
684            qdq_affine_src,
685        );
686        sources.insert(
687            "qdq_affine_backward_biases_f32".into(),
688            qdq_affine_src,
689        );
690        // ADR-020 iter-15: fused affine quantized matmul for DWQ inference.
691        // Per-element kernel; one thread per (m, n) output element.
692        // Tiled + simdgroup-MMA variant lands in iter-15b.
693        let qmm_affine_src: &'static str =
694            include_str!("shaders/qmm_affine.metal");
695        sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
696        // ADR-020 iter-15b: tiled variant — 16x16 thread block with
697        // cooperative-load X/W tiles in threadgroup-shared memory for
698        // 2-5x speedup over the per-element kernel.
699        let qmm_affine_tiled_src: &'static str =
700            include_str!("shaders/qmm_affine_tiled.metal");
701        sources.insert(
702            "qmm_affine_t_f32_tiled".into(),
703            qmm_affine_tiled_src,
704        );
705        // ADR-020 iter-15c: simdgroup-MMA variant — uses Apple GPU
706        // hardware `simdgroup_matrix<float, 8, 8>` MMA for the inner
707        // reduction.  Per-tile algorithmic 8× over scalar tiled, lands
708        // as ~3-4× wall after launch / load amortization.
709        let qmm_affine_simd_src: &'static str =
710            include_str!("shaders/qmm_affine_simd.metal");
711        sources.insert(
712            "qmm_affine_t_f32_simd".into(),
713            qmm_affine_simd_src,
714        );
715        // ADR-020 iter-15c-2: 4-simdgroup-per-TG variant — 32×32
716        // output tile, 4 simdgroups arranged as 2×2 grid each owning
717        // a 16×16 sub-tile = 4 simdgroup_matrix accumulators.  Same
718        // math as 15c-1, fuller warp-pool exploitation.
719        let qmm_affine_simd4_src: &'static str =
720            include_str!("shaders/qmm_affine_simd4.metal");
721        sources.insert(
722            "qmm_affine_t_f32_simd4".into(),
723            qmm_affine_simd4_src,
724        );
725        // ADR-020 iter-15c-2b: gs=64 variant (mlx-lm dynamic_quant
726        // canonical default).  Same 4-simdgroup geometry, BK=64
727        // instead of 32 (= 8 sub-K-tiles per K-step instead of 4).
728        let qmm_affine_simd4_gs64_src: &'static str =
729            include_str!("shaders/qmm_affine_simd4_gs64.metal");
730        sources.insert(
731            "qmm_affine_t_f32_simd4_gs64".into(),
732            qmm_affine_simd4_gs64_src,
733        );
734        // ADR-020 AC#5 Iter A: packed-U32 dense affine matmul (bits=4,
735        // gs=32) — production decode/prefill kernel for serving DWQ
736        // safetensors directly without a load-time unpack pass.
737        let qmm_affine_t_packed_simd4_b4_src: &'static str =
738            include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
739        sources.insert(
740            "qmm_affine_t_packed_simd4_b4".into(),
741            qmm_affine_t_packed_simd4_b4_src,
742        );
743        // ADR-020 iter-11h-b: training-mode causal depthwise 1D
744        // convolution (forward + backward dx + backward dw).  Used by
745        // GpuTape autograd for differentiable Qwen3.5MoE forward
746        // (GatedDeltaNet's conv1d step).
747        let conv1d_dwc_src: &'static str =
748            include_str!("shaders/conv1d_depthwise_causal.metal");
749        sources.insert(
750            "conv1d_depthwise_causal_forward_f32".into(),
751            conv1d_dwc_src,
752        );
753        sources.insert(
754            "conv1d_depthwise_causal_backward_dx_f32".into(),
755            conv1d_dwc_src,
756        );
757        sources.insert(
758            "conv1d_depthwise_causal_backward_dw_f32".into(),
759            conv1d_dwc_src,
760        );
761        // ADR-020 iter-11h-c1: elementwise exp forward + backward.
762        // Building block for GatedDeltaNet's alpha = exp(-g) state-decay.
763        let exp_src: &'static str =
764            include_str!("shaders/exp_elementwise.metal");
765        sources.insert("exp_f32".into(), exp_src);
766        sources.insert("exp_backward_f32".into(), exp_src);
767        // ADR-020 iter-11h-c2: vector outer product (forward + dlhs +
768        // drhs).  Building block for gated_delta_update's
769        // outer(delta, k) state-update term.
770        let outer_src: &'static str =
771            include_str!("shaders/outer_product.metal");
772        sources.insert("outer_product_f32".into(), outer_src);
773        sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
774        sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
775        // ADR-020 iter-11h-e1: take_along_axis (gather) + scatter-backward.
776        // Building block for MoE router on GpuTape.
777        let taa_src: &'static str =
778            include_str!("shaders/take_along_axis.metal");
779        sources.insert("take_along_axis_f32".into(), taa_src);
780        sources.insert("take_along_axis_backward_f32".into(), taa_src);
781        // ADR-020 iter-11h-misc-1: elementwise divide forward + backward.
782        let div_src: &'static str =
783            include_str!("shaders/divide_elementwise.metal");
784        sources.insert("divide_f32".into(), div_src);
785        sources.insert("divide_backward_f32".into(), div_src);
786        // ADR-020 iter-11h-misc-3: elementwise sqrt forward + backward.
787        let sqrt_src: &'static str =
788            include_str!("shaders/sqrt_elementwise.metal");
789        sources.insert("sqrt_f32".into(), sqrt_src);
790        sources.insert("sqrt_backward_f32".into(), sqrt_src);
791        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
792        sources.insert("softcap_f32".into(), softcap_src);
793        sources.insert("softcap_f16".into(), softcap_src);
794        sources.insert("softcap_bf16".into(), softcap_src);
795
796        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
797        //   normed = rms_norm(input, weight, eps);  output = residual + normed
798        let fused_norm_add_src: &'static str =
799            include_str!("shaders/fused_norm_add_bf16.metal");
800        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
801        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
802
803        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
804        let fused_hnr_f32_src: &'static str =
805            include_str!("shaders/fused_head_norm_rope_f32.metal");
806        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
807        // ADR-028 iter-337 — float4 + simd_sum Phase 1 variant.  Phases
808        // 2-4 byte-identical to v1; race-fix barrier preserved.  Env-gated
809        // via HF2Q_FUSED_HEAD_NORM_ROPE_V2 (default ON, opt-out via =0).
810        sources.insert("fused_head_norm_rope_f32_v2".into(), fused_hnr_f32_src);
811
812        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
813        // Both entry points live in the same .metal file.
814        let fused_hnr_bf16_src: &'static str =
815            include_str!("shaders/fused_head_norm_rope_bf16.metal");
816        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
817        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
818
819        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
820        let fused_norm_add_f32_src: &'static str =
821            include_str!("shaders/fused_norm_add_f32.metal");
822        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
823        // ADR-028 iter-331 — float4 + simd_sum variant (peer-pattern,
824        // ported from llama.cpp kernel_rms_norm_fuse_impl<float4, 3>).
825        // Env-gated via HF2Q_FUSED_NORM_ADD_V2=1 in the dispatcher
826        // (default ON since iter-331; opt-out via =0/false/off).
827        sources.insert("fused_norm_add_f32_v2".into(), fused_norm_add_f32_src);
828        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
829        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
830        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
831        // ADR-028 iter-363: V2 (simd_max + simd_sum) variant of MoE routing.
832        sources.insert("fused_moe_routing_f32_v2".into(), fused_norm_add_f32_src);
833        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
834        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
835        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
836        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
837
838        // Argsort kernel (Story 2.3) — MoE top-K routing
839        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
840        sources.insert("argsort_desc_f32".into(), argsort_src);
841
842        // Gather / index_select kernel (Story 2.4)
843        let gather_src: &'static str = include_str!("shaders/gather.metal");
844        sources.insert("gather_f32".into(), gather_src);
845
846        // F32 KV cache copy kernel (Session merge S1+S2)
847        let kv_cache_copy_src: &'static str =
848            include_str!("shaders/kv_cache_copy.metal");
849        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
850        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
851
852        // Strided copy kernel (Story 2.5)
853        let copy_src: &'static str = include_str!("shaders/copy.metal");
854        sources.insert("strided_copy_f32".into(), copy_src);
855        sources.insert("offset_copy_f32".into(), copy_src);
856
857        // Fused-QKV split kernel (ADR-005 W-5b.18 — replaces hf2q CPU
858        // download → triple-loop split → 3× upload round-trip in
859        // gpu_delta_net::layer_qkv_deinterleave).
860        let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
861        sources.insert("qkv_split_f32".into(), qkv_split_src);
862
863        // Tiled-GQA broadcast kernel (ADR-005 W-5b.19 — replaces hf2q CPU
864        // tiled-replicate at gpu_delta_net::apply_gated_delta_net_chunk
865        // GQA pre-expansion, ~497 ms / 10.4 ms-per-layer at PP4106).
866        let repeat_tiled_src: &'static str =
867            include_str!("shaders/repeat_tiled.metal");
868        sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
869
870        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
871        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
872        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
873        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
874        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
875        // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
876        sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
877        // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
878        sources.insert("dense_matvec_f32".into(), dense_gemm_src);
879
880        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
881        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
882        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
883        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
884        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
885        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
886        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
887        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
888        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
889
890        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
891        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
892        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
893        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
894        // ADR-028 iter-485 (Phase 7d / H4): fused K+V single-position 4-bit encoder.
895        sources.insert("hadamard_quantize_kv_fast_dual_d256".into(), hq_fast_src);
896        sources.insert("hadamard_quantize_kv_fast_dual_d512".into(), hq_fast_src);
897        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
898        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
899        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
900        // ADR-028 iter-148: fused K+V single-position HB encoder
901        sources.insert("hadamard_quantize_kv_hb_dual_d256".into(), hq_fast_src);
902        sources.insert("hadamard_quantize_kv_hb_dual_d512".into(), hq_fast_src);
903        // ADR-028 Phase 10e.5 (iter-351): no-FWHT V quantize for hybrid path.
904        // Same byte-packed Lloyd-Max codebook output, but skips the Hadamard
905        // rotation so dequant in SDPA recovers raw V (no FWHT-undo needed).
906        sources.insert("kv_quantize_v_no_fwht_d256".into(), hq_fast_src);
907        sources.insert("kv_quantize_v_no_fwht_d512".into(), hq_fast_src);
908        // ADR-028 Phase 10c.5 (iter-354): fused F16-K-copy + V-no-FWHT-encode.
909        // Saves 30 KV-write dispatches/decode-token at gemma4 30L by combining
910        // the per-layer K-cast and V-encode into a single dispatch (Z-dim).
911        sources.insert("kv_copy_kf16_quantize_v_no_fwht_d256".into(), hq_fast_src);
912        sources.insert("kv_copy_kf16_quantize_v_no_fwht_d512".into(), hq_fast_src);
913
914        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
915        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
916        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
917        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
918        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
919        // ADR-027 Phase B iter-30 (hf2q sub-sub-iter 23c-β.1): sequence-batch
920        // dequant variant. Same MSL source; new kernel entry point
921        // `tq_dequantize_hb_kv_seq` reads positions [start_pos..start_pos+n_tokens)
922        // in one dispatch (one threadgroup per (kv_head, position)). Unblocks
923        // hf2q's TQ-aware prefill SDPA path (current per-position kernel
924        // requires cur_len separate dispatches).
925        sources.insert("tq_dequantize_hb_kv_seq".into(), tq_dq_src);
926
927        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
928        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
929        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
930        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
931
932        // ADR-028 §iter-485 (Phase 7d H3): fused TQ-HB reduce + FWHT-sign-undo.
933        // Combines flash_attn_vec_reduce + fwht_sign_undo_f32 into a single
934        // dispatch, saving 1 dispatch + 1 forced barrier per layer per decode
935        // token. Gated by env flag `HF2Q_TQ_HB_OUT_FUSED=1` in forward_mlx.rs.
936        let reduce_undo_src: &'static str = include_str!("shaders/flash_attn_vec_reduce_tq_hb_undo.metal");
937        sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk256".into(), reduce_undo_src);
938        sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk512".into(), reduce_undo_src);
939
940        // ADR-028 Phase 10d (iter-349): hybrid F16-K + TQ-HB-V SDPA kernel.
941        // Same V-side codebook as flash_attn_vec_tq_hb (5/6/8-bit Lloyd-Max);
942        // K-side reads F16 dense — peer-equivalent layout, no codebook lookup.
943        let hybrid_src: &'static str = include_str!("shaders/flash_attn_vec_hybrid.metal");
944        sources.insert("flash_attn_vec_hybrid_dk256".into(), hybrid_src);
945        sources.insert("flash_attn_vec_hybrid_dk512".into(), hybrid_src);
946
947        // ADR-029 CFA cfa-20260512-fa-peer-port (iter-122): verbatim llama.cpp peer port.
948        // F16-K + F16-V, DK=DV=256, NWG=1, NSG=1, NE=1. No function constants — baked.
949        let peer_port_src: &'static str = include_str!("shaders/flash_attn_vec_peer_port_f16.metal");
950        sources.insert("flash_attn_vec_peer_port_f16_dk256_dv256".into(), peer_port_src);
951
952        // ADR-029 iter-134: peer reduce kernel (verbatim port of ggml-metal.metal 7235-7275).
953        // Pairs with the NWG=32 vec kernel to match peer's actual runtime dispatch.
954        let peer_port_reduce_src: &'static str =
955            include_str!("shaders/flash_attn_vec_peer_port_f16_reduce.metal");
956        sources.insert(
957            "flash_attn_vec_peer_port_f16_reduce_dv256_nwg32".into(),
958            peer_port_reduce_src,
959        );
960
961        // ADR-029 iter-135: NWG=32 variant of the verbatim peer port. Same body as
962        // flash_attn_vec_peer_port_f16.metal with NWG=1→32. Pairs with iter-134 reduce kernel.
963        let peer_port_nwg32_src: &'static str =
964            include_str!("shaders/flash_attn_vec_peer_port_f16_nwg32.metal");
965        sources.insert(
966            "flash_attn_vec_peer_port_f16_nwg32_dk256_dv256".into(),
967            peer_port_nwg32_src,
968        );
969
970        // GPU sampling kernels — eliminate logits readback (Phase 6)
971        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
972        sources.insert("argmax_f32".into(), argmax_src);
973        let softmax_sample_src: &'static str =
974            include_str!("shaders/softmax_sample.metal");
975        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
976        // Top-K kernel for Q8 rerank: avoids full-logits readback.
977        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
978        sources.insert("top_k_f32".into(), top_k_src);
979
980        // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
981        // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
982        let moe_stk_src: &'static str =
983            include_str!("shaders/moe_softmax_topk.metal");
984        sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
985        let moe_wr_src: &'static str =
986            include_str!("shaders/moe_weighted_reduce.metal");
987        sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
988        let sdpa_decode_src: &'static str =
989            include_str!("shaders/sdpa_decode.metal");
990        sources.insert("sdpa_decode".into(), sdpa_decode_src);
991
992        Self {
993            cache: HashMap::new(),
994            sources,
995        }
996    }
997
998    /// Register a shader source at runtime (useful for testing and dynamic
999    /// kernel generation).
1000    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
1001        let name = name.into();
1002        // Invalidate any cached pipeline for this name since the source changed.
1003        self.cache.remove(&name);
1004        self.sources.insert(name, source);
1005    }
1006
1007    /// Get a compiled compute pipeline for the named kernel function.
1008    ///
1009    /// On first call for a given name, this compiles the MSL source into a
1010    /// Metal library, extracts the named function, and creates a
1011    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
1012    ///
1013    /// # Errors
1014    ///
1015    /// * `MlxError::KernelNotFound` — no source registered for this name.
1016    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
1017    ///   creation failed.
1018    pub fn get_pipeline(
1019        &mut self,
1020        name: &str,
1021        device: &metal::DeviceRef,
1022    ) -> Result<&ComputePipelineState> {
1023        if !self.cache.contains_key(name) {
1024            // Slow path: compile the shader.
1025            let source = self.sources.get(name).ok_or_else(|| {
1026                MlxError::KernelNotFound(name.to_string())
1027            })?;
1028
1029            let compile_opts = metal::CompileOptions::new();
1030            let library = device
1031                .new_library_with_source(source, &compile_opts)
1032                .map_err(|msg| MlxError::ShaderCompilationError {
1033                    name: name.to_string(),
1034                    message: msg,
1035                })?;
1036
1037            let function = library
1038                .get_function(name, None)
1039                .map_err(|msg| MlxError::ShaderCompilationError {
1040                    name: name.to_string(),
1041                    message: msg,
1042                })?;
1043
1044            // Build the pipeline through a descriptor so we can attach a
1045            // human-readable label.  The label propagates into Instruments /
1046            // xctrace Metal System Trace as the per-pipeline identifier
1047            // (`metal-object-label` schema), giving us per-kernel attribution
1048            // instead of the generic "Compute Command 0" placeholder.
1049            //
1050            // `MTLComputePipelineState.label` is read-only after creation per
1051            // the Apple Metal spec; the only supported way to set it is via
1052            // the descriptor before pipeline creation.  ADR-015 iter9b.
1053            let descriptor = ComputePipelineDescriptor::new();
1054            descriptor.set_compute_function(Some(&function));
1055            descriptor.set_label(name);
1056            // ADR-028 iter-376: threadGroupSizeIsMultipleOfThreadExecutionWidth
1057            // hint allows the Metal compiler to skip bounds checks and use more
1058            // aggressive codegen.  Opt-in via HF2Q_PIPELINE_TG_MULT_HINT=1.
1059            // SAFETY: every dispatched threadgroup MUST be a multiple of 32 at
1060            // runtime — Apple specifies undefined behavior otherwise.  Our hot
1061            // kernels use tg_size ∈ {32, 64, 256, 1024} (all multiples of 32).
1062            if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
1063                descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
1064            }
1065
1066            let pipeline = device
1067                .new_compute_pipeline_state(&descriptor)
1068                .map_err(|msg| MlxError::ShaderCompilationError {
1069                    name: name.to_string(),
1070                    message: msg,
1071                })?;
1072
1073            self.cache.insert(name.to_string(), pipeline);
1074        }
1075
1076        // At this point the pipeline is guaranteed to be in the cache.
1077        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
1078        self.cache.get(name).ok_or_else(|| {
1079            MlxError::KernelNotFound(name.to_string())
1080        })
1081    }
1082
1083    /// Get a compiled compute pipeline for the named kernel, specialized with
1084    /// Metal function constants (both bool and i32 in one call).
1085    ///
1086    /// `bool_constants` contains `(index, value)` pairs mapping to
1087    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
1088    /// `int_constants` contains `(index, value)` pairs mapping to
1089    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
1090    /// shader.
1091    ///
1092    /// Pipelines are cached by a composite key:
1093    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
1094    /// marks bool entries and the 'i' prefix marks i32 entries, making the
1095    /// format unambiguous regardless of constant ordering.  Distinct
1096    /// `(name, constants)` combinations each compile to a separate pipeline;
1097    /// the slow compilation path runs at most once per unique combination.
1098    ///
1099    /// # Errors
1100    ///
1101    /// * `MlxError::KernelNotFound` — no source registered for this name.
1102    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
1103    ///   specialisation, or pipeline creation failed.
1104    pub fn get_pipeline_with_constants(
1105        &mut self,
1106        name: &str,
1107        device: &metal::DeviceRef,
1108        bool_constants: &[(usize, bool)],
1109        int_constants: &[(usize, i32)],
1110    ) -> Result<&ComputePipelineState> {
1111        // Build a composite cache key so distinct constant combinations each
1112        // compile to their own pipeline.  Bool entries use the 'b' type marker
1113        // and i32 entries use 'i'; this prevents a collision between, e.g.,
1114        // bool index 5 value 1 and int index 5 value 1.
1115        let mut cache_key = name.to_string();
1116        for &(index, value) in bool_constants {
1117            cache_key.push('|');
1118            cache_key.push_str(&index.to_string());
1119            cache_key.push_str(if value { ":b1" } else { ":b0" });
1120        }
1121        for &(index, value) in int_constants {
1122            cache_key.push('|');
1123            cache_key.push_str(&index.to_string());
1124            cache_key.push(':');
1125            cache_key.push('i');
1126            cache_key.push_str(&value.to_string());
1127        }
1128
1129        if !self.cache.contains_key(&cache_key) {
1130            // Slow path: compile the shader with function constant specialisation.
1131            let source = self.sources.get(name).ok_or_else(|| {
1132                MlxError::KernelNotFound(name.to_string())
1133            })?;
1134
1135            let compile_opts = metal::CompileOptions::new();
1136            let library = device
1137                .new_library_with_source(source, &compile_opts)
1138                .map_err(|msg| MlxError::ShaderCompilationError {
1139                    name: name.to_string(),
1140                    message: msg,
1141                })?;
1142
1143            // Build the FunctionConstantValues object with all bool and i32
1144            // constants.  Metal's set_constant_value_at_index reads the value
1145            // through a raw pointer; the pointed-to bytes must match the size
1146            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
1147            let fcv = FunctionConstantValues::new();
1148
1149            for &(index, value) in bool_constants {
1150                // MTLDataType::Bool = 53 (metal-rs argument.rs).
1151                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
1152                let v: u8 = if value { 1 } else { 0 };
1153                fcv.set_constant_value_at_index(
1154                    (&v as *const u8).cast::<std::ffi::c_void>(),
1155                    MTLDataType::Bool,
1156                    index as u64,
1157                );
1158            }
1159
1160            for &(index, value) in int_constants {
1161                // MTLDataType::Int = 29 (metal-rs argument.rs).
1162                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
1163                // matching the Metal shader type `constant int`.
1164                fcv.set_constant_value_at_index(
1165                    (&value as *const i32).cast::<std::ffi::c_void>(),
1166                    MTLDataType::Int,
1167                    index as u64,
1168                );
1169            }
1170
1171            let function = library
1172                .get_function(name, Some(fcv))
1173                .map_err(|msg| MlxError::ShaderCompilationError {
1174                    name: name.to_string(),
1175                    message: msg,
1176                })?;
1177
1178            // Label this specialisation with the full composite cache key
1179            // (e.g. `kernel_mul_mv_q4_0_f32|0:b1|3:i32`) so xctrace Metal
1180            // System Trace shows each function-constant variant as a distinct
1181            // pipeline.  Without this, all specialisations share a generic
1182            // "Compute Command 0" identifier and we cannot attribute µs/token
1183            // to a specific (kernel, constants) combination.  ADR-015 iter9b.
1184            let descriptor = ComputePipelineDescriptor::new();
1185            descriptor.set_compute_function(Some(&function));
1186            descriptor.set_label(&cache_key);
1187            // ADR-028 iter-376: same hint as primary pipeline path.
1188            if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
1189                descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
1190            }
1191
1192            let pipeline = device
1193                .new_compute_pipeline_state(&descriptor)
1194                .map_err(|msg| MlxError::ShaderCompilationError {
1195                    name: name.to_string(),
1196                    message: msg,
1197                })?;
1198
1199            self.cache.insert(cache_key.clone(), pipeline);
1200        }
1201
1202        self.cache.get(&cache_key).ok_or_else(|| {
1203            MlxError::KernelNotFound(name.to_string())
1204        })
1205    }
1206
1207    /// Get a compiled compute pipeline for the named kernel, specialized with
1208    /// Metal bool function constants.
1209    ///
1210    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
1211    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
1212    ///
1213    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
1214    /// passes an empty `int_constants` slice.  Existing callers continue to
1215    /// work without modification; the cache-key format for pure-bool pipelines
1216    /// is compatible (bool entries carry the 'b' type marker, which is the
1217    /// only format ever written by this wrapper).
1218    ///
1219    /// # Errors
1220    ///
1221    /// * `MlxError::KernelNotFound` — no source registered for this name.
1222    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
1223    ///   specialisation, or pipeline creation failed.
1224    pub fn get_pipeline_with_bool_constants(
1225        &mut self,
1226        name: &str,
1227        device: &metal::DeviceRef,
1228        bool_constants: &[(usize, bool)],
1229    ) -> Result<&ComputePipelineState> {
1230        self.get_pipeline_with_constants(name, device, bool_constants, &[])
1231    }
1232
1233    /// Check if a pipeline for the given name is already compiled and cached.
1234    pub fn is_cached(&self, name: &str) -> bool {
1235        self.cache.contains_key(name)
1236    }
1237
1238    /// Number of compiled pipelines currently in the cache.
1239    pub fn cached_count(&self) -> usize {
1240        self.cache.len()
1241    }
1242
1243    /// Number of registered shader sources.
1244    pub fn source_count(&self) -> usize {
1245        self.sources.len()
1246    }
1247}
1248
1249impl Default for KernelRegistry {
1250    fn default() -> Self {
1251        Self::new()
1252    }
1253}
1254
1255#[cfg(test)]
1256mod tests {
1257    use super::*;
1258
1259    /// Minimal Metal shader that uses a single int function constant.
1260    ///
1261    /// The kernel writes the constant value N into the first element of the
1262    /// output buffer, allowing the test to verify that the Metal compiler
1263    /// actually sees distinct specialisations for N=4 and N=8.
1264    ///
1265    /// The shader is intentionally trivial — we only need it to *compile* with
1266    /// an int function constant; correctness of the kernel logic is not under
1267    /// test here.
1268    const INT_FC_TEST_SHADER: &str = r#"
1269#include <metal_stdlib>
1270using namespace metal;
1271
1272constant int test_N [[function_constant(100)]];
1273
1274kernel void int_fc_test_kernel(
1275    device int* out [[buffer(0)]],
1276    uint tid [[thread_position_in_grid]])
1277{
1278    if (tid == 0) {
1279        out[0] = test_N;
1280    }
1281}
1282"#;
1283
1284    /// Verify that `get_pipeline_with_constants` produces distinct cached
1285    /// pipelines for different i32 function-constant values, and that
1286    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
1287    /// works correctly with the new 'b'-prefixed cache-key format.
1288    ///
1289    /// This test requires a real Metal device and is therefore marked
1290    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
1291    #[test]
1292    fn test_int_fc_distinct_pipelines_and_bool_compat() {
1293        let device = metal::Device::system_default()
1294            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1295
1296        let mut registry = KernelRegistry::new();
1297
1298        // Register the inline test shader under a name that cannot collide with
1299        // any production kernel.
1300        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1301
1302        // Compile with N=4.
1303        let p4_ptr = registry
1304            .get_pipeline_with_constants(
1305                "int_fc_test_kernel",
1306                &device,
1307                &[],                  // no bool constants
1308                &[(100, 4_i32)],      // int constant index 100 = 4
1309            )
1310            .expect("pipeline N=4 should compile") as *const _;
1311
1312        // Cache must now have exactly 1 entry for this kernel.
1313        // (Other production kernels may already be in cache from new(); here
1314        // we check that the N=4 key was inserted.)
1315        let count_after_n4 = registry.cached_count();
1316
1317        // Compile with N=8 — must produce a SEPARATE pipeline.
1318        let p8_ptr = registry
1319            .get_pipeline_with_constants(
1320                "int_fc_test_kernel",
1321                &device,
1322                &[],
1323                &[(100, 8_i32)],
1324            )
1325            .expect("pipeline N=8 should compile") as *const _;
1326
1327        // Cache must have grown by exactly 1.
1328        assert_eq!(
1329            registry.cached_count(),
1330            count_after_n4 + 1,
1331            "N=8 must produce a new cache entry"
1332        );
1333
1334        // The two pipelines must be distinct objects in the cache.
1335        assert_ne!(
1336            p4_ptr, p8_ptr,
1337            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1338        );
1339
1340        // A second call with N=4 must return the SAME pipeline (cache hit, no
1341        // new compilation).
1342        let p4_again_ptr = registry
1343            .get_pipeline_with_constants(
1344                "int_fc_test_kernel",
1345                &device,
1346                &[],
1347                &[(100, 4_i32)],
1348            )
1349            .expect("pipeline N=4 cache hit should succeed") as *const _;
1350
1351        assert_eq!(
1352            registry.cached_count(),
1353            count_after_n4 + 1,
1354            "repeated N=4 call must be a cache hit, not a new entry"
1355        );
1356        assert_eq!(
1357            p4_ptr, p4_again_ptr,
1358            "repeated N=4 call must return the same pipeline pointer"
1359        );
1360
1361        // Verify backward compatibility: get_pipeline_with_bool_constants must
1362        // still route through get_pipeline_with_constants and produce a cached
1363        // pipeline without panicking.
1364        //
1365        // We register a separate bool-constant shader that does NOT use a bool
1366        // function constant (so the Metal compiler ignores missing FCs for
1367        // this trivial case) — but the call path and cache-key format are what
1368        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
1369        // simply unused by the shader (Metal allows unused FCs when the shader
1370        // declares them with `function_constant` but the value is never read).
1371        //
1372        // To avoid a Metal compiler error for an undeclared function constant,
1373        // we register a separate bare-kernel shader for the bool wrapper test.
1374        const BARE_SHADER: &str = r#"
1375#include <metal_stdlib>
1376using namespace metal;
1377kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1378    if (tid == 0) { out[0] = 42; }
1379}
1380"#;
1381        registry.register_source("bare_kernel", BARE_SHADER);
1382
1383        let count_before_bool = registry.cached_count();
1384        let _bool_pipeline = registry
1385            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1386            .expect("bool-constants wrapper with empty slice must succeed");
1387
1388        assert_eq!(
1389            registry.cached_count(),
1390            count_before_bool + 1,
1391            "bool-constants wrapper must insert one new cache entry"
1392        );
1393    }
1394
1395    /// Verify that the `MTLComputePipelineState.label` produced by
1396    /// `get_pipeline` and `get_pipeline_with_constants` actually propagates
1397    /// from the descriptor to the resulting pipeline state.
1398    ///
1399    /// This is the in-process smoke check for ADR-015 iter9b: we cannot
1400    /// reach into xctrace from Rust, but we can read back the same `label`
1401    /// property xctrace consumes via `ComputePipelineStateRef::label()`.
1402    /// If labels are missing or wrong here, the MST trace will also show
1403    /// generic identifiers — so this test gates the iter9 retry's
1404    /// per-Q4_0-kernel attribution.
1405    #[test]
1406    fn test_pipeline_labels_propagate_for_mst() {
1407        let device = metal::Device::system_default()
1408            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1409
1410        let mut registry = KernelRegistry::new();
1411
1412        // Reuse the same trivial shaders as the int-FC test.
1413        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1414
1415        const BARE_SHADER_LABEL_TEST: &str = r#"
1416#include <metal_stdlib>
1417using namespace metal;
1418kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1419    if (tid == 0) { out[0] = 7; }
1420}
1421"#;
1422        registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1423
1424        // Plain get_pipeline path — label must equal the kernel name.
1425        // Capture as owned String so the cache borrow is released before
1426        // the next get_pipeline_with_constants call below.
1427        let plain_label = registry
1428            .get_pipeline("label_smoke_kernel", &device)
1429            .expect("plain pipeline must compile")
1430            .label()
1431            .to_string();
1432        assert_eq!(
1433            plain_label, "label_smoke_kernel",
1434            "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1435        );
1436
1437        // Constants path — label must equal the composite cache key so each
1438        // function-constant variant is individually attributable in MST.
1439        // We capture the label as an owned String to release the borrow on
1440        // the cache before fetching the next specialisation.
1441        let label_v7 = registry
1442            .get_pipeline_with_constants(
1443                "int_fc_test_kernel",
1444                &device,
1445                &[],
1446                &[(100, 7_i32)],
1447            )
1448            .expect("specialised pipeline must compile")
1449            .label()
1450            .to_string();
1451        assert_eq!(
1452            label_v7, "int_fc_test_kernel|100:i7",
1453            "get_pipeline_with_constants must label with the cache_key so each \
1454             specialisation is distinct in xctrace MST"
1455        );
1456
1457        // A second specialisation must produce a different label.
1458        let label_v13 = registry
1459            .get_pipeline_with_constants(
1460                "int_fc_test_kernel",
1461                &device,
1462                &[],
1463                &[(100, 13_i32)],
1464            )
1465            .expect("second specialised pipeline must compile")
1466            .label()
1467            .to_string();
1468        assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1469        assert_ne!(
1470            label_v7, label_v13,
1471            "distinct constant values must yield distinct pipeline labels"
1472        );
1473    }
1474}