Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81        // GGML block-format quantized matrix-matrix kernels
82        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
83        // Used at prefill m > 8 to reuse each weight tile across a 32-row
84        // block via threadgroup-staged simdgroup MMA, instead of re-reading
85        // every block per prompt-token as the mv kernel does.
86        let ggml_mm_src: &'static str =
87            include_str!("shaders/quantized_matmul_mm.metal");
88        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92        // GGML block-format quantized matrix-matrix kernels — tensor API
93        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
94        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
95        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
96        // primitive which on M3+ dispatches to hardware tensor cores for
97        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
98        // Only compiled on devices where the tensor API is available; the
99        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
100        // compilation so non-tensor devices transparently fall back to the
101        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
102        let ggml_mm_tensor_src: &'static str =
103            include_str!("shaders/quantized_matmul_mm_tensor.metal");
104        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
111        // prefill Q@K^T and scores@V, modeled on llama.cpp's
112        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
113        // active).  Tile geometry and write-back identical to the
114        // quantized tensor kernel; only the A-stage copy (bfloat →
115        // bfloat, no dequantize) differs.
116        let dense_mm_bf16_tensor_src: &'static str =
117            include_str!("shaders/dense_mm_bf16_tensor.metal");
118        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120        // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
121        // for M=1 single-token decode.  Port of llama.cpp's
122        // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
123        // Used in apply_linear_projection_f32 when seq_len=1 and the
124        // weight matrix is BF16, replacing the MM kernel (~2× faster for
125        // M=1 due to better memory bandwidth utilization per thread).
126        let dense_gemv_bf16_src: &'static str =
127            include_str!("shaders/dense_gemv_bf16.metal");
128        sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
129
130        // Fused scale-mask-softmax for the non-flash-attention prefill
131        // path.  One row-local threadgroup per (head, query) pair
132        // replaces three separate dispatches (scale, mask-add, softmax);
133        // reads a bf16 mask (-INF at masked positions, matching
134        // flash_attn_prefill_mask.metal) that is shared across heads.
135        let scale_mask_softmax_src: &'static str =
136            include_str!("shaders/scale_mask_softmax.metal");
137        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
138
139        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
140        sources.insert(
141            "quantized_matmul_id".into(),
142            include_str!("shaders/quantized_matmul_id.metal"),
143        );
144
145        // Expert-routed (MoE) GGML block-format quantized matmul kernels
146        let ggml_id_src: &'static str =
147            include_str!("shaders/quantized_matmul_id_ggml.metal");
148        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
149        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
150        sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
151        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
152
153        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
154        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
155        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
156        // Two-stage dispatch: map0 regroups the token-to-expert table into
157        // per-expert routed-token lists, then mm_id stages a 64x32 expert
158        // weight tile into threadgroup shmem and reuses it across a 32-row
159        // block of that expert's routed tokens.
160        let ggml_id_mm_src: &'static str =
161            include_str!("shaders/quantized_matmul_id_mm.metal");
162        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
163        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
164        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
165        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
166        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
167
168        // MoE-routed quantized matrix-matrix kernels — tensor API variant
169        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
170        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
171        // the mm_id kernel is ported — map0 is a short pre-pass (not
172        // matmul) and continues to use the simdgroup version.
173        let ggml_id_mm_tensor_src: &'static str =
174            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
175        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
176        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
177        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
178
179        // Embedding kernels (Story 1.5)
180        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
181        sources.insert("embedding_gather_4bit".into(), embedding_src);
182        sources.insert("embedding_gather_6bit".into(), embedding_src);
183
184        // MoE gate kernel (Story 1.5)
185        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
186        sources.insert("moe_gate".into(), moe_gate_src);
187
188        // MoE dispatch kernels (Story 1.5)
189        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
190        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
191        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
192        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
193        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
194        sources.insert("moe_accumulate".into(), moe_dispatch_src);
195        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
196        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
197        sources.insert("zero_buffer".into(), moe_dispatch_src);
198        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
199        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
200        // bf16 variants (Phase 2 bf16 activation path)
201        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
202        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
203        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
204
205        // Batched KV cache copy kernels
206        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
207        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
208        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
209        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
210        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
211        // Wave P4.11 — fused K+V copy variants
212        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
213        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
214        // bf16-source KV cache copy (Phase 2 bf16 activation path)
215        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
216
217        // Elementwise and transpose kernels (Story 1.5)
218        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
219        sources.insert("elementwise_add_f32".into(), elementwise_src);
220        sources.insert("elementwise_add_f16".into(), elementwise_src);
221        sources.insert("elementwise_mul_f32".into(), elementwise_src);
222        sources.insert("elementwise_mul_f16".into(), elementwise_src);
223        sources.insert("elementwise_add_bf16".into(), elementwise_src);
224        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
225        sources.insert("cast_f16_to_f32".into(), elementwise_src);
226        sources.insert("cast_f32_to_f16".into(), elementwise_src);
227        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
228        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
229        sources.insert("scalar_mul_bf16".into(), elementwise_src);
230        sources.insert("scalar_mul_f32".into(), elementwise_src);
231        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
232        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
233        sources.insert("permute_021_bf16".into(), elementwise_src);
234        sources.insert("transpose_last2_bf16".into(), elementwise_src);
235        sources.insert("permute_021_f32".into(), elementwise_src);
236        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
237        sources.insert("transpose_2d_f32".into(), elementwise_src);
238        sources.insert("transpose_2d_f16".into(), elementwise_src);
239
240        // Attention kernels (Story 1.3)
241        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
242        sources.insert("sdpa".into(), sdpa_src);
243        sources.insert("sdpa_bf16".into(), sdpa_src);
244        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
245        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
246        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
247
248        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
249        // Ten entry points; all backed by the same shader source.
250        // Pipelines are compiled with function constants via
251        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
252        let flash_attn_prefill_src: &'static str =
253            include_str!("shaders/flash_attn_prefill.metal");
254        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
255        sources.insert(
256            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
257            flash_attn_prefill_src,
258        );
259        sources.insert(
260            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
261            flash_attn_prefill_src,
262        );
263        sources.insert(
264            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
265            flash_attn_prefill_src,
266        );
267        sources.insert(
268            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
269            flash_attn_prefill_src,
270        );
271        sources.insert(
272            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
273            flash_attn_prefill_src,
274        );
275        sources.insert(
276            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
277            flash_attn_prefill_src,
278        );
279        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
280        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
281        // the 32 KB Metal limit (candle sdpa.rs:86-94).
282        sources.insert(
283            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
284            flash_attn_prefill_src,
285        );
286        sources.insert(
287            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
288            flash_attn_prefill_src,
289        );
290        sources.insert(
291            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
292            flash_attn_prefill_src,
293        );
294        sources.insert(
295            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
296            flash_attn_prefill_src,
297        );
298
299        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
300        // (ported from llama.cpp flash_attn_ext_vec)
301        let flash_attn_vec_src: &'static str =
302            include_str!("shaders/flash_attn_vec.metal");
303        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
304        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
305        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
306        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
307        // F16 KV variants (Phase 4a)
308        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
309        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
310
311        // RoPE, normalization, activation kernels (Story 1.4)
312        let rope_src: &'static str = include_str!("shaders/rope.metal");
313        sources.insert("rope_f32".into(), rope_src);
314        sources.insert("rope_f16".into(), rope_src);
315        sources.insert("rope_bf16".into(), rope_src);
316        sources.insert("rope_neox_bf16".into(), rope_src);
317        sources.insert("rope_neox_f32".into(), rope_src);
318        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
319        sources.insert("rms_norm_f32".into(), rms_norm_src);
320        sources.insert("rms_norm_f16".into(), rms_norm_src);
321        sources.insert("rms_norm_bf16".into(), rms_norm_src);
322        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
323        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
324        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
325        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
326        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
327        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
328        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
329        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
330        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
331        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
332        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
333        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
334        sources.insert("l2_norm_f32".into(), l2_norm_src);
335        sources.insert("l2_norm_f16".into(), l2_norm_src);
336        sources.insert("l2_norm_bf16".into(), l2_norm_src);
337        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
338        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
339        sources.insert("cumsum_f32".into(), cumsum_src);
340        sources.insert("cumsum_bf16".into(), cumsum_src);
341        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
342        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
343        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
344        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
345        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
346        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
347        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
348        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
349        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
350        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
351        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
352        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
353        sources.insert("rope_multi_f32".into(), rope_multi_src);
354        sources.insert("rope_multi_bf16".into(), rope_multi_src);
355        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
356        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
357        sources.insert("gated_delta_net_f32".into(), gdn_src);
358        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
359        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
360        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
361        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
362        let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
363        sources.insert("silu_mul_f32".into(), silu_mul_src);
364        let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
365        sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
366        let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
367        sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
368        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
369        sources.insert("gelu_f32".into(), gelu_src);
370        sources.insert("gelu_f16".into(), gelu_src);
371        sources.insert("gelu_bf16".into(), gelu_src);
372        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
373        sources.insert("softmax_f32".into(), softmax_src);
374        sources.insert("softmax_f16".into(), softmax_src);
375        sources.insert("softmax_bf16".into(), softmax_src);
376        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
377        sources.insert("softcap_f32".into(), softcap_src);
378        sources.insert("softcap_f16".into(), softcap_src);
379        sources.insert("softcap_bf16".into(), softcap_src);
380
381        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
382        //   normed = rms_norm(input, weight, eps);  output = residual + normed
383        let fused_norm_add_src: &'static str =
384            include_str!("shaders/fused_norm_add_bf16.metal");
385        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
386        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
387
388        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
389        let fused_hnr_f32_src: &'static str =
390            include_str!("shaders/fused_head_norm_rope_f32.metal");
391        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
392
393        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
394        // Both entry points live in the same .metal file.
395        let fused_hnr_bf16_src: &'static str =
396            include_str!("shaders/fused_head_norm_rope_bf16.metal");
397        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
398        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
399
400        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
401        let fused_norm_add_f32_src: &'static str =
402            include_str!("shaders/fused_norm_add_f32.metal");
403        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
404        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
405        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
406        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
407        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
408        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
409        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
410        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
411
412        // Argsort kernel (Story 2.3) — MoE top-K routing
413        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
414        sources.insert("argsort_desc_f32".into(), argsort_src);
415
416        // Gather / index_select kernel (Story 2.4)
417        let gather_src: &'static str = include_str!("shaders/gather.metal");
418        sources.insert("gather_f32".into(), gather_src);
419
420        // F32 KV cache copy kernel (Session merge S1+S2)
421        let kv_cache_copy_src: &'static str =
422            include_str!("shaders/kv_cache_copy.metal");
423        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
424        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
425
426        // Strided copy kernel (Story 2.5)
427        let copy_src: &'static str = include_str!("shaders/copy.metal");
428        sources.insert("strided_copy_f32".into(), copy_src);
429        sources.insert("offset_copy_f32".into(), copy_src);
430
431        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
432        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
433        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
434        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
435        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
436        // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
437        sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
438        // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
439        sources.insert("dense_matvec_f32".into(), dense_gemm_src);
440
441        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
442        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
443        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
444        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
445        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
446        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
447        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
448        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
449        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
450
451        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
452        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
453        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
454        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
455        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
456        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
457        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
458
459        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
460        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
461        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
462        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
463        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
464
465        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
466        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
467        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
468        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
469
470        // GPU sampling kernels — eliminate logits readback (Phase 6)
471        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
472        sources.insert("argmax_f32".into(), argmax_src);
473        let softmax_sample_src: &'static str =
474            include_str!("shaders/softmax_sample.metal");
475        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
476        // Top-K kernel for Q8 rerank: avoids full-logits readback.
477        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
478        sources.insert("top_k_f32".into(), top_k_src);
479
480        // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
481        // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
482        let moe_stk_src: &'static str =
483            include_str!("shaders/moe_softmax_topk.metal");
484        sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
485        let moe_wr_src: &'static str =
486            include_str!("shaders/moe_weighted_reduce.metal");
487        sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
488        let sdpa_decode_src: &'static str =
489            include_str!("shaders/sdpa_decode.metal");
490        sources.insert("sdpa_decode".into(), sdpa_decode_src);
491
492        Self {
493            cache: HashMap::new(),
494            sources,
495        }
496    }
497
498    /// Register a shader source at runtime (useful for testing and dynamic
499    /// kernel generation).
500    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
501        let name = name.into();
502        // Invalidate any cached pipeline for this name since the source changed.
503        self.cache.remove(&name);
504        self.sources.insert(name, source);
505    }
506
507    /// Get a compiled compute pipeline for the named kernel function.
508    ///
509    /// On first call for a given name, this compiles the MSL source into a
510    /// Metal library, extracts the named function, and creates a
511    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
512    ///
513    /// # Errors
514    ///
515    /// * `MlxError::KernelNotFound` — no source registered for this name.
516    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
517    ///   creation failed.
518    pub fn get_pipeline(
519        &mut self,
520        name: &str,
521        device: &metal::DeviceRef,
522    ) -> Result<&ComputePipelineState> {
523        if !self.cache.contains_key(name) {
524            // Slow path: compile the shader.
525            let source = self.sources.get(name).ok_or_else(|| {
526                MlxError::KernelNotFound(name.to_string())
527            })?;
528
529            let compile_opts = metal::CompileOptions::new();
530            let library = device
531                .new_library_with_source(source, &compile_opts)
532                .map_err(|msg| MlxError::ShaderCompilationError {
533                    name: name.to_string(),
534                    message: msg,
535                })?;
536
537            let function = library
538                .get_function(name, None)
539                .map_err(|msg| MlxError::ShaderCompilationError {
540                    name: name.to_string(),
541                    message: msg,
542                })?;
543
544            let pipeline = device
545                .new_compute_pipeline_state_with_function(&function)
546                .map_err(|msg| MlxError::ShaderCompilationError {
547                    name: name.to_string(),
548                    message: msg,
549                })?;
550
551            self.cache.insert(name.to_string(), pipeline);
552        }
553
554        // At this point the pipeline is guaranteed to be in the cache.
555        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
556        self.cache.get(name).ok_or_else(|| {
557            MlxError::KernelNotFound(name.to_string())
558        })
559    }
560
561    /// Get a compiled compute pipeline for the named kernel, specialized with
562    /// Metal function constants (both bool and i32 in one call).
563    ///
564    /// `bool_constants` contains `(index, value)` pairs mapping to
565    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
566    /// `int_constants` contains `(index, value)` pairs mapping to
567    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
568    /// shader.
569    ///
570    /// Pipelines are cached by a composite key:
571    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
572    /// marks bool entries and the 'i' prefix marks i32 entries, making the
573    /// format unambiguous regardless of constant ordering.  Distinct
574    /// `(name, constants)` combinations each compile to a separate pipeline;
575    /// the slow compilation path runs at most once per unique combination.
576    ///
577    /// # Errors
578    ///
579    /// * `MlxError::KernelNotFound` — no source registered for this name.
580    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
581    ///   specialisation, or pipeline creation failed.
582    pub fn get_pipeline_with_constants(
583        &mut self,
584        name: &str,
585        device: &metal::DeviceRef,
586        bool_constants: &[(usize, bool)],
587        int_constants: &[(usize, i32)],
588    ) -> Result<&ComputePipelineState> {
589        // Build a composite cache key so distinct constant combinations each
590        // compile to their own pipeline.  Bool entries use the 'b' type marker
591        // and i32 entries use 'i'; this prevents a collision between, e.g.,
592        // bool index 5 value 1 and int index 5 value 1.
593        let mut cache_key = name.to_string();
594        for &(index, value) in bool_constants {
595            cache_key.push('|');
596            cache_key.push_str(&index.to_string());
597            cache_key.push_str(if value { ":b1" } else { ":b0" });
598        }
599        for &(index, value) in int_constants {
600            cache_key.push('|');
601            cache_key.push_str(&index.to_string());
602            cache_key.push(':');
603            cache_key.push('i');
604            cache_key.push_str(&value.to_string());
605        }
606
607        if !self.cache.contains_key(&cache_key) {
608            // Slow path: compile the shader with function constant specialisation.
609            let source = self.sources.get(name).ok_or_else(|| {
610                MlxError::KernelNotFound(name.to_string())
611            })?;
612
613            let compile_opts = metal::CompileOptions::new();
614            let library = device
615                .new_library_with_source(source, &compile_opts)
616                .map_err(|msg| MlxError::ShaderCompilationError {
617                    name: name.to_string(),
618                    message: msg,
619                })?;
620
621            // Build the FunctionConstantValues object with all bool and i32
622            // constants.  Metal's set_constant_value_at_index reads the value
623            // through a raw pointer; the pointed-to bytes must match the size
624            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
625            let fcv = FunctionConstantValues::new();
626
627            for &(index, value) in bool_constants {
628                // MTLDataType::Bool = 53 (metal-rs argument.rs).
629                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
630                let v: u8 = if value { 1 } else { 0 };
631                fcv.set_constant_value_at_index(
632                    (&v as *const u8).cast::<std::ffi::c_void>(),
633                    MTLDataType::Bool,
634                    index as u64,
635                );
636            }
637
638            for &(index, value) in int_constants {
639                // MTLDataType::Int = 29 (metal-rs argument.rs).
640                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
641                // matching the Metal shader type `constant int`.
642                fcv.set_constant_value_at_index(
643                    (&value as *const i32).cast::<std::ffi::c_void>(),
644                    MTLDataType::Int,
645                    index as u64,
646                );
647            }
648
649            let function = library
650                .get_function(name, Some(fcv))
651                .map_err(|msg| MlxError::ShaderCompilationError {
652                    name: name.to_string(),
653                    message: msg,
654                })?;
655
656            let pipeline = device
657                .new_compute_pipeline_state_with_function(&function)
658                .map_err(|msg| MlxError::ShaderCompilationError {
659                    name: name.to_string(),
660                    message: msg,
661                })?;
662
663            self.cache.insert(cache_key.clone(), pipeline);
664        }
665
666        self.cache.get(&cache_key).ok_or_else(|| {
667            MlxError::KernelNotFound(name.to_string())
668        })
669    }
670
671    /// Get a compiled compute pipeline for the named kernel, specialized with
672    /// Metal bool function constants.
673    ///
674    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
675    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
676    ///
677    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
678    /// passes an empty `int_constants` slice.  Existing callers continue to
679    /// work without modification; the cache-key format for pure-bool pipelines
680    /// is compatible (bool entries carry the 'b' type marker, which is the
681    /// only format ever written by this wrapper).
682    ///
683    /// # Errors
684    ///
685    /// * `MlxError::KernelNotFound` — no source registered for this name.
686    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
687    ///   specialisation, or pipeline creation failed.
688    pub fn get_pipeline_with_bool_constants(
689        &mut self,
690        name: &str,
691        device: &metal::DeviceRef,
692        bool_constants: &[(usize, bool)],
693    ) -> Result<&ComputePipelineState> {
694        self.get_pipeline_with_constants(name, device, bool_constants, &[])
695    }
696
697    /// Check if a pipeline for the given name is already compiled and cached.
698    pub fn is_cached(&self, name: &str) -> bool {
699        self.cache.contains_key(name)
700    }
701
702    /// Number of compiled pipelines currently in the cache.
703    pub fn cached_count(&self) -> usize {
704        self.cache.len()
705    }
706
707    /// Number of registered shader sources.
708    pub fn source_count(&self) -> usize {
709        self.sources.len()
710    }
711}
712
713impl Default for KernelRegistry {
714    fn default() -> Self {
715        Self::new()
716    }
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722
723    /// Minimal Metal shader that uses a single int function constant.
724    ///
725    /// The kernel writes the constant value N into the first element of the
726    /// output buffer, allowing the test to verify that the Metal compiler
727    /// actually sees distinct specialisations for N=4 and N=8.
728    ///
729    /// The shader is intentionally trivial — we only need it to *compile* with
730    /// an int function constant; correctness of the kernel logic is not under
731    /// test here.
732    const INT_FC_TEST_SHADER: &str = r#"
733#include <metal_stdlib>
734using namespace metal;
735
736constant int test_N [[function_constant(100)]];
737
738kernel void int_fc_test_kernel(
739    device int* out [[buffer(0)]],
740    uint tid [[thread_position_in_grid]])
741{
742    if (tid == 0) {
743        out[0] = test_N;
744    }
745}
746"#;
747
748    /// Verify that `get_pipeline_with_constants` produces distinct cached
749    /// pipelines for different i32 function-constant values, and that
750    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
751    /// works correctly with the new 'b'-prefixed cache-key format.
752    ///
753    /// This test requires a real Metal device and is therefore marked
754    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
755    #[test]
756    fn test_int_fc_distinct_pipelines_and_bool_compat() {
757        let device = metal::Device::system_default()
758            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
759
760        let mut registry = KernelRegistry::new();
761
762        // Register the inline test shader under a name that cannot collide with
763        // any production kernel.
764        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
765
766        // Compile with N=4.
767        let p4_ptr = registry
768            .get_pipeline_with_constants(
769                "int_fc_test_kernel",
770                &device,
771                &[],                  // no bool constants
772                &[(100, 4_i32)],      // int constant index 100 = 4
773            )
774            .expect("pipeline N=4 should compile") as *const _;
775
776        // Cache must now have exactly 1 entry for this kernel.
777        // (Other production kernels may already be in cache from new(); here
778        // we check that the N=4 key was inserted.)
779        let count_after_n4 = registry.cached_count();
780
781        // Compile with N=8 — must produce a SEPARATE pipeline.
782        let p8_ptr = registry
783            .get_pipeline_with_constants(
784                "int_fc_test_kernel",
785                &device,
786                &[],
787                &[(100, 8_i32)],
788            )
789            .expect("pipeline N=8 should compile") as *const _;
790
791        // Cache must have grown by exactly 1.
792        assert_eq!(
793            registry.cached_count(),
794            count_after_n4 + 1,
795            "N=8 must produce a new cache entry"
796        );
797
798        // The two pipelines must be distinct objects in the cache.
799        assert_ne!(
800            p4_ptr, p8_ptr,
801            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
802        );
803
804        // A second call with N=4 must return the SAME pipeline (cache hit, no
805        // new compilation).
806        let p4_again_ptr = registry
807            .get_pipeline_with_constants(
808                "int_fc_test_kernel",
809                &device,
810                &[],
811                &[(100, 4_i32)],
812            )
813            .expect("pipeline N=4 cache hit should succeed") as *const _;
814
815        assert_eq!(
816            registry.cached_count(),
817            count_after_n4 + 1,
818            "repeated N=4 call must be a cache hit, not a new entry"
819        );
820        assert_eq!(
821            p4_ptr, p4_again_ptr,
822            "repeated N=4 call must return the same pipeline pointer"
823        );
824
825        // Verify backward compatibility: get_pipeline_with_bool_constants must
826        // still route through get_pipeline_with_constants and produce a cached
827        // pipeline without panicking.
828        //
829        // We register a separate bool-constant shader that does NOT use a bool
830        // function constant (so the Metal compiler ignores missing FCs for
831        // this trivial case) — but the call path and cache-key format are what
832        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
833        // simply unused by the shader (Metal allows unused FCs when the shader
834        // declares them with `function_constant` but the value is never read).
835        //
836        // To avoid a Metal compiler error for an undeclared function constant,
837        // we register a separate bare-kernel shader for the bool wrapper test.
838        const BARE_SHADER: &str = r#"
839#include <metal_stdlib>
840using namespace metal;
841kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
842    if (tid == 0) { out[0] = 42; }
843}
844"#;
845        registry.register_source("bare_kernel", BARE_SHADER);
846
847        let count_before_bool = registry.cached_count();
848        let _bool_pipeline = registry
849            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
850            .expect("bool-constants wrapper with empty slice must succeed");
851
852        assert_eq!(
853            registry.cached_count(),
854            count_before_bool + 1,
855            "bool-constants wrapper must insert one new cache entry"
856        );
857    }
858}