Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80        // ADR-022 Phase 1 — Q5_1 / IQ4_NL dense mat-vec.
81        sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
82        sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
83        // ADR-013 P7 — Q4_K dense decode mat-vec (port of llama.cpp's
84        // kernel_mul_mv_q4_K_f32 at ggml-metal.metal:7715-7821).
85        sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
86        // ADR-022 Phase 2 — Q5_K dense mv kernel.
87        sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
88
89        // GGML block-format quantized matrix-matrix kernels
90        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
91        // Used at prefill m > 8 to reuse each weight tile across a 32-row
92        // block via threadgroup-staged simdgroup MMA, instead of re-reading
93        // every block per prompt-token as the mv kernel does.
94        let ggml_mm_src: &'static str =
95            include_str!("shaders/quantized_matmul_mm.metal");
96        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
97        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
98        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
99        // ADR-022 Phase 1 — dense Q5_1 / IQ4_NL mm.
100        sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
101        sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
102        // ADR-022 Phase 2 — dense Q5_K mm.
103        sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
104        // ADR-022 Phase 3 — dense Q4_K mm.
105        sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
106
107        // GGML block-format quantized matrix-matrix kernels — tensor API
108        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
109        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
110        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
111        // primitive which on M3+ dispatches to hardware tensor cores for
112        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
113        // Only compiled on devices where the tensor API is available; the
114        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
115        // compilation so non-tensor devices transparently fall back to the
116        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
117        let ggml_mm_tensor_src: &'static str =
118            include_str!("shaders/quantized_matmul_mm_tensor.metal");
119        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
120        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
121        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
122        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
123        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
124        // ADR-022 Phase 1 — Q5_1 / IQ4_NL tensor mm.
125        sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
126        sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
127        // ADR-022 Phase 2 — Q5_K tensor mm.
128        sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
129        // ADR-022 Phase 3 — Q4_K tensor mm + Q8_0 perm021.
130        sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
131        sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
132
133        // ADR-022 Phase 1 P1.7 — Q5_1 / IQ4_NL mul_mv_ext r1 family.
134        // Eight instantiations (2 types × 4 r1ptg widths). Each PSO is
135        // additionally specialized at PSO-compile time with FC_mul_mv_nsg
136        // (function_constant 600) and FC_mul_mv_nxpsg (function_constant 601).
137        let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
138        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
139        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
140        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
141        sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
142        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
143        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
144        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
145        sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
146        // ADR-022 Phase 4 — Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K mv_ext.
147        // 5 types × 4 r1ptg widths = 20 instantiations.
148        for r1 in [2, 3, 4, 5].iter() {
149            for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
150                let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
151                sources.insert(name, mul_mv_ext_src);
152            }
153        }
154
155        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
156        // prefill Q@K^T and scores@V, modeled on llama.cpp's
157        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
158        // active).  Tile geometry and write-back identical to the
159        // quantized tensor kernel; only the A-stage copy (bfloat →
160        // bfloat, no dequantize) differs.
161        let dense_mm_bf16_tensor_src: &'static str =
162            include_str!("shaders/dense_mm_bf16_tensor.metal");
163        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
164
165        // Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
166        // sibling of dense_mm_bf16_tensor).  Used by hf2q's ADR-005
167        // iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
168        // the BF16 K-stage cast as a confounding variable.  Port of
169        // llama.cpp's kernel_mul_mm_f32_f32 specialization
170        // (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
171        // branch.  Same tile geometry (NR0=64 NR1=32 NK=32) but
172        // float-everywhere shmem staging.
173        let dense_mm_f32_f32_tensor_src: &'static str =
174            include_str!("shaders/dense_mm_f32_f32.metal");
175        sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
176
177        // Dense f16×f32 → f32 tensor-API matmul (F16-staging sibling
178        // of dense_mm_bf16_tensor).  Used by hf2q's ADR-005 Phase 2c
179        // iter-128 gemma4v ViT precision-parity path: every mmproj
180        // weight is stored as F16 in GGUF, peer's `kernel_mul_mm_f16_f32`
181        // (`ggml-metal.metal:10099`) stages BOTH A and B as `half` in
182        // shmem and computes on `simdgroup_half8x8`.  Matches peer
183        // per-element rounding budget exactly (10-bit mantissa vs
184        // BF16's 7-bit), closing the 1.16x/block cascade compound that
185        // iter-127 numerically bisected to BF16 staging.  Same tile
186        // geometry as the BF16 sibling (NR0=64 NR1=32 NK=32, 8 KB
187        // shmem) — half and bfloat share 16-bit storage.
188        let dense_mm_f16_tensor_src: &'static str =
189            include_str!("shaders/dense_mm_f16_tensor.metal");
190        sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
191
192        // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
193        // for M=1 single-token decode.  Port of llama.cpp's
194        // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
195        // Used in apply_linear_projection_f32 when seq_len=1 and the
196        // weight matrix is BF16, replacing the MM kernel (~2× faster for
197        // M=1 due to better memory bandwidth utilization per thread).
198        let dense_gemv_bf16_src: &'static str =
199            include_str!("shaders/dense_gemv_bf16.metal");
200        sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
201
202        // Fused scale-mask-softmax for the non-flash-attention prefill
203        // path.  One row-local threadgroup per (head, query) pair
204        // replaces three separate dispatches (scale, mask-add, softmax);
205        // reads a bf16 mask (-INF at masked positions, matching
206        // flash_attn_prefill_mask.metal) that is shared across heads.
207        let scale_mask_softmax_src: &'static str =
208            include_str!("shaders/scale_mask_softmax.metal");
209        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
210
211        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
212        sources.insert(
213            "quantized_matmul_id".into(),
214            include_str!("shaders/quantized_matmul_id.metal"),
215        );
216
217        // Expert-routed (MoE) GGML block-format quantized matmul kernels
218        let ggml_id_src: &'static str =
219            include_str!("shaders/quantized_matmul_id_ggml.metal");
220        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
221        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
222        // ADR-013 P7 — Q4_K MoE expert-routed mat-vec (port of
223        // llama.cpp's kernel_mul_mv_id_q4_K_f32 at ggml-metal.metal:10349).
224        sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
225        sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
226        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
227        // ADR-022 Phase 1 — Q5_1 / IQ4_NL MoE expert-routed mat-vec.
228        sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
229        sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
230        // Fused-SwiGLU mv_id variants (ADR-012 §Optimize / Task #15):
231        // computes y[r][n] = sum_k(dequant(W[expert][n][k]) * silu(gate[r][k]) * up[r][k])
232        // in one dispatch — replaces silu_mul + expert_down sequence.
233        sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
234
235        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
236        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
237        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
238        // Two-stage dispatch: map0 regroups the token-to-expert table into
239        // per-expert routed-token lists, then mm_id stages a 64x32 expert
240        // weight tile into threadgroup shmem and reuses it across a 32-row
241        // block of that expert's routed tokens.
242        let ggml_id_mm_src: &'static str =
243            include_str!("shaders/quantized_matmul_id_mm.metal");
244        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
245        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
246        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
247        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
248        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
249        // ADR-013 P16 — Q4_K mm_id (port of llama.cpp ggml-metal.metal:10169).
250        sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
251        // ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL mm_id template instantiations.
252        sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
253        sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
254        // ADR-022 Phase 2 — Q5_K mm_id template instantiation.
255        sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
256
257        // MoE-routed quantized matrix-matrix kernels — tensor API variant
258        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
259        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
260        // the mm_id kernel is ported — map0 is a short pre-pass (not
261        // matmul) and continues to use the simdgroup version.
262        let ggml_id_mm_tensor_src: &'static str =
263            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
264        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
265        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
266        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
267        // ADR-013 P16 — Q4_K tensor-API mm_id.
268        sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
269        // ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL tensor-API mm_id.
270        sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
271        sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
272        // ADR-022 Phase 2 — Q5_K tensor-API mm_id.
273        sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
274
275        // Embedding kernels (Story 1.5)
276        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
277        sources.insert("embedding_gather_4bit".into(), embedding_src);
278        sources.insert("embedding_gather_6bit".into(), embedding_src);
279
280        // MoE gate kernel (Story 1.5)
281        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
282        sources.insert("moe_gate".into(), moe_gate_src);
283
284        // MoE dispatch kernels (Story 1.5)
285        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
286        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
287        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
288        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
289        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
290        sources.insert("moe_accumulate".into(), moe_dispatch_src);
291        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
292        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
293        sources.insert("zero_buffer".into(), moe_dispatch_src);
294        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
295        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
296        // bf16 variants (Phase 2 bf16 activation path)
297        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
298        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
299        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
300        // ADR-020 iter-11h-e3a: backward kernels for moe_weighted_sum_seq.
301        sources.insert(
302            "moe_weighted_sum_seq_backward_outputs_f32".into(),
303            moe_dispatch_src,
304        );
305        sources.insert(
306            "moe_weighted_sum_seq_backward_weights_f32".into(),
307            moe_dispatch_src,
308        );
309        // ADR-020 iter-11h-e3b: fused backward kernel for moe_swiglu_seq.
310        sources.insert(
311            "moe_swiglu_seq_backward_f32".into(),
312            moe_dispatch_src,
313        );
314
315        // Batched KV cache copy kernels
316        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
317        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
318        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
319        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
320        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
321        // Wave P4.11 — fused K+V copy variants
322        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
323        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
324        // bf16-source KV cache copy (Phase 2 bf16 activation path)
325        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
326
327        // Elementwise and transpose kernels (Story 1.5)
328        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
329        sources.insert("elementwise_add_f32".into(), elementwise_src);
330        sources.insert("elementwise_add_f16".into(), elementwise_src);
331        sources.insert("elementwise_mul_f32".into(), elementwise_src);
332        sources.insert("elementwise_mul_f16".into(), elementwise_src);
333        sources.insert("elementwise_add_bf16".into(), elementwise_src);
334        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
335        sources.insert("cast_f16_to_f32".into(), elementwise_src);
336        sources.insert("cast_f32_to_f16".into(), elementwise_src);
337        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
338        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
339        sources.insert("scalar_mul_bf16".into(), elementwise_src);
340        sources.insert("scalar_mul_f32".into(), elementwise_src);
341        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
342        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
343        sources.insert("permute_021_bf16".into(), elementwise_src);
344        sources.insert("transpose_last2_bf16".into(), elementwise_src);
345        sources.insert("transpose_last2_f16".into(), elementwise_src);
346        sources.insert("permute_021_f32".into(), elementwise_src);
347        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
348        sources.insert("transpose_2d_f32".into(), elementwise_src);
349        sources.insert("transpose_2d_f16".into(), elementwise_src);
350
351        // Attention kernels (Story 1.3)
352        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
353        sources.insert("sdpa".into(), sdpa_src);
354        sources.insert("sdpa_bf16".into(), sdpa_src);
355        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
356        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
357        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
358
359        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
360        // Ten entry points; all backed by the same shader source.
361        // Pipelines are compiled with function constants via
362        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
363        let flash_attn_prefill_src: &'static str =
364            include_str!("shaders/flash_attn_prefill.metal");
365        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
366        sources.insert(
367            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
368            flash_attn_prefill_src,
369        );
370        sources.insert(
371            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
372            flash_attn_prefill_src,
373        );
374        sources.insert(
375            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
376            flash_attn_prefill_src,
377        );
378        sources.insert(
379            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
380            flash_attn_prefill_src,
381        );
382        sources.insert(
383            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
384            flash_attn_prefill_src,
385        );
386        sources.insert(
387            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
388            flash_attn_prefill_src,
389        );
390        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
391        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
392        // the 32 KB Metal limit (candle sdpa.rs:86-94).
393        sources.insert(
394            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
395            flash_attn_prefill_src,
396        );
397        sources.insert(
398            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
399            flash_attn_prefill_src,
400        );
401        sources.insert(
402            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
403            flash_attn_prefill_src,
404        );
405        sources.insert(
406            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
407            flash_attn_prefill_src,
408        );
409
410        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
411        // (ported from llama.cpp flash_attn_ext_vec)
412        let flash_attn_vec_src: &'static str =
413            include_str!("shaders/flash_attn_vec.metal");
414        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
415        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
416        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
417        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
418        // F16 KV variants (Phase 4a)
419        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
420        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
421
422        // RoPE, normalization, activation kernels (Story 1.4)
423        let rope_src: &'static str = include_str!("shaders/rope.metal");
424        sources.insert("rope_f32".into(), rope_src);
425        sources.insert("rope_f16".into(), rope_src);
426        sources.insert("rope_bf16".into(), rope_src);
427        sources.insert("rope_neox_bf16".into(), rope_src);
428        sources.insert("rope_neox_f32".into(), rope_src);
429        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
430        sources.insert("rms_norm_f32".into(), rms_norm_src);
431        sources.insert("rms_norm_f16".into(), rms_norm_src);
432        sources.insert("rms_norm_bf16".into(), rms_norm_src);
433        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
434        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
435        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
436        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
437        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
438        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
439        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
440        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
441        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
442        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
443        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
444        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
445        sources.insert("l2_norm_f32".into(), l2_norm_src);
446        sources.insert("l2_norm_f16".into(), l2_norm_src);
447        sources.insert("l2_norm_bf16".into(), l2_norm_src);
448        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
449        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
450        sources.insert("cumsum_f32".into(), cumsum_src);
451        sources.insert("cumsum_bf16".into(), cumsum_src);
452        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
453        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
454        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
455        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
456        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
457        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
458        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
459        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
460        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
461        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
462        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
463        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
464        sources.insert("rope_multi_f32".into(), rope_multi_src);
465        sources.insert("rope_multi_bf16".into(), rope_multi_src);
466        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
467        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
468        sources.insert("gated_delta_net_f32".into(), gdn_src);
469        // ADR-015 iter56 — decode-only `simd_sum` variant. Three NSG-templated
470        // host names share the same source; selection is by D_k via
471        // `dispatch_gated_delta_net_decode`. Drop-in for the fused kernel
472        // above when n_tokens=1.
473        let gdn_decode_src: &'static str =
474            include_str!("shaders/gated_delta_net_decode.metal");
475        sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
476        sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
477        sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
478        // Wave 5b — chunk-parallel inter-chunk state-recurrence kernel
479        // (the one new kernel in the chunk-parallel pipeline; spec source:
480        // arXiv 2412.06464 §4 + FLA chunk_delta_h.py:43-298).
481        let gdn_chunk_src: &'static str =
482            include_str!("shaders/gated_delta_net_chunk.metal");
483        sources.insert(
484            "gated_delta_net_chunk_inter_state_bf16".into(),
485            gdn_chunk_src,
486        );
487        // Wave 5b.1 iter 2 — chunk_scaled_dot_kkt kernel (input-side of
488        // the chunk pipeline; spec source: FLA chunk_scaled_dot_kkt.py:36-99).
489        let gdn_kkt_src: &'static str =
490            include_str!("shaders/gated_delta_net_kkt.metal");
491        sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
492        // Wave 5b.1 iter 2 — recompute_w_u_fwd kernel (applies post-solve A
493        // to (β·v) and (β·k·exp(g)) to produce w and u; spec source: FLA
494        // wy_fast.py:29-117).
495        let gdn_recompute_wu_src: &'static str =
496            include_str!("shaders/gated_delta_net_recompute_wu.metal");
497        sources.insert(
498            "gated_delta_net_recompute_wu_bf16".into(),
499            gdn_recompute_wu_src,
500        );
501        // Wave 5b.1 iter 3 — chunk_fwd_o kernel (per-chunk output: closes
502        // the chunk pipeline; spec source: FLA chunk_o.py:42-138).
503        let gdn_chunk_o_src: &'static str =
504            include_str!("shaders/gated_delta_net_chunk_o.metal");
505        sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
506        // Wave 5b.1 iter 4 — orchestrator helper kernels:
507        //   chunk_local_cumsum_g_f32      — per-chunk prefix sum on g [B, T, H]
508        //   chunk_tri_solve_invert_f32    — per-chunk-block (I + A_strict)^-1
509        //                                   on FLA's [B, T, H, BT] layout.
510        let chunk_local_cumsum_g_src: &'static str =
511            include_str!("shaders/chunk_local_cumsum_g.metal");
512        sources.insert(
513            "chunk_local_cumsum_g_f32".into(),
514            chunk_local_cumsum_g_src,
515        );
516        let chunk_tri_solve_invert_src: &'static str =
517            include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
518        sources.insert(
519            "chunk_tri_solve_invert_f32".into(),
520            chunk_tri_solve_invert_src,
521        );
522        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
523        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
524        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
525        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
526        let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
527        sources.insert("silu_mul_f32".into(), silu_mul_src);
528        let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
529        sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
530        let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
531        sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
532        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
533        sources.insert("gelu_f32".into(), gelu_src);
534        sources.insert("gelu_f16".into(), gelu_src);
535        sources.insert("gelu_bf16".into(), gelu_src);
536        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
537        sources.insert("softmax_f32".into(), softmax_src);
538        sources.insert("softmax_f16".into(), softmax_src);
539        sources.insert("softmax_bf16".into(), softmax_src);
540        let softmax_backward_src: &'static str =
541            include_str!("shaders/softmax_backward.metal");
542        sources.insert("softmax_backward_f32".into(), softmax_backward_src);
543        let log_elementwise_src: &'static str =
544            include_str!("shaders/log_elementwise.metal");
545        sources.insert("log_f32".into(), log_elementwise_src);
546        sources.insert("log_backward_f32".into(), log_elementwise_src);
547        let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
548        sources.insert("row_sum_f32".into(), row_sum_src);
549        sources.insert("row_sum_backward_f32".into(), row_sum_src);
550        // ADR-020 iter-10a: GGUF-legacy quantize-dequantize round-trip kernels
551        // (Q4_0 + Q8_0).  Used by hf2q's dynamic_quant Track 1 to produce
552        // W_low / W_high for the gradient-Taylor sensitivity formula.
553        let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
554        sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
555        sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
556        // ADR-020 iter-10b: RMSNorm reverse-mode autograd kernels.
557        // r_inv helper is reused by both backward kernels; dx and dw cover
558        // the full backward identity for `y = x * rsqrt(mean(x²) + eps) * w`.
559        let rms_norm_backward_src: &'static str =
560            include_str!("shaders/rms_norm_backward.metal");
561        sources.insert(
562            "rms_norm_compute_rms_inv_f32".into(),
563            rms_norm_backward_src,
564        );
565        sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
566        sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
567        // ADR-020 iter-11a: 2-D row-major slice + concat-by-column kernels.
568        // Used by hf2q's multi-head SDPA on GpuTape (slice Q/K/V into
569        // per-head views, run per-head SDPA, concat per-head contexts
570        // back to full attention output).
571        let slice_concat_2d_src: &'static str =
572            include_str!("shaders/slice_concat_2d.metal");
573        sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
574        sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
575        // ADR-020 iter-11b: SiLU forward + backward kernels for GpuTape
576        // SwiGLU FFN composition.
577        let silu_backward_src: &'static str =
578            include_str!("shaders/silu_backward.metal");
579        sources.insert("silu_f32".into(), silu_backward_src);
580        sources.insert("silu_backward_f32".into(), silu_backward_src);
581        // ADR-020 iter-11d: FP32 embedding lookup + scatter-add backward.
582        let embedding_autograd_src: &'static str =
583            include_str!("shaders/embedding_autograd.metal");
584        sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
585        sources.insert(
586            "embedding_scatter_add_f32".into(),
587            embedding_autograd_src,
588        );
589        // ADR-020 iter-13a: Adam optimizer step kernel for Track 2
590        // DWQ-proper training loop.
591        let adam_update_src: &'static str =
592            include_str!("shaders/adam_update.metal");
593        sources.insert("adam_update_f32".into(), adam_update_src);
594        // ADR-020 iter-13b: differentiable affine qdq kernels for the
595        // DWQ-proper training loop.  Init + forward + backward (scales,
596        // biases) — q_int is FROZEN, scales+biases learnable.
597        let qdq_affine_src: &'static str =
598            include_str!("shaders/qdq_affine.metal");
599        sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
600        sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
601        sources.insert(
602            "qdq_affine_backward_scales_f32".into(),
603            qdq_affine_src,
604        );
605        sources.insert(
606            "qdq_affine_backward_biases_f32".into(),
607            qdq_affine_src,
608        );
609        // ADR-020 iter-15: fused affine quantized matmul for DWQ inference.
610        // Per-element kernel; one thread per (m, n) output element.
611        // Tiled + simdgroup-MMA variant lands in iter-15b.
612        let qmm_affine_src: &'static str =
613            include_str!("shaders/qmm_affine.metal");
614        sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
615        // ADR-020 iter-15b: tiled variant — 16x16 thread block with
616        // cooperative-load X/W tiles in threadgroup-shared memory for
617        // 2-5x speedup over the per-element kernel.
618        let qmm_affine_tiled_src: &'static str =
619            include_str!("shaders/qmm_affine_tiled.metal");
620        sources.insert(
621            "qmm_affine_t_f32_tiled".into(),
622            qmm_affine_tiled_src,
623        );
624        // ADR-020 iter-15c: simdgroup-MMA variant — uses Apple GPU
625        // hardware `simdgroup_matrix<float, 8, 8>` MMA for the inner
626        // reduction.  Per-tile algorithmic 8× over scalar tiled, lands
627        // as ~3-4× wall after launch / load amortization.
628        let qmm_affine_simd_src: &'static str =
629            include_str!("shaders/qmm_affine_simd.metal");
630        sources.insert(
631            "qmm_affine_t_f32_simd".into(),
632            qmm_affine_simd_src,
633        );
634        // ADR-020 iter-15c-2: 4-simdgroup-per-TG variant — 32×32
635        // output tile, 4 simdgroups arranged as 2×2 grid each owning
636        // a 16×16 sub-tile = 4 simdgroup_matrix accumulators.  Same
637        // math as 15c-1, fuller warp-pool exploitation.
638        let qmm_affine_simd4_src: &'static str =
639            include_str!("shaders/qmm_affine_simd4.metal");
640        sources.insert(
641            "qmm_affine_t_f32_simd4".into(),
642            qmm_affine_simd4_src,
643        );
644        // ADR-020 iter-15c-2b: gs=64 variant (mlx-lm dynamic_quant
645        // canonical default).  Same 4-simdgroup geometry, BK=64
646        // instead of 32 (= 8 sub-K-tiles per K-step instead of 4).
647        let qmm_affine_simd4_gs64_src: &'static str =
648            include_str!("shaders/qmm_affine_simd4_gs64.metal");
649        sources.insert(
650            "qmm_affine_t_f32_simd4_gs64".into(),
651            qmm_affine_simd4_gs64_src,
652        );
653        // ADR-020 AC#5 Iter A: packed-U32 dense affine matmul (bits=4,
654        // gs=32) — production decode/prefill kernel for serving DWQ
655        // safetensors directly without a load-time unpack pass.
656        let qmm_affine_t_packed_simd4_b4_src: &'static str =
657            include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
658        sources.insert(
659            "qmm_affine_t_packed_simd4_b4".into(),
660            qmm_affine_t_packed_simd4_b4_src,
661        );
662        // ADR-020 iter-11h-b: training-mode causal depthwise 1D
663        // convolution (forward + backward dx + backward dw).  Used by
664        // GpuTape autograd for differentiable Qwen3.5MoE forward
665        // (GatedDeltaNet's conv1d step).
666        let conv1d_dwc_src: &'static str =
667            include_str!("shaders/conv1d_depthwise_causal.metal");
668        sources.insert(
669            "conv1d_depthwise_causal_forward_f32".into(),
670            conv1d_dwc_src,
671        );
672        sources.insert(
673            "conv1d_depthwise_causal_backward_dx_f32".into(),
674            conv1d_dwc_src,
675        );
676        sources.insert(
677            "conv1d_depthwise_causal_backward_dw_f32".into(),
678            conv1d_dwc_src,
679        );
680        // ADR-020 iter-11h-c1: elementwise exp forward + backward.
681        // Building block for GatedDeltaNet's alpha = exp(-g) state-decay.
682        let exp_src: &'static str =
683            include_str!("shaders/exp_elementwise.metal");
684        sources.insert("exp_f32".into(), exp_src);
685        sources.insert("exp_backward_f32".into(), exp_src);
686        // ADR-020 iter-11h-c2: vector outer product (forward + dlhs +
687        // drhs).  Building block for gated_delta_update's
688        // outer(delta, k) state-update term.
689        let outer_src: &'static str =
690            include_str!("shaders/outer_product.metal");
691        sources.insert("outer_product_f32".into(), outer_src);
692        sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
693        sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
694        // ADR-020 iter-11h-e1: take_along_axis (gather) + scatter-backward.
695        // Building block for MoE router on GpuTape.
696        let taa_src: &'static str =
697            include_str!("shaders/take_along_axis.metal");
698        sources.insert("take_along_axis_f32".into(), taa_src);
699        sources.insert("take_along_axis_backward_f32".into(), taa_src);
700        // ADR-020 iter-11h-misc-1: elementwise divide forward + backward.
701        let div_src: &'static str =
702            include_str!("shaders/divide_elementwise.metal");
703        sources.insert("divide_f32".into(), div_src);
704        sources.insert("divide_backward_f32".into(), div_src);
705        // ADR-020 iter-11h-misc-3: elementwise sqrt forward + backward.
706        let sqrt_src: &'static str =
707            include_str!("shaders/sqrt_elementwise.metal");
708        sources.insert("sqrt_f32".into(), sqrt_src);
709        sources.insert("sqrt_backward_f32".into(), sqrt_src);
710        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
711        sources.insert("softcap_f32".into(), softcap_src);
712        sources.insert("softcap_f16".into(), softcap_src);
713        sources.insert("softcap_bf16".into(), softcap_src);
714
715        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
716        //   normed = rms_norm(input, weight, eps);  output = residual + normed
717        let fused_norm_add_src: &'static str =
718            include_str!("shaders/fused_norm_add_bf16.metal");
719        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
720        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
721
722        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
723        let fused_hnr_f32_src: &'static str =
724            include_str!("shaders/fused_head_norm_rope_f32.metal");
725        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
726
727        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
728        // Both entry points live in the same .metal file.
729        let fused_hnr_bf16_src: &'static str =
730            include_str!("shaders/fused_head_norm_rope_bf16.metal");
731        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
732        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
733
734        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
735        let fused_norm_add_f32_src: &'static str =
736            include_str!("shaders/fused_norm_add_f32.metal");
737        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
738        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
739        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
740        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
741        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
742        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
743        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
744        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
745
746        // Argsort kernel (Story 2.3) — MoE top-K routing
747        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
748        sources.insert("argsort_desc_f32".into(), argsort_src);
749
750        // Gather / index_select kernel (Story 2.4)
751        let gather_src: &'static str = include_str!("shaders/gather.metal");
752        sources.insert("gather_f32".into(), gather_src);
753
754        // F32 KV cache copy kernel (Session merge S1+S2)
755        let kv_cache_copy_src: &'static str =
756            include_str!("shaders/kv_cache_copy.metal");
757        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
758        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
759
760        // Strided copy kernel (Story 2.5)
761        let copy_src: &'static str = include_str!("shaders/copy.metal");
762        sources.insert("strided_copy_f32".into(), copy_src);
763        sources.insert("offset_copy_f32".into(), copy_src);
764
765        // Fused-QKV split kernel (ADR-005 W-5b.18 — replaces hf2q CPU
766        // download → triple-loop split → 3× upload round-trip in
767        // gpu_delta_net::layer_qkv_deinterleave).
768        let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
769        sources.insert("qkv_split_f32".into(), qkv_split_src);
770
771        // Tiled-GQA broadcast kernel (ADR-005 W-5b.19 — replaces hf2q CPU
772        // tiled-replicate at gpu_delta_net::apply_gated_delta_net_chunk
773        // GQA pre-expansion, ~497 ms / 10.4 ms-per-layer at PP4106).
774        let repeat_tiled_src: &'static str =
775            include_str!("shaders/repeat_tiled.metal");
776        sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
777
778        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
779        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
780        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
781        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
782        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
783        // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
784        sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
785        // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
786        sources.insert("dense_matvec_f32".into(), dense_gemm_src);
787
788        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
789        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
790        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
791        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
792        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
793        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
794        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
795        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
796        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
797
798        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
799        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
800        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
801        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
802        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
803        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
804        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
805
806        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
807        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
808        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
809        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
810        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
811
812        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
813        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
814        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
815        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
816
817        // GPU sampling kernels — eliminate logits readback (Phase 6)
818        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
819        sources.insert("argmax_f32".into(), argmax_src);
820        let softmax_sample_src: &'static str =
821            include_str!("shaders/softmax_sample.metal");
822        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
823        // Top-K kernel for Q8 rerank: avoids full-logits readback.
824        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
825        sources.insert("top_k_f32".into(), top_k_src);
826
827        // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
828        // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
829        let moe_stk_src: &'static str =
830            include_str!("shaders/moe_softmax_topk.metal");
831        sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
832        let moe_wr_src: &'static str =
833            include_str!("shaders/moe_weighted_reduce.metal");
834        sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
835        let sdpa_decode_src: &'static str =
836            include_str!("shaders/sdpa_decode.metal");
837        sources.insert("sdpa_decode".into(), sdpa_decode_src);
838
839        Self {
840            cache: HashMap::new(),
841            sources,
842        }
843    }
844
845    /// Register a shader source at runtime (useful for testing and dynamic
846    /// kernel generation).
847    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
848        let name = name.into();
849        // Invalidate any cached pipeline for this name since the source changed.
850        self.cache.remove(&name);
851        self.sources.insert(name, source);
852    }
853
854    /// Get a compiled compute pipeline for the named kernel function.
855    ///
856    /// On first call for a given name, this compiles the MSL source into a
857    /// Metal library, extracts the named function, and creates a
858    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
859    ///
860    /// # Errors
861    ///
862    /// * `MlxError::KernelNotFound` — no source registered for this name.
863    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
864    ///   creation failed.
865    pub fn get_pipeline(
866        &mut self,
867        name: &str,
868        device: &metal::DeviceRef,
869    ) -> Result<&ComputePipelineState> {
870        if !self.cache.contains_key(name) {
871            // Slow path: compile the shader.
872            let source = self.sources.get(name).ok_or_else(|| {
873                MlxError::KernelNotFound(name.to_string())
874            })?;
875
876            let compile_opts = metal::CompileOptions::new();
877            let library = device
878                .new_library_with_source(source, &compile_opts)
879                .map_err(|msg| MlxError::ShaderCompilationError {
880                    name: name.to_string(),
881                    message: msg,
882                })?;
883
884            let function = library
885                .get_function(name, None)
886                .map_err(|msg| MlxError::ShaderCompilationError {
887                    name: name.to_string(),
888                    message: msg,
889                })?;
890
891            // Build the pipeline through a descriptor so we can attach a
892            // human-readable label.  The label propagates into Instruments /
893            // xctrace Metal System Trace as the per-pipeline identifier
894            // (`metal-object-label` schema), giving us per-kernel attribution
895            // instead of the generic "Compute Command 0" placeholder.
896            //
897            // `MTLComputePipelineState.label` is read-only after creation per
898            // the Apple Metal spec; the only supported way to set it is via
899            // the descriptor before pipeline creation.  ADR-015 iter9b.
900            let descriptor = ComputePipelineDescriptor::new();
901            descriptor.set_compute_function(Some(&function));
902            descriptor.set_label(name);
903
904            let pipeline = device
905                .new_compute_pipeline_state(&descriptor)
906                .map_err(|msg| MlxError::ShaderCompilationError {
907                    name: name.to_string(),
908                    message: msg,
909                })?;
910
911            self.cache.insert(name.to_string(), pipeline);
912        }
913
914        // At this point the pipeline is guaranteed to be in the cache.
915        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
916        self.cache.get(name).ok_or_else(|| {
917            MlxError::KernelNotFound(name.to_string())
918        })
919    }
920
921    /// Get a compiled compute pipeline for the named kernel, specialized with
922    /// Metal function constants (both bool and i32 in one call).
923    ///
924    /// `bool_constants` contains `(index, value)` pairs mapping to
925    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
926    /// `int_constants` contains `(index, value)` pairs mapping to
927    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
928    /// shader.
929    ///
930    /// Pipelines are cached by a composite key:
931    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
932    /// marks bool entries and the 'i' prefix marks i32 entries, making the
933    /// format unambiguous regardless of constant ordering.  Distinct
934    /// `(name, constants)` combinations each compile to a separate pipeline;
935    /// the slow compilation path runs at most once per unique combination.
936    ///
937    /// # Errors
938    ///
939    /// * `MlxError::KernelNotFound` — no source registered for this name.
940    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
941    ///   specialisation, or pipeline creation failed.
942    pub fn get_pipeline_with_constants(
943        &mut self,
944        name: &str,
945        device: &metal::DeviceRef,
946        bool_constants: &[(usize, bool)],
947        int_constants: &[(usize, i32)],
948    ) -> Result<&ComputePipelineState> {
949        // Build a composite cache key so distinct constant combinations each
950        // compile to their own pipeline.  Bool entries use the 'b' type marker
951        // and i32 entries use 'i'; this prevents a collision between, e.g.,
952        // bool index 5 value 1 and int index 5 value 1.
953        let mut cache_key = name.to_string();
954        for &(index, value) in bool_constants {
955            cache_key.push('|');
956            cache_key.push_str(&index.to_string());
957            cache_key.push_str(if value { ":b1" } else { ":b0" });
958        }
959        for &(index, value) in int_constants {
960            cache_key.push('|');
961            cache_key.push_str(&index.to_string());
962            cache_key.push(':');
963            cache_key.push('i');
964            cache_key.push_str(&value.to_string());
965        }
966
967        if !self.cache.contains_key(&cache_key) {
968            // Slow path: compile the shader with function constant specialisation.
969            let source = self.sources.get(name).ok_or_else(|| {
970                MlxError::KernelNotFound(name.to_string())
971            })?;
972
973            let compile_opts = metal::CompileOptions::new();
974            let library = device
975                .new_library_with_source(source, &compile_opts)
976                .map_err(|msg| MlxError::ShaderCompilationError {
977                    name: name.to_string(),
978                    message: msg,
979                })?;
980
981            // Build the FunctionConstantValues object with all bool and i32
982            // constants.  Metal's set_constant_value_at_index reads the value
983            // through a raw pointer; the pointed-to bytes must match the size
984            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
985            let fcv = FunctionConstantValues::new();
986
987            for &(index, value) in bool_constants {
988                // MTLDataType::Bool = 53 (metal-rs argument.rs).
989                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
990                let v: u8 = if value { 1 } else { 0 };
991                fcv.set_constant_value_at_index(
992                    (&v as *const u8).cast::<std::ffi::c_void>(),
993                    MTLDataType::Bool,
994                    index as u64,
995                );
996            }
997
998            for &(index, value) in int_constants {
999                // MTLDataType::Int = 29 (metal-rs argument.rs).
1000                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
1001                // matching the Metal shader type `constant int`.
1002                fcv.set_constant_value_at_index(
1003                    (&value as *const i32).cast::<std::ffi::c_void>(),
1004                    MTLDataType::Int,
1005                    index as u64,
1006                );
1007            }
1008
1009            let function = library
1010                .get_function(name, Some(fcv))
1011                .map_err(|msg| MlxError::ShaderCompilationError {
1012                    name: name.to_string(),
1013                    message: msg,
1014                })?;
1015
1016            // Label this specialisation with the full composite cache key
1017            // (e.g. `kernel_mul_mv_q4_0_f32|0:b1|3:i32`) so xctrace Metal
1018            // System Trace shows each function-constant variant as a distinct
1019            // pipeline.  Without this, all specialisations share a generic
1020            // "Compute Command 0" identifier and we cannot attribute µs/token
1021            // to a specific (kernel, constants) combination.  ADR-015 iter9b.
1022            let descriptor = ComputePipelineDescriptor::new();
1023            descriptor.set_compute_function(Some(&function));
1024            descriptor.set_label(&cache_key);
1025
1026            let pipeline = device
1027                .new_compute_pipeline_state(&descriptor)
1028                .map_err(|msg| MlxError::ShaderCompilationError {
1029                    name: name.to_string(),
1030                    message: msg,
1031                })?;
1032
1033            self.cache.insert(cache_key.clone(), pipeline);
1034        }
1035
1036        self.cache.get(&cache_key).ok_or_else(|| {
1037            MlxError::KernelNotFound(name.to_string())
1038        })
1039    }
1040
1041    /// Get a compiled compute pipeline for the named kernel, specialized with
1042    /// Metal bool function constants.
1043    ///
1044    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
1045    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
1046    ///
1047    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
1048    /// passes an empty `int_constants` slice.  Existing callers continue to
1049    /// work without modification; the cache-key format for pure-bool pipelines
1050    /// is compatible (bool entries carry the 'b' type marker, which is the
1051    /// only format ever written by this wrapper).
1052    ///
1053    /// # Errors
1054    ///
1055    /// * `MlxError::KernelNotFound` — no source registered for this name.
1056    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
1057    ///   specialisation, or pipeline creation failed.
1058    pub fn get_pipeline_with_bool_constants(
1059        &mut self,
1060        name: &str,
1061        device: &metal::DeviceRef,
1062        bool_constants: &[(usize, bool)],
1063    ) -> Result<&ComputePipelineState> {
1064        self.get_pipeline_with_constants(name, device, bool_constants, &[])
1065    }
1066
1067    /// Check if a pipeline for the given name is already compiled and cached.
1068    pub fn is_cached(&self, name: &str) -> bool {
1069        self.cache.contains_key(name)
1070    }
1071
1072    /// Number of compiled pipelines currently in the cache.
1073    pub fn cached_count(&self) -> usize {
1074        self.cache.len()
1075    }
1076
1077    /// Number of registered shader sources.
1078    pub fn source_count(&self) -> usize {
1079        self.sources.len()
1080    }
1081}
1082
1083impl Default for KernelRegistry {
1084    fn default() -> Self {
1085        Self::new()
1086    }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092
1093    /// Minimal Metal shader that uses a single int function constant.
1094    ///
1095    /// The kernel writes the constant value N into the first element of the
1096    /// output buffer, allowing the test to verify that the Metal compiler
1097    /// actually sees distinct specialisations for N=4 and N=8.
1098    ///
1099    /// The shader is intentionally trivial — we only need it to *compile* with
1100    /// an int function constant; correctness of the kernel logic is not under
1101    /// test here.
1102    const INT_FC_TEST_SHADER: &str = r#"
1103#include <metal_stdlib>
1104using namespace metal;
1105
1106constant int test_N [[function_constant(100)]];
1107
1108kernel void int_fc_test_kernel(
1109    device int* out [[buffer(0)]],
1110    uint tid [[thread_position_in_grid]])
1111{
1112    if (tid == 0) {
1113        out[0] = test_N;
1114    }
1115}
1116"#;
1117
1118    /// Verify that `get_pipeline_with_constants` produces distinct cached
1119    /// pipelines for different i32 function-constant values, and that
1120    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
1121    /// works correctly with the new 'b'-prefixed cache-key format.
1122    ///
1123    /// This test requires a real Metal device and is therefore marked
1124    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
1125    #[test]
1126    fn test_int_fc_distinct_pipelines_and_bool_compat() {
1127        let device = metal::Device::system_default()
1128            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1129
1130        let mut registry = KernelRegistry::new();
1131
1132        // Register the inline test shader under a name that cannot collide with
1133        // any production kernel.
1134        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1135
1136        // Compile with N=4.
1137        let p4_ptr = registry
1138            .get_pipeline_with_constants(
1139                "int_fc_test_kernel",
1140                &device,
1141                &[],                  // no bool constants
1142                &[(100, 4_i32)],      // int constant index 100 = 4
1143            )
1144            .expect("pipeline N=4 should compile") as *const _;
1145
1146        // Cache must now have exactly 1 entry for this kernel.
1147        // (Other production kernels may already be in cache from new(); here
1148        // we check that the N=4 key was inserted.)
1149        let count_after_n4 = registry.cached_count();
1150
1151        // Compile with N=8 — must produce a SEPARATE pipeline.
1152        let p8_ptr = registry
1153            .get_pipeline_with_constants(
1154                "int_fc_test_kernel",
1155                &device,
1156                &[],
1157                &[(100, 8_i32)],
1158            )
1159            .expect("pipeline N=8 should compile") as *const _;
1160
1161        // Cache must have grown by exactly 1.
1162        assert_eq!(
1163            registry.cached_count(),
1164            count_after_n4 + 1,
1165            "N=8 must produce a new cache entry"
1166        );
1167
1168        // The two pipelines must be distinct objects in the cache.
1169        assert_ne!(
1170            p4_ptr, p8_ptr,
1171            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1172        );
1173
1174        // A second call with N=4 must return the SAME pipeline (cache hit, no
1175        // new compilation).
1176        let p4_again_ptr = registry
1177            .get_pipeline_with_constants(
1178                "int_fc_test_kernel",
1179                &device,
1180                &[],
1181                &[(100, 4_i32)],
1182            )
1183            .expect("pipeline N=4 cache hit should succeed") as *const _;
1184
1185        assert_eq!(
1186            registry.cached_count(),
1187            count_after_n4 + 1,
1188            "repeated N=4 call must be a cache hit, not a new entry"
1189        );
1190        assert_eq!(
1191            p4_ptr, p4_again_ptr,
1192            "repeated N=4 call must return the same pipeline pointer"
1193        );
1194
1195        // Verify backward compatibility: get_pipeline_with_bool_constants must
1196        // still route through get_pipeline_with_constants and produce a cached
1197        // pipeline without panicking.
1198        //
1199        // We register a separate bool-constant shader that does NOT use a bool
1200        // function constant (so the Metal compiler ignores missing FCs for
1201        // this trivial case) — but the call path and cache-key format are what
1202        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
1203        // simply unused by the shader (Metal allows unused FCs when the shader
1204        // declares them with `function_constant` but the value is never read).
1205        //
1206        // To avoid a Metal compiler error for an undeclared function constant,
1207        // we register a separate bare-kernel shader for the bool wrapper test.
1208        const BARE_SHADER: &str = r#"
1209#include <metal_stdlib>
1210using namespace metal;
1211kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1212    if (tid == 0) { out[0] = 42; }
1213}
1214"#;
1215        registry.register_source("bare_kernel", BARE_SHADER);
1216
1217        let count_before_bool = registry.cached_count();
1218        let _bool_pipeline = registry
1219            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1220            .expect("bool-constants wrapper with empty slice must succeed");
1221
1222        assert_eq!(
1223            registry.cached_count(),
1224            count_before_bool + 1,
1225            "bool-constants wrapper must insert one new cache entry"
1226        );
1227    }
1228
1229    /// Verify that the `MTLComputePipelineState.label` produced by
1230    /// `get_pipeline` and `get_pipeline_with_constants` actually propagates
1231    /// from the descriptor to the resulting pipeline state.
1232    ///
1233    /// This is the in-process smoke check for ADR-015 iter9b: we cannot
1234    /// reach into xctrace from Rust, but we can read back the same `label`
1235    /// property xctrace consumes via `ComputePipelineStateRef::label()`.
1236    /// If labels are missing or wrong here, the MST trace will also show
1237    /// generic identifiers — so this test gates the iter9 retry's
1238    /// per-Q4_0-kernel attribution.
1239    #[test]
1240    fn test_pipeline_labels_propagate_for_mst() {
1241        let device = metal::Device::system_default()
1242            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1243
1244        let mut registry = KernelRegistry::new();
1245
1246        // Reuse the same trivial shaders as the int-FC test.
1247        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1248
1249        const BARE_SHADER_LABEL_TEST: &str = r#"
1250#include <metal_stdlib>
1251using namespace metal;
1252kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1253    if (tid == 0) { out[0] = 7; }
1254}
1255"#;
1256        registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1257
1258        // Plain get_pipeline path — label must equal the kernel name.
1259        // Capture as owned String so the cache borrow is released before
1260        // the next get_pipeline_with_constants call below.
1261        let plain_label = registry
1262            .get_pipeline("label_smoke_kernel", &device)
1263            .expect("plain pipeline must compile")
1264            .label()
1265            .to_string();
1266        assert_eq!(
1267            plain_label, "label_smoke_kernel",
1268            "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1269        );
1270
1271        // Constants path — label must equal the composite cache key so each
1272        // function-constant variant is individually attributable in MST.
1273        // We capture the label as an owned String to release the borrow on
1274        // the cache before fetching the next specialisation.
1275        let label_v7 = registry
1276            .get_pipeline_with_constants(
1277                "int_fc_test_kernel",
1278                &device,
1279                &[],
1280                &[(100, 7_i32)],
1281            )
1282            .expect("specialised pipeline must compile")
1283            .label()
1284            .to_string();
1285        assert_eq!(
1286            label_v7, "int_fc_test_kernel|100:i7",
1287            "get_pipeline_with_constants must label with the cache_key so each \
1288             specialisation is distinct in xctrace MST"
1289        );
1290
1291        // A second specialisation must produce a different label.
1292        let label_v13 = registry
1293            .get_pipeline_with_constants(
1294                "int_fc_test_kernel",
1295                &device,
1296                &[],
1297                &[(100, 13_i32)],
1298            )
1299            .expect("second specialised pipeline must compile")
1300            .label()
1301            .to_string();
1302        assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1303        assert_ne!(
1304            label_v7, label_v13,
1305            "distinct constant values must yield distinct pipeline labels"
1306        );
1307    }
1308}