Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81        // GGML block-format quantized matrix-matrix kernels
82        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
83        // Used at prefill m > 8 to reuse each weight tile across a 32-row
84        // block via threadgroup-staged simdgroup MMA, instead of re-reading
85        // every block per prompt-token as the mv kernel does.
86        let ggml_mm_src: &'static str =
87            include_str!("shaders/quantized_matmul_mm.metal");
88        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92        // GGML block-format quantized matrix-matrix kernels — tensor API
93        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
94        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
95        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
96        // primitive which on M3+ dispatches to hardware tensor cores for
97        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
98        // Only compiled on devices where the tensor API is available; the
99        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
100        // compilation so non-tensor devices transparently fall back to the
101        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
102        let ggml_mm_tensor_src: &'static str =
103            include_str!("shaders/quantized_matmul_mm_tensor.metal");
104        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
111        // prefill Q@K^T and scores@V, modeled on llama.cpp's
112        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
113        // active).  Tile geometry and write-back identical to the
114        // quantized tensor kernel; only the A-stage copy (bfloat →
115        // bfloat, no dequantize) differs.
116        let dense_mm_bf16_tensor_src: &'static str =
117            include_str!("shaders/dense_mm_bf16_tensor.metal");
118        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120        // Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
121        // sibling of dense_mm_bf16_tensor).  Used by hf2q's ADR-005
122        // iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
123        // the BF16 K-stage cast as a confounding variable.  Port of
124        // llama.cpp's kernel_mul_mm_f32_f32 specialization
125        // (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
126        // branch.  Same tile geometry (NR0=64 NR1=32 NK=32) but
127        // float-everywhere shmem staging.
128        let dense_mm_f32_f32_tensor_src: &'static str =
129            include_str!("shaders/dense_mm_f32_f32.metal");
130        sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
131
132        // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
133        // for M=1 single-token decode.  Port of llama.cpp's
134        // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
135        // Used in apply_linear_projection_f32 when seq_len=1 and the
136        // weight matrix is BF16, replacing the MM kernel (~2× faster for
137        // M=1 due to better memory bandwidth utilization per thread).
138        let dense_gemv_bf16_src: &'static str =
139            include_str!("shaders/dense_gemv_bf16.metal");
140        sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
141
142        // Fused scale-mask-softmax for the non-flash-attention prefill
143        // path.  One row-local threadgroup per (head, query) pair
144        // replaces three separate dispatches (scale, mask-add, softmax);
145        // reads a bf16 mask (-INF at masked positions, matching
146        // flash_attn_prefill_mask.metal) that is shared across heads.
147        let scale_mask_softmax_src: &'static str =
148            include_str!("shaders/scale_mask_softmax.metal");
149        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
150
151        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
152        sources.insert(
153            "quantized_matmul_id".into(),
154            include_str!("shaders/quantized_matmul_id.metal"),
155        );
156
157        // Expert-routed (MoE) GGML block-format quantized matmul kernels
158        let ggml_id_src: &'static str =
159            include_str!("shaders/quantized_matmul_id_ggml.metal");
160        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
161        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
162        sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
163        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
164
165        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
166        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
167        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
168        // Two-stage dispatch: map0 regroups the token-to-expert table into
169        // per-expert routed-token lists, then mm_id stages a 64x32 expert
170        // weight tile into threadgroup shmem and reuses it across a 32-row
171        // block of that expert's routed tokens.
172        let ggml_id_mm_src: &'static str =
173            include_str!("shaders/quantized_matmul_id_mm.metal");
174        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
175        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
176        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
177        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
178        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
179
180        // MoE-routed quantized matrix-matrix kernels — tensor API variant
181        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
182        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
183        // the mm_id kernel is ported — map0 is a short pre-pass (not
184        // matmul) and continues to use the simdgroup version.
185        let ggml_id_mm_tensor_src: &'static str =
186            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
187        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
188        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
189        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
190
191        // Embedding kernels (Story 1.5)
192        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
193        sources.insert("embedding_gather_4bit".into(), embedding_src);
194        sources.insert("embedding_gather_6bit".into(), embedding_src);
195
196        // MoE gate kernel (Story 1.5)
197        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
198        sources.insert("moe_gate".into(), moe_gate_src);
199
200        // MoE dispatch kernels (Story 1.5)
201        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
202        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
203        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
204        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
205        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
206        sources.insert("moe_accumulate".into(), moe_dispatch_src);
207        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
208        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
209        sources.insert("zero_buffer".into(), moe_dispatch_src);
210        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
211        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
212        // bf16 variants (Phase 2 bf16 activation path)
213        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
214        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
215        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
216
217        // Batched KV cache copy kernels
218        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
219        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
220        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
221        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
222        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
223        // Wave P4.11 — fused K+V copy variants
224        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
225        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
226        // bf16-source KV cache copy (Phase 2 bf16 activation path)
227        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
228
229        // Elementwise and transpose kernels (Story 1.5)
230        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
231        sources.insert("elementwise_add_f32".into(), elementwise_src);
232        sources.insert("elementwise_add_f16".into(), elementwise_src);
233        sources.insert("elementwise_mul_f32".into(), elementwise_src);
234        sources.insert("elementwise_mul_f16".into(), elementwise_src);
235        sources.insert("elementwise_add_bf16".into(), elementwise_src);
236        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
237        sources.insert("cast_f16_to_f32".into(), elementwise_src);
238        sources.insert("cast_f32_to_f16".into(), elementwise_src);
239        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
240        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
241        sources.insert("scalar_mul_bf16".into(), elementwise_src);
242        sources.insert("scalar_mul_f32".into(), elementwise_src);
243        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
244        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
245        sources.insert("permute_021_bf16".into(), elementwise_src);
246        sources.insert("transpose_last2_bf16".into(), elementwise_src);
247        sources.insert("permute_021_f32".into(), elementwise_src);
248        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
249        sources.insert("transpose_2d_f32".into(), elementwise_src);
250        sources.insert("transpose_2d_f16".into(), elementwise_src);
251
252        // Attention kernels (Story 1.3)
253        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
254        sources.insert("sdpa".into(), sdpa_src);
255        sources.insert("sdpa_bf16".into(), sdpa_src);
256        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
257        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
258        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
259
260        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
261        // Ten entry points; all backed by the same shader source.
262        // Pipelines are compiled with function constants via
263        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
264        let flash_attn_prefill_src: &'static str =
265            include_str!("shaders/flash_attn_prefill.metal");
266        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
267        sources.insert(
268            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
269            flash_attn_prefill_src,
270        );
271        sources.insert(
272            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
273            flash_attn_prefill_src,
274        );
275        sources.insert(
276            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
277            flash_attn_prefill_src,
278        );
279        sources.insert(
280            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
281            flash_attn_prefill_src,
282        );
283        sources.insert(
284            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
285            flash_attn_prefill_src,
286        );
287        sources.insert(
288            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
289            flash_attn_prefill_src,
290        );
291        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
292        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
293        // the 32 KB Metal limit (candle sdpa.rs:86-94).
294        sources.insert(
295            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
296            flash_attn_prefill_src,
297        );
298        sources.insert(
299            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
300            flash_attn_prefill_src,
301        );
302        sources.insert(
303            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
304            flash_attn_prefill_src,
305        );
306        sources.insert(
307            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
308            flash_attn_prefill_src,
309        );
310
311        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
312        // (ported from llama.cpp flash_attn_ext_vec)
313        let flash_attn_vec_src: &'static str =
314            include_str!("shaders/flash_attn_vec.metal");
315        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
316        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
317        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
318        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
319        // F16 KV variants (Phase 4a)
320        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
321        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
322
323        // RoPE, normalization, activation kernels (Story 1.4)
324        let rope_src: &'static str = include_str!("shaders/rope.metal");
325        sources.insert("rope_f32".into(), rope_src);
326        sources.insert("rope_f16".into(), rope_src);
327        sources.insert("rope_bf16".into(), rope_src);
328        sources.insert("rope_neox_bf16".into(), rope_src);
329        sources.insert("rope_neox_f32".into(), rope_src);
330        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
331        sources.insert("rms_norm_f32".into(), rms_norm_src);
332        sources.insert("rms_norm_f16".into(), rms_norm_src);
333        sources.insert("rms_norm_bf16".into(), rms_norm_src);
334        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
335        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
336        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
337        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
338        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
339        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
340        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
341        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
342        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
343        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
344        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
345        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
346        sources.insert("l2_norm_f32".into(), l2_norm_src);
347        sources.insert("l2_norm_f16".into(), l2_norm_src);
348        sources.insert("l2_norm_bf16".into(), l2_norm_src);
349        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
350        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
351        sources.insert("cumsum_f32".into(), cumsum_src);
352        sources.insert("cumsum_bf16".into(), cumsum_src);
353        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
354        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
355        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
356        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
357        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
358        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
359        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
360        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
361        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
362        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
363        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
364        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
365        sources.insert("rope_multi_f32".into(), rope_multi_src);
366        sources.insert("rope_multi_bf16".into(), rope_multi_src);
367        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
368        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
369        sources.insert("gated_delta_net_f32".into(), gdn_src);
370        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
371        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
372        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
373        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
374        let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
375        sources.insert("silu_mul_f32".into(), silu_mul_src);
376        let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
377        sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
378        let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
379        sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
380        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
381        sources.insert("gelu_f32".into(), gelu_src);
382        sources.insert("gelu_f16".into(), gelu_src);
383        sources.insert("gelu_bf16".into(), gelu_src);
384        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
385        sources.insert("softmax_f32".into(), softmax_src);
386        sources.insert("softmax_f16".into(), softmax_src);
387        sources.insert("softmax_bf16".into(), softmax_src);
388        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
389        sources.insert("softcap_f32".into(), softcap_src);
390        sources.insert("softcap_f16".into(), softcap_src);
391        sources.insert("softcap_bf16".into(), softcap_src);
392
393        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
394        //   normed = rms_norm(input, weight, eps);  output = residual + normed
395        let fused_norm_add_src: &'static str =
396            include_str!("shaders/fused_norm_add_bf16.metal");
397        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
398        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
399
400        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
401        let fused_hnr_f32_src: &'static str =
402            include_str!("shaders/fused_head_norm_rope_f32.metal");
403        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
404
405        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
406        // Both entry points live in the same .metal file.
407        let fused_hnr_bf16_src: &'static str =
408            include_str!("shaders/fused_head_norm_rope_bf16.metal");
409        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
410        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
411
412        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
413        let fused_norm_add_f32_src: &'static str =
414            include_str!("shaders/fused_norm_add_f32.metal");
415        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
416        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
417        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
418        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
419        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
420        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
421        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
422        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
423
424        // Argsort kernel (Story 2.3) — MoE top-K routing
425        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
426        sources.insert("argsort_desc_f32".into(), argsort_src);
427
428        // Gather / index_select kernel (Story 2.4)
429        let gather_src: &'static str = include_str!("shaders/gather.metal");
430        sources.insert("gather_f32".into(), gather_src);
431
432        // F32 KV cache copy kernel (Session merge S1+S2)
433        let kv_cache_copy_src: &'static str =
434            include_str!("shaders/kv_cache_copy.metal");
435        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
436        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
437
438        // Strided copy kernel (Story 2.5)
439        let copy_src: &'static str = include_str!("shaders/copy.metal");
440        sources.insert("strided_copy_f32".into(), copy_src);
441        sources.insert("offset_copy_f32".into(), copy_src);
442
443        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
444        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
445        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
446        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
447        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
448        // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
449        sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
450        // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
451        sources.insert("dense_matvec_f32".into(), dense_gemm_src);
452
453        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
454        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
455        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
456        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
457        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
458        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
459        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
460        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
461        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
462
463        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
464        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
465        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
466        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
467        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
468        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
469        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
470
471        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
472        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
473        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
474        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
475        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
476
477        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
478        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
479        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
480        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
481
482        // GPU sampling kernels — eliminate logits readback (Phase 6)
483        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
484        sources.insert("argmax_f32".into(), argmax_src);
485        let softmax_sample_src: &'static str =
486            include_str!("shaders/softmax_sample.metal");
487        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
488        // Top-K kernel for Q8 rerank: avoids full-logits readback.
489        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
490        sources.insert("top_k_f32".into(), top_k_src);
491
492        // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
493        // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
494        let moe_stk_src: &'static str =
495            include_str!("shaders/moe_softmax_topk.metal");
496        sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
497        let moe_wr_src: &'static str =
498            include_str!("shaders/moe_weighted_reduce.metal");
499        sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
500        let sdpa_decode_src: &'static str =
501            include_str!("shaders/sdpa_decode.metal");
502        sources.insert("sdpa_decode".into(), sdpa_decode_src);
503
504        Self {
505            cache: HashMap::new(),
506            sources,
507        }
508    }
509
510    /// Register a shader source at runtime (useful for testing and dynamic
511    /// kernel generation).
512    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
513        let name = name.into();
514        // Invalidate any cached pipeline for this name since the source changed.
515        self.cache.remove(&name);
516        self.sources.insert(name, source);
517    }
518
519    /// Get a compiled compute pipeline for the named kernel function.
520    ///
521    /// On first call for a given name, this compiles the MSL source into a
522    /// Metal library, extracts the named function, and creates a
523    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
524    ///
525    /// # Errors
526    ///
527    /// * `MlxError::KernelNotFound` — no source registered for this name.
528    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
529    ///   creation failed.
530    pub fn get_pipeline(
531        &mut self,
532        name: &str,
533        device: &metal::DeviceRef,
534    ) -> Result<&ComputePipelineState> {
535        if !self.cache.contains_key(name) {
536            // Slow path: compile the shader.
537            let source = self.sources.get(name).ok_or_else(|| {
538                MlxError::KernelNotFound(name.to_string())
539            })?;
540
541            let compile_opts = metal::CompileOptions::new();
542            let library = device
543                .new_library_with_source(source, &compile_opts)
544                .map_err(|msg| MlxError::ShaderCompilationError {
545                    name: name.to_string(),
546                    message: msg,
547                })?;
548
549            let function = library
550                .get_function(name, None)
551                .map_err(|msg| MlxError::ShaderCompilationError {
552                    name: name.to_string(),
553                    message: msg,
554                })?;
555
556            let pipeline = device
557                .new_compute_pipeline_state_with_function(&function)
558                .map_err(|msg| MlxError::ShaderCompilationError {
559                    name: name.to_string(),
560                    message: msg,
561                })?;
562
563            self.cache.insert(name.to_string(), pipeline);
564        }
565
566        // At this point the pipeline is guaranteed to be in the cache.
567        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
568        self.cache.get(name).ok_or_else(|| {
569            MlxError::KernelNotFound(name.to_string())
570        })
571    }
572
573    /// Get a compiled compute pipeline for the named kernel, specialized with
574    /// Metal function constants (both bool and i32 in one call).
575    ///
576    /// `bool_constants` contains `(index, value)` pairs mapping to
577    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
578    /// `int_constants` contains `(index, value)` pairs mapping to
579    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
580    /// shader.
581    ///
582    /// Pipelines are cached by a composite key:
583    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
584    /// marks bool entries and the 'i' prefix marks i32 entries, making the
585    /// format unambiguous regardless of constant ordering.  Distinct
586    /// `(name, constants)` combinations each compile to a separate pipeline;
587    /// the slow compilation path runs at most once per unique combination.
588    ///
589    /// # Errors
590    ///
591    /// * `MlxError::KernelNotFound` — no source registered for this name.
592    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
593    ///   specialisation, or pipeline creation failed.
594    pub fn get_pipeline_with_constants(
595        &mut self,
596        name: &str,
597        device: &metal::DeviceRef,
598        bool_constants: &[(usize, bool)],
599        int_constants: &[(usize, i32)],
600    ) -> Result<&ComputePipelineState> {
601        // Build a composite cache key so distinct constant combinations each
602        // compile to their own pipeline.  Bool entries use the 'b' type marker
603        // and i32 entries use 'i'; this prevents a collision between, e.g.,
604        // bool index 5 value 1 and int index 5 value 1.
605        let mut cache_key = name.to_string();
606        for &(index, value) in bool_constants {
607            cache_key.push('|');
608            cache_key.push_str(&index.to_string());
609            cache_key.push_str(if value { ":b1" } else { ":b0" });
610        }
611        for &(index, value) in int_constants {
612            cache_key.push('|');
613            cache_key.push_str(&index.to_string());
614            cache_key.push(':');
615            cache_key.push('i');
616            cache_key.push_str(&value.to_string());
617        }
618
619        if !self.cache.contains_key(&cache_key) {
620            // Slow path: compile the shader with function constant specialisation.
621            let source = self.sources.get(name).ok_or_else(|| {
622                MlxError::KernelNotFound(name.to_string())
623            })?;
624
625            let compile_opts = metal::CompileOptions::new();
626            let library = device
627                .new_library_with_source(source, &compile_opts)
628                .map_err(|msg| MlxError::ShaderCompilationError {
629                    name: name.to_string(),
630                    message: msg,
631                })?;
632
633            // Build the FunctionConstantValues object with all bool and i32
634            // constants.  Metal's set_constant_value_at_index reads the value
635            // through a raw pointer; the pointed-to bytes must match the size
636            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
637            let fcv = FunctionConstantValues::new();
638
639            for &(index, value) in bool_constants {
640                // MTLDataType::Bool = 53 (metal-rs argument.rs).
641                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
642                let v: u8 = if value { 1 } else { 0 };
643                fcv.set_constant_value_at_index(
644                    (&v as *const u8).cast::<std::ffi::c_void>(),
645                    MTLDataType::Bool,
646                    index as u64,
647                );
648            }
649
650            for &(index, value) in int_constants {
651                // MTLDataType::Int = 29 (metal-rs argument.rs).
652                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
653                // matching the Metal shader type `constant int`.
654                fcv.set_constant_value_at_index(
655                    (&value as *const i32).cast::<std::ffi::c_void>(),
656                    MTLDataType::Int,
657                    index as u64,
658                );
659            }
660
661            let function = library
662                .get_function(name, Some(fcv))
663                .map_err(|msg| MlxError::ShaderCompilationError {
664                    name: name.to_string(),
665                    message: msg,
666                })?;
667
668            let pipeline = device
669                .new_compute_pipeline_state_with_function(&function)
670                .map_err(|msg| MlxError::ShaderCompilationError {
671                    name: name.to_string(),
672                    message: msg,
673                })?;
674
675            self.cache.insert(cache_key.clone(), pipeline);
676        }
677
678        self.cache.get(&cache_key).ok_or_else(|| {
679            MlxError::KernelNotFound(name.to_string())
680        })
681    }
682
683    /// Get a compiled compute pipeline for the named kernel, specialized with
684    /// Metal bool function constants.
685    ///
686    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
687    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
688    ///
689    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
690    /// passes an empty `int_constants` slice.  Existing callers continue to
691    /// work without modification; the cache-key format for pure-bool pipelines
692    /// is compatible (bool entries carry the 'b' type marker, which is the
693    /// only format ever written by this wrapper).
694    ///
695    /// # Errors
696    ///
697    /// * `MlxError::KernelNotFound` — no source registered for this name.
698    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
699    ///   specialisation, or pipeline creation failed.
700    pub fn get_pipeline_with_bool_constants(
701        &mut self,
702        name: &str,
703        device: &metal::DeviceRef,
704        bool_constants: &[(usize, bool)],
705    ) -> Result<&ComputePipelineState> {
706        self.get_pipeline_with_constants(name, device, bool_constants, &[])
707    }
708
709    /// Check if a pipeline for the given name is already compiled and cached.
710    pub fn is_cached(&self, name: &str) -> bool {
711        self.cache.contains_key(name)
712    }
713
714    /// Number of compiled pipelines currently in the cache.
715    pub fn cached_count(&self) -> usize {
716        self.cache.len()
717    }
718
719    /// Number of registered shader sources.
720    pub fn source_count(&self) -> usize {
721        self.sources.len()
722    }
723}
724
725impl Default for KernelRegistry {
726    fn default() -> Self {
727        Self::new()
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734
735    /// Minimal Metal shader that uses a single int function constant.
736    ///
737    /// The kernel writes the constant value N into the first element of the
738    /// output buffer, allowing the test to verify that the Metal compiler
739    /// actually sees distinct specialisations for N=4 and N=8.
740    ///
741    /// The shader is intentionally trivial — we only need it to *compile* with
742    /// an int function constant; correctness of the kernel logic is not under
743    /// test here.
744    const INT_FC_TEST_SHADER: &str = r#"
745#include <metal_stdlib>
746using namespace metal;
747
748constant int test_N [[function_constant(100)]];
749
750kernel void int_fc_test_kernel(
751    device int* out [[buffer(0)]],
752    uint tid [[thread_position_in_grid]])
753{
754    if (tid == 0) {
755        out[0] = test_N;
756    }
757}
758"#;
759
760    /// Verify that `get_pipeline_with_constants` produces distinct cached
761    /// pipelines for different i32 function-constant values, and that
762    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
763    /// works correctly with the new 'b'-prefixed cache-key format.
764    ///
765    /// This test requires a real Metal device and is therefore marked
766    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
767    #[test]
768    fn test_int_fc_distinct_pipelines_and_bool_compat() {
769        let device = metal::Device::system_default()
770            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
771
772        let mut registry = KernelRegistry::new();
773
774        // Register the inline test shader under a name that cannot collide with
775        // any production kernel.
776        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
777
778        // Compile with N=4.
779        let p4_ptr = registry
780            .get_pipeline_with_constants(
781                "int_fc_test_kernel",
782                &device,
783                &[],                  // no bool constants
784                &[(100, 4_i32)],      // int constant index 100 = 4
785            )
786            .expect("pipeline N=4 should compile") as *const _;
787
788        // Cache must now have exactly 1 entry for this kernel.
789        // (Other production kernels may already be in cache from new(); here
790        // we check that the N=4 key was inserted.)
791        let count_after_n4 = registry.cached_count();
792
793        // Compile with N=8 — must produce a SEPARATE pipeline.
794        let p8_ptr = registry
795            .get_pipeline_with_constants(
796                "int_fc_test_kernel",
797                &device,
798                &[],
799                &[(100, 8_i32)],
800            )
801            .expect("pipeline N=8 should compile") as *const _;
802
803        // Cache must have grown by exactly 1.
804        assert_eq!(
805            registry.cached_count(),
806            count_after_n4 + 1,
807            "N=8 must produce a new cache entry"
808        );
809
810        // The two pipelines must be distinct objects in the cache.
811        assert_ne!(
812            p4_ptr, p8_ptr,
813            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
814        );
815
816        // A second call with N=4 must return the SAME pipeline (cache hit, no
817        // new compilation).
818        let p4_again_ptr = registry
819            .get_pipeline_with_constants(
820                "int_fc_test_kernel",
821                &device,
822                &[],
823                &[(100, 4_i32)],
824            )
825            .expect("pipeline N=4 cache hit should succeed") as *const _;
826
827        assert_eq!(
828            registry.cached_count(),
829            count_after_n4 + 1,
830            "repeated N=4 call must be a cache hit, not a new entry"
831        );
832        assert_eq!(
833            p4_ptr, p4_again_ptr,
834            "repeated N=4 call must return the same pipeline pointer"
835        );
836
837        // Verify backward compatibility: get_pipeline_with_bool_constants must
838        // still route through get_pipeline_with_constants and produce a cached
839        // pipeline without panicking.
840        //
841        // We register a separate bool-constant shader that does NOT use a bool
842        // function constant (so the Metal compiler ignores missing FCs for
843        // this trivial case) — but the call path and cache-key format are what
844        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
845        // simply unused by the shader (Metal allows unused FCs when the shader
846        // declares them with `function_constant` but the value is never read).
847        //
848        // To avoid a Metal compiler error for an undeclared function constant,
849        // we register a separate bare-kernel shader for the bool wrapper test.
850        const BARE_SHADER: &str = r#"
851#include <metal_stdlib>
852using namespace metal;
853kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
854    if (tid == 0) { out[0] = 42; }
855}
856"#;
857        registry.register_source("bare_kernel", BARE_SHADER);
858
859        let count_before_bool = registry.cached_count();
860        let _bool_pipeline = registry
861            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
862            .expect("bool-constants wrapper with empty slice must succeed");
863
864        assert_eq!(
865            registry.cached_count(),
866            count_before_bool + 1,
867            "bool-constants wrapper must insert one new cache entry"
868        );
869    }
870}