Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81        // GGML block-format quantized matrix-matrix kernels
82        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
83        // Used at prefill m > 8 to reuse each weight tile across a 32-row
84        // block via threadgroup-staged simdgroup MMA, instead of re-reading
85        // every block per prompt-token as the mv kernel does.
86        let ggml_mm_src: &'static str =
87            include_str!("shaders/quantized_matmul_mm.metal");
88        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92        // GGML block-format quantized matrix-matrix kernels — tensor API
93        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
94        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
95        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
96        // primitive which on M3+ dispatches to hardware tensor cores for
97        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
98        // Only compiled on devices where the tensor API is available; the
99        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
100        // compilation so non-tensor devices transparently fall back to the
101        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
102        let ggml_mm_tensor_src: &'static str =
103            include_str!("shaders/quantized_matmul_mm_tensor.metal");
104        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
111        // prefill Q@K^T and scores@V, modeled on llama.cpp's
112        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
113        // active).  Tile geometry and write-back identical to the
114        // quantized tensor kernel; only the A-stage copy (bfloat →
115        // bfloat, no dequantize) differs.
116        let dense_mm_bf16_tensor_src: &'static str =
117            include_str!("shaders/dense_mm_bf16_tensor.metal");
118        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120        // Fused scale-mask-softmax for the non-flash-attention prefill
121        // path.  One row-local threadgroup per (head, query) pair
122        // replaces three separate dispatches (scale, mask-add, softmax);
123        // reads a bf16 mask (-INF at masked positions, matching
124        // flash_attn_prefill_mask.metal) that is shared across heads.
125        let scale_mask_softmax_src: &'static str =
126            include_str!("shaders/scale_mask_softmax.metal");
127        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
128
129        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
130        sources.insert(
131            "quantized_matmul_id".into(),
132            include_str!("shaders/quantized_matmul_id.metal"),
133        );
134
135        // Expert-routed (MoE) GGML block-format quantized matmul kernels
136        let ggml_id_src: &'static str =
137            include_str!("shaders/quantized_matmul_id_ggml.metal");
138        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
139        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
140        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
141
142        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
143        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
144        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
145        // Two-stage dispatch: map0 regroups the token-to-expert table into
146        // per-expert routed-token lists, then mm_id stages a 64x32 expert
147        // weight tile into threadgroup shmem and reuses it across a 32-row
148        // block of that expert's routed tokens.
149        let ggml_id_mm_src: &'static str =
150            include_str!("shaders/quantized_matmul_id_mm.metal");
151        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
152        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
153        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
154        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
155        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
156
157        // MoE-routed quantized matrix-matrix kernels — tensor API variant
158        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
159        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
160        // the mm_id kernel is ported — map0 is a short pre-pass (not
161        // matmul) and continues to use the simdgroup version.
162        let ggml_id_mm_tensor_src: &'static str =
163            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
164        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
165        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
166        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
167
168        // Embedding kernels (Story 1.5)
169        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
170        sources.insert("embedding_gather_4bit".into(), embedding_src);
171        sources.insert("embedding_gather_6bit".into(), embedding_src);
172
173        // MoE gate kernel (Story 1.5)
174        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
175        sources.insert("moe_gate".into(), moe_gate_src);
176
177        // MoE dispatch kernels (Story 1.5)
178        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
179        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
180        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
181        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
182        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
183        sources.insert("moe_accumulate".into(), moe_dispatch_src);
184        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
185        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
186        sources.insert("zero_buffer".into(), moe_dispatch_src);
187        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
188        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
189        // bf16 variants (Phase 2 bf16 activation path)
190        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
191        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
192        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
193
194        // Batched KV cache copy kernels
195        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
196        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
197        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
198        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
199        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
200        // Wave P4.11 — fused K+V copy variants
201        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
202        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
203        // bf16-source KV cache copy (Phase 2 bf16 activation path)
204        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
205
206        // Elementwise and transpose kernels (Story 1.5)
207        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
208        sources.insert("elementwise_add_f32".into(), elementwise_src);
209        sources.insert("elementwise_add_f16".into(), elementwise_src);
210        sources.insert("elementwise_mul_f32".into(), elementwise_src);
211        sources.insert("elementwise_mul_f16".into(), elementwise_src);
212        sources.insert("elementwise_add_bf16".into(), elementwise_src);
213        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
214        sources.insert("cast_f16_to_f32".into(), elementwise_src);
215        sources.insert("cast_f32_to_f16".into(), elementwise_src);
216        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
217        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
218        sources.insert("scalar_mul_bf16".into(), elementwise_src);
219        sources.insert("scalar_mul_f32".into(), elementwise_src);
220        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
221        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
222        sources.insert("permute_021_bf16".into(), elementwise_src);
223        sources.insert("transpose_last2_bf16".into(), elementwise_src);
224        sources.insert("permute_021_f32".into(), elementwise_src);
225        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
226        sources.insert("transpose_2d_f32".into(), elementwise_src);
227        sources.insert("transpose_2d_f16".into(), elementwise_src);
228
229        // Attention kernels (Story 1.3)
230        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
231        sources.insert("sdpa".into(), sdpa_src);
232        sources.insert("sdpa_bf16".into(), sdpa_src);
233        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
234        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
235        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
236
237        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
238        // Ten entry points; all backed by the same shader source.
239        // Pipelines are compiled with function constants via
240        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
241        let flash_attn_prefill_src: &'static str =
242            include_str!("shaders/flash_attn_prefill.metal");
243        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
244        sources.insert(
245            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
246            flash_attn_prefill_src,
247        );
248        sources.insert(
249            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
250            flash_attn_prefill_src,
251        );
252        sources.insert(
253            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
254            flash_attn_prefill_src,
255        );
256        sources.insert(
257            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
258            flash_attn_prefill_src,
259        );
260        sources.insert(
261            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
262            flash_attn_prefill_src,
263        );
264        sources.insert(
265            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
266            flash_attn_prefill_src,
267        );
268        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
269        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
270        // the 32 KB Metal limit (candle sdpa.rs:86-94).
271        sources.insert(
272            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
273            flash_attn_prefill_src,
274        );
275        sources.insert(
276            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
277            flash_attn_prefill_src,
278        );
279        sources.insert(
280            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
281            flash_attn_prefill_src,
282        );
283        sources.insert(
284            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
285            flash_attn_prefill_src,
286        );
287
288        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
289        // (ported from llama.cpp flash_attn_ext_vec)
290        let flash_attn_vec_src: &'static str =
291            include_str!("shaders/flash_attn_vec.metal");
292        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
293        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
294        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
295        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
296        // F16 KV variants (Phase 4a)
297        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
298        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
299
300        // RoPE, normalization, activation kernels (Story 1.4)
301        let rope_src: &'static str = include_str!("shaders/rope.metal");
302        sources.insert("rope_f32".into(), rope_src);
303        sources.insert("rope_f16".into(), rope_src);
304        sources.insert("rope_bf16".into(), rope_src);
305        sources.insert("rope_neox_bf16".into(), rope_src);
306        sources.insert("rope_neox_f32".into(), rope_src);
307        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
308        sources.insert("rms_norm_f32".into(), rms_norm_src);
309        sources.insert("rms_norm_f16".into(), rms_norm_src);
310        sources.insert("rms_norm_bf16".into(), rms_norm_src);
311        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
312        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
313        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
314        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
315        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
316        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
317        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
318        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
319        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
320        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
321        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
322        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
323        sources.insert("l2_norm_f32".into(), l2_norm_src);
324        sources.insert("l2_norm_f16".into(), l2_norm_src);
325        sources.insert("l2_norm_bf16".into(), l2_norm_src);
326        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
327        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
328        sources.insert("cumsum_f32".into(), cumsum_src);
329        sources.insert("cumsum_bf16".into(), cumsum_src);
330        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
331        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
332        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
333        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
334        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
335        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
336        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
337        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
338        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
339        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
340        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
341        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
342        sources.insert("rope_multi_f32".into(), rope_multi_src);
343        sources.insert("rope_multi_bf16".into(), rope_multi_src);
344        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
345        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
346        sources.insert("gated_delta_net_f32".into(), gdn_src);
347        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
348        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
349        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
350        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
351        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
352        sources.insert("gelu_f32".into(), gelu_src);
353        sources.insert("gelu_f16".into(), gelu_src);
354        sources.insert("gelu_bf16".into(), gelu_src);
355        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
356        sources.insert("softmax_f32".into(), softmax_src);
357        sources.insert("softmax_f16".into(), softmax_src);
358        sources.insert("softmax_bf16".into(), softmax_src);
359        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
360        sources.insert("softcap_f32".into(), softcap_src);
361        sources.insert("softcap_f16".into(), softcap_src);
362        sources.insert("softcap_bf16".into(), softcap_src);
363
364        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
365        //   normed = rms_norm(input, weight, eps);  output = residual + normed
366        let fused_norm_add_src: &'static str =
367            include_str!("shaders/fused_norm_add_bf16.metal");
368        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
369        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
370
371        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
372        let fused_hnr_f32_src: &'static str =
373            include_str!("shaders/fused_head_norm_rope_f32.metal");
374        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
375
376        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
377        // Both entry points live in the same .metal file.
378        let fused_hnr_bf16_src: &'static str =
379            include_str!("shaders/fused_head_norm_rope_bf16.metal");
380        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
381        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
382
383        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
384        let fused_norm_add_f32_src: &'static str =
385            include_str!("shaders/fused_norm_add_f32.metal");
386        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
387        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
388        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
389        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
390        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
391        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
392        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
393        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
394
395        // Argsort kernel (Story 2.3) — MoE top-K routing
396        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
397        sources.insert("argsort_desc_f32".into(), argsort_src);
398
399        // Gather / index_select kernel (Story 2.4)
400        let gather_src: &'static str = include_str!("shaders/gather.metal");
401        sources.insert("gather_f32".into(), gather_src);
402
403        // F32 KV cache copy kernel (Session merge S1+S2)
404        let kv_cache_copy_src: &'static str =
405            include_str!("shaders/kv_cache_copy.metal");
406        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
407        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
408
409        // Strided copy kernel (Story 2.5)
410        let copy_src: &'static str = include_str!("shaders/copy.metal");
411        sources.insert("strided_copy_f32".into(), copy_src);
412        sources.insert("offset_copy_f32".into(), copy_src);
413
414        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
415        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
416        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
417        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
418        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
419
420        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
421        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
422        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
423        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
424        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
425        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
426        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
427        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
428        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
429
430        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
431        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
432        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
433        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
434        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
435        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
436        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
437
438        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
439        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
440        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
441        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
442        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
443
444        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
445        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
446        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
447        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
448
449        // GPU sampling kernels — eliminate logits readback (Phase 6)
450        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
451        sources.insert("argmax_f32".into(), argmax_src);
452        let softmax_sample_src: &'static str =
453            include_str!("shaders/softmax_sample.metal");
454        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
455        // Top-K kernel for Q8 rerank: avoids full-logits readback.
456        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
457        sources.insert("top_k_f32".into(), top_k_src);
458
459        Self {
460            cache: HashMap::new(),
461            sources,
462        }
463    }
464
465    /// Register a shader source at runtime (useful for testing and dynamic
466    /// kernel generation).
467    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
468        let name = name.into();
469        // Invalidate any cached pipeline for this name since the source changed.
470        self.cache.remove(&name);
471        self.sources.insert(name, source);
472    }
473
474    /// Get a compiled compute pipeline for the named kernel function.
475    ///
476    /// On first call for a given name, this compiles the MSL source into a
477    /// Metal library, extracts the named function, and creates a
478    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
479    ///
480    /// # Errors
481    ///
482    /// * `MlxError::KernelNotFound` — no source registered for this name.
483    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
484    ///   creation failed.
485    pub fn get_pipeline(
486        &mut self,
487        name: &str,
488        device: &metal::DeviceRef,
489    ) -> Result<&ComputePipelineState> {
490        if !self.cache.contains_key(name) {
491            // Slow path: compile the shader.
492            let source = self.sources.get(name).ok_or_else(|| {
493                MlxError::KernelNotFound(name.to_string())
494            })?;
495
496            let compile_opts = metal::CompileOptions::new();
497            let library = device
498                .new_library_with_source(source, &compile_opts)
499                .map_err(|msg| MlxError::ShaderCompilationError {
500                    name: name.to_string(),
501                    message: msg,
502                })?;
503
504            let function = library
505                .get_function(name, None)
506                .map_err(|msg| MlxError::ShaderCompilationError {
507                    name: name.to_string(),
508                    message: msg,
509                })?;
510
511            let pipeline = device
512                .new_compute_pipeline_state_with_function(&function)
513                .map_err(|msg| MlxError::ShaderCompilationError {
514                    name: name.to_string(),
515                    message: msg,
516                })?;
517
518            self.cache.insert(name.to_string(), pipeline);
519        }
520
521        // At this point the pipeline is guaranteed to be in the cache.
522        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
523        self.cache.get(name).ok_or_else(|| {
524            MlxError::KernelNotFound(name.to_string())
525        })
526    }
527
528    /// Get a compiled compute pipeline for the named kernel, specialized with
529    /// Metal function constants (both bool and i32 in one call).
530    ///
531    /// `bool_constants` contains `(index, value)` pairs mapping to
532    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
533    /// `int_constants` contains `(index, value)` pairs mapping to
534    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
535    /// shader.
536    ///
537    /// Pipelines are cached by a composite key:
538    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
539    /// marks bool entries and the 'i' prefix marks i32 entries, making the
540    /// format unambiguous regardless of constant ordering.  Distinct
541    /// `(name, constants)` combinations each compile to a separate pipeline;
542    /// the slow compilation path runs at most once per unique combination.
543    ///
544    /// # Errors
545    ///
546    /// * `MlxError::KernelNotFound` — no source registered for this name.
547    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
548    ///   specialisation, or pipeline creation failed.
549    pub fn get_pipeline_with_constants(
550        &mut self,
551        name: &str,
552        device: &metal::DeviceRef,
553        bool_constants: &[(usize, bool)],
554        int_constants: &[(usize, i32)],
555    ) -> Result<&ComputePipelineState> {
556        // Build a composite cache key so distinct constant combinations each
557        // compile to their own pipeline.  Bool entries use the 'b' type marker
558        // and i32 entries use 'i'; this prevents a collision between, e.g.,
559        // bool index 5 value 1 and int index 5 value 1.
560        let mut cache_key = name.to_string();
561        for &(index, value) in bool_constants {
562            cache_key.push('|');
563            cache_key.push_str(&index.to_string());
564            cache_key.push_str(if value { ":b1" } else { ":b0" });
565        }
566        for &(index, value) in int_constants {
567            cache_key.push('|');
568            cache_key.push_str(&index.to_string());
569            cache_key.push(':');
570            cache_key.push('i');
571            cache_key.push_str(&value.to_string());
572        }
573
574        if !self.cache.contains_key(&cache_key) {
575            // Slow path: compile the shader with function constant specialisation.
576            let source = self.sources.get(name).ok_or_else(|| {
577                MlxError::KernelNotFound(name.to_string())
578            })?;
579
580            let compile_opts = metal::CompileOptions::new();
581            let library = device
582                .new_library_with_source(source, &compile_opts)
583                .map_err(|msg| MlxError::ShaderCompilationError {
584                    name: name.to_string(),
585                    message: msg,
586                })?;
587
588            // Build the FunctionConstantValues object with all bool and i32
589            // constants.  Metal's set_constant_value_at_index reads the value
590            // through a raw pointer; the pointed-to bytes must match the size
591            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
592            let fcv = FunctionConstantValues::new();
593
594            for &(index, value) in bool_constants {
595                // MTLDataType::Bool = 53 (metal-rs argument.rs).
596                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
597                let v: u8 = if value { 1 } else { 0 };
598                fcv.set_constant_value_at_index(
599                    (&v as *const u8).cast::<std::ffi::c_void>(),
600                    MTLDataType::Bool,
601                    index as u64,
602                );
603            }
604
605            for &(index, value) in int_constants {
606                // MTLDataType::Int = 29 (metal-rs argument.rs).
607                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
608                // matching the Metal shader type `constant int`.
609                fcv.set_constant_value_at_index(
610                    (&value as *const i32).cast::<std::ffi::c_void>(),
611                    MTLDataType::Int,
612                    index as u64,
613                );
614            }
615
616            let function = library
617                .get_function(name, Some(fcv))
618                .map_err(|msg| MlxError::ShaderCompilationError {
619                    name: name.to_string(),
620                    message: msg,
621                })?;
622
623            let pipeline = device
624                .new_compute_pipeline_state_with_function(&function)
625                .map_err(|msg| MlxError::ShaderCompilationError {
626                    name: name.to_string(),
627                    message: msg,
628                })?;
629
630            self.cache.insert(cache_key.clone(), pipeline);
631        }
632
633        self.cache.get(&cache_key).ok_or_else(|| {
634            MlxError::KernelNotFound(name.to_string())
635        })
636    }
637
638    /// Get a compiled compute pipeline for the named kernel, specialized with
639    /// Metal bool function constants.
640    ///
641    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
642    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
643    ///
644    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
645    /// passes an empty `int_constants` slice.  Existing callers continue to
646    /// work without modification; the cache-key format for pure-bool pipelines
647    /// is compatible (bool entries carry the 'b' type marker, which is the
648    /// only format ever written by this wrapper).
649    ///
650    /// # Errors
651    ///
652    /// * `MlxError::KernelNotFound` — no source registered for this name.
653    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
654    ///   specialisation, or pipeline creation failed.
655    pub fn get_pipeline_with_bool_constants(
656        &mut self,
657        name: &str,
658        device: &metal::DeviceRef,
659        bool_constants: &[(usize, bool)],
660    ) -> Result<&ComputePipelineState> {
661        self.get_pipeline_with_constants(name, device, bool_constants, &[])
662    }
663
664    /// Check if a pipeline for the given name is already compiled and cached.
665    pub fn is_cached(&self, name: &str) -> bool {
666        self.cache.contains_key(name)
667    }
668
669    /// Number of compiled pipelines currently in the cache.
670    pub fn cached_count(&self) -> usize {
671        self.cache.len()
672    }
673
674    /// Number of registered shader sources.
675    pub fn source_count(&self) -> usize {
676        self.sources.len()
677    }
678}
679
680impl Default for KernelRegistry {
681    fn default() -> Self {
682        Self::new()
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    /// Minimal Metal shader that uses a single int function constant.
691    ///
692    /// The kernel writes the constant value N into the first element of the
693    /// output buffer, allowing the test to verify that the Metal compiler
694    /// actually sees distinct specialisations for N=4 and N=8.
695    ///
696    /// The shader is intentionally trivial — we only need it to *compile* with
697    /// an int function constant; correctness of the kernel logic is not under
698    /// test here.
699    const INT_FC_TEST_SHADER: &str = r#"
700#include <metal_stdlib>
701using namespace metal;
702
703constant int test_N [[function_constant(100)]];
704
705kernel void int_fc_test_kernel(
706    device int* out [[buffer(0)]],
707    uint tid [[thread_position_in_grid]])
708{
709    if (tid == 0) {
710        out[0] = test_N;
711    }
712}
713"#;
714
715    /// Verify that `get_pipeline_with_constants` produces distinct cached
716    /// pipelines for different i32 function-constant values, and that
717    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
718    /// works correctly with the new 'b'-prefixed cache-key format.
719    ///
720    /// This test requires a real Metal device and is therefore marked
721    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
722    #[test]
723    fn test_int_fc_distinct_pipelines_and_bool_compat() {
724        let device = metal::Device::system_default()
725            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
726
727        let mut registry = KernelRegistry::new();
728
729        // Register the inline test shader under a name that cannot collide with
730        // any production kernel.
731        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
732
733        // Compile with N=4.
734        let p4_ptr = registry
735            .get_pipeline_with_constants(
736                "int_fc_test_kernel",
737                &device,
738                &[],                  // no bool constants
739                &[(100, 4_i32)],      // int constant index 100 = 4
740            )
741            .expect("pipeline N=4 should compile") as *const _;
742
743        // Cache must now have exactly 1 entry for this kernel.
744        // (Other production kernels may already be in cache from new(); here
745        // we check that the N=4 key was inserted.)
746        let count_after_n4 = registry.cached_count();
747
748        // Compile with N=8 — must produce a SEPARATE pipeline.
749        let p8_ptr = registry
750            .get_pipeline_with_constants(
751                "int_fc_test_kernel",
752                &device,
753                &[],
754                &[(100, 8_i32)],
755            )
756            .expect("pipeline N=8 should compile") as *const _;
757
758        // Cache must have grown by exactly 1.
759        assert_eq!(
760            registry.cached_count(),
761            count_after_n4 + 1,
762            "N=8 must produce a new cache entry"
763        );
764
765        // The two pipelines must be distinct objects in the cache.
766        assert_ne!(
767            p4_ptr, p8_ptr,
768            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
769        );
770
771        // A second call with N=4 must return the SAME pipeline (cache hit, no
772        // new compilation).
773        let p4_again_ptr = registry
774            .get_pipeline_with_constants(
775                "int_fc_test_kernel",
776                &device,
777                &[],
778                &[(100, 4_i32)],
779            )
780            .expect("pipeline N=4 cache hit should succeed") as *const _;
781
782        assert_eq!(
783            registry.cached_count(),
784            count_after_n4 + 1,
785            "repeated N=4 call must be a cache hit, not a new entry"
786        );
787        assert_eq!(
788            p4_ptr, p4_again_ptr,
789            "repeated N=4 call must return the same pipeline pointer"
790        );
791
792        // Verify backward compatibility: get_pipeline_with_bool_constants must
793        // still route through get_pipeline_with_constants and produce a cached
794        // pipeline without panicking.
795        //
796        // We register a separate bool-constant shader that does NOT use a bool
797        // function constant (so the Metal compiler ignores missing FCs for
798        // this trivial case) — but the call path and cache-key format are what
799        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
800        // simply unused by the shader (Metal allows unused FCs when the shader
801        // declares them with `function_constant` but the value is never read).
802        //
803        // To avoid a Metal compiler error for an undeclared function constant,
804        // we register a separate bare-kernel shader for the bool wrapper test.
805        const BARE_SHADER: &str = r#"
806#include <metal_stdlib>
807using namespace metal;
808kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
809    if (tid == 0) { out[0] = 42; }
810}
811"#;
812        registry.register_source("bare_kernel", BARE_SHADER);
813
814        let count_before_bool = registry.cached_count();
815        let _bool_pipeline = registry
816            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
817            .expect("bool-constants wrapper with empty slice must succeed");
818
819        assert_eq!(
820            registry.cached_count(),
821            count_before_bool + 1,
822            "bool-constants wrapper must insert one new cache entry"
823        );
824    }
825}