Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81        // GGML block-format quantized matrix-matrix kernels
82        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
83        // Used at prefill m > 8 to reuse each weight tile across a 32-row
84        // block via threadgroup-staged simdgroup MMA, instead of re-reading
85        // every block per prompt-token as the mv kernel does.
86        let ggml_mm_src: &'static str =
87            include_str!("shaders/quantized_matmul_mm.metal");
88        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92        // GGML block-format quantized matrix-matrix kernels — tensor API
93        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
94        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
95        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
96        // primitive which on M3+ dispatches to hardware tensor cores for
97        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
98        // Only compiled on devices where the tensor API is available; the
99        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
100        // compilation so non-tensor devices transparently fall back to the
101        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
102        let ggml_mm_tensor_src: &'static str =
103            include_str!("shaders/quantized_matmul_mm_tensor.metal");
104        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
111        // prefill Q@K^T and scores@V, modeled on llama.cpp's
112        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
113        // active).  Tile geometry and write-back identical to the
114        // quantized tensor kernel; only the A-stage copy (bfloat →
115        // bfloat, no dequantize) differs.
116        let dense_mm_bf16_tensor_src: &'static str =
117            include_str!("shaders/dense_mm_bf16_tensor.metal");
118        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120        // Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
121        // sibling of dense_mm_bf16_tensor).  Used by hf2q's ADR-005
122        // iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
123        // the BF16 K-stage cast as a confounding variable.  Port of
124        // llama.cpp's kernel_mul_mm_f32_f32 specialization
125        // (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
126        // branch.  Same tile geometry (NR0=64 NR1=32 NK=32) but
127        // float-everywhere shmem staging.
128        let dense_mm_f32_f32_tensor_src: &'static str =
129            include_str!("shaders/dense_mm_f32_f32.metal");
130        sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
131
132        // Dense f16×f32 → f32 tensor-API matmul (F16-staging sibling
133        // of dense_mm_bf16_tensor).  Used by hf2q's ADR-005 Phase 2c
134        // iter-128 gemma4v ViT precision-parity path: every mmproj
135        // weight is stored as F16 in GGUF, peer's `kernel_mul_mm_f16_f32`
136        // (`ggml-metal.metal:10099`) stages BOTH A and B as `half` in
137        // shmem and computes on `simdgroup_half8x8`.  Matches peer
138        // per-element rounding budget exactly (10-bit mantissa vs
139        // BF16's 7-bit), closing the 1.16x/block cascade compound that
140        // iter-127 numerically bisected to BF16 staging.  Same tile
141        // geometry as the BF16 sibling (NR0=64 NR1=32 NK=32, 8 KB
142        // shmem) — half and bfloat share 16-bit storage.
143        let dense_mm_f16_tensor_src: &'static str =
144            include_str!("shaders/dense_mm_f16_tensor.metal");
145        sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
146
147        // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
148        // for M=1 single-token decode.  Port of llama.cpp's
149        // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
150        // Used in apply_linear_projection_f32 when seq_len=1 and the
151        // weight matrix is BF16, replacing the MM kernel (~2× faster for
152        // M=1 due to better memory bandwidth utilization per thread).
153        let dense_gemv_bf16_src: &'static str =
154            include_str!("shaders/dense_gemv_bf16.metal");
155        sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
156
157        // Fused scale-mask-softmax for the non-flash-attention prefill
158        // path.  One row-local threadgroup per (head, query) pair
159        // replaces three separate dispatches (scale, mask-add, softmax);
160        // reads a bf16 mask (-INF at masked positions, matching
161        // flash_attn_prefill_mask.metal) that is shared across heads.
162        let scale_mask_softmax_src: &'static str =
163            include_str!("shaders/scale_mask_softmax.metal");
164        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
165
166        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
167        sources.insert(
168            "quantized_matmul_id".into(),
169            include_str!("shaders/quantized_matmul_id.metal"),
170        );
171
172        // Expert-routed (MoE) GGML block-format quantized matmul kernels
173        let ggml_id_src: &'static str =
174            include_str!("shaders/quantized_matmul_id_ggml.metal");
175        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
176        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
177        sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
178        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
179
180        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
181        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
182        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
183        // Two-stage dispatch: map0 regroups the token-to-expert table into
184        // per-expert routed-token lists, then mm_id stages a 64x32 expert
185        // weight tile into threadgroup shmem and reuses it across a 32-row
186        // block of that expert's routed tokens.
187        let ggml_id_mm_src: &'static str =
188            include_str!("shaders/quantized_matmul_id_mm.metal");
189        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
190        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
191        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
192        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
193        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
194
195        // MoE-routed quantized matrix-matrix kernels — tensor API variant
196        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
197        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
198        // the mm_id kernel is ported — map0 is a short pre-pass (not
199        // matmul) and continues to use the simdgroup version.
200        let ggml_id_mm_tensor_src: &'static str =
201            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
202        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
203        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
204        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
205
206        // Embedding kernels (Story 1.5)
207        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
208        sources.insert("embedding_gather_4bit".into(), embedding_src);
209        sources.insert("embedding_gather_6bit".into(), embedding_src);
210
211        // MoE gate kernel (Story 1.5)
212        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
213        sources.insert("moe_gate".into(), moe_gate_src);
214
215        // MoE dispatch kernels (Story 1.5)
216        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
217        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
218        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
219        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
220        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
221        sources.insert("moe_accumulate".into(), moe_dispatch_src);
222        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
223        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
224        sources.insert("zero_buffer".into(), moe_dispatch_src);
225        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
226        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
227        // bf16 variants (Phase 2 bf16 activation path)
228        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
229        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
230        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
231
232        // Batched KV cache copy kernels
233        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
234        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
235        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
236        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
237        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
238        // Wave P4.11 — fused K+V copy variants
239        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
240        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
241        // bf16-source KV cache copy (Phase 2 bf16 activation path)
242        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
243
244        // Elementwise and transpose kernels (Story 1.5)
245        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
246        sources.insert("elementwise_add_f32".into(), elementwise_src);
247        sources.insert("elementwise_add_f16".into(), elementwise_src);
248        sources.insert("elementwise_mul_f32".into(), elementwise_src);
249        sources.insert("elementwise_mul_f16".into(), elementwise_src);
250        sources.insert("elementwise_add_bf16".into(), elementwise_src);
251        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
252        sources.insert("cast_f16_to_f32".into(), elementwise_src);
253        sources.insert("cast_f32_to_f16".into(), elementwise_src);
254        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
255        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
256        sources.insert("scalar_mul_bf16".into(), elementwise_src);
257        sources.insert("scalar_mul_f32".into(), elementwise_src);
258        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
259        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
260        sources.insert("permute_021_bf16".into(), elementwise_src);
261        sources.insert("transpose_last2_bf16".into(), elementwise_src);
262        sources.insert("transpose_last2_f16".into(), elementwise_src);
263        sources.insert("permute_021_f32".into(), elementwise_src);
264        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
265        sources.insert("transpose_2d_f32".into(), elementwise_src);
266        sources.insert("transpose_2d_f16".into(), elementwise_src);
267
268        // Attention kernels (Story 1.3)
269        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
270        sources.insert("sdpa".into(), sdpa_src);
271        sources.insert("sdpa_bf16".into(), sdpa_src);
272        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
273        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
274        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
275
276        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
277        // Ten entry points; all backed by the same shader source.
278        // Pipelines are compiled with function constants via
279        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
280        let flash_attn_prefill_src: &'static str =
281            include_str!("shaders/flash_attn_prefill.metal");
282        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
283        sources.insert(
284            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
285            flash_attn_prefill_src,
286        );
287        sources.insert(
288            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
289            flash_attn_prefill_src,
290        );
291        sources.insert(
292            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
293            flash_attn_prefill_src,
294        );
295        sources.insert(
296            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
297            flash_attn_prefill_src,
298        );
299        sources.insert(
300            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
301            flash_attn_prefill_src,
302        );
303        sources.insert(
304            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
305            flash_attn_prefill_src,
306        );
307        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
308        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
309        // the 32 KB Metal limit (candle sdpa.rs:86-94).
310        sources.insert(
311            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
312            flash_attn_prefill_src,
313        );
314        sources.insert(
315            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
316            flash_attn_prefill_src,
317        );
318        sources.insert(
319            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
320            flash_attn_prefill_src,
321        );
322        sources.insert(
323            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
324            flash_attn_prefill_src,
325        );
326
327        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
328        // (ported from llama.cpp flash_attn_ext_vec)
329        let flash_attn_vec_src: &'static str =
330            include_str!("shaders/flash_attn_vec.metal");
331        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
332        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
333        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
334        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
335        // F16 KV variants (Phase 4a)
336        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
337        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
338
339        // RoPE, normalization, activation kernels (Story 1.4)
340        let rope_src: &'static str = include_str!("shaders/rope.metal");
341        sources.insert("rope_f32".into(), rope_src);
342        sources.insert("rope_f16".into(), rope_src);
343        sources.insert("rope_bf16".into(), rope_src);
344        sources.insert("rope_neox_bf16".into(), rope_src);
345        sources.insert("rope_neox_f32".into(), rope_src);
346        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
347        sources.insert("rms_norm_f32".into(), rms_norm_src);
348        sources.insert("rms_norm_f16".into(), rms_norm_src);
349        sources.insert("rms_norm_bf16".into(), rms_norm_src);
350        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
351        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
352        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
353        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
354        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
355        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
356        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
357        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
358        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
359        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
360        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
361        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
362        sources.insert("l2_norm_f32".into(), l2_norm_src);
363        sources.insert("l2_norm_f16".into(), l2_norm_src);
364        sources.insert("l2_norm_bf16".into(), l2_norm_src);
365        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
366        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
367        sources.insert("cumsum_f32".into(), cumsum_src);
368        sources.insert("cumsum_bf16".into(), cumsum_src);
369        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
370        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
371        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
372        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
373        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
374        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
375        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
376        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
377        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
378        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
379        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
380        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
381        sources.insert("rope_multi_f32".into(), rope_multi_src);
382        sources.insert("rope_multi_bf16".into(), rope_multi_src);
383        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
384        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
385        sources.insert("gated_delta_net_f32".into(), gdn_src);
386        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
387        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
388        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
389        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
390        let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
391        sources.insert("silu_mul_f32".into(), silu_mul_src);
392        let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
393        sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
394        let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
395        sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
396        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
397        sources.insert("gelu_f32".into(), gelu_src);
398        sources.insert("gelu_f16".into(), gelu_src);
399        sources.insert("gelu_bf16".into(), gelu_src);
400        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
401        sources.insert("softmax_f32".into(), softmax_src);
402        sources.insert("softmax_f16".into(), softmax_src);
403        sources.insert("softmax_bf16".into(), softmax_src);
404        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
405        sources.insert("softcap_f32".into(), softcap_src);
406        sources.insert("softcap_f16".into(), softcap_src);
407        sources.insert("softcap_bf16".into(), softcap_src);
408
409        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
410        //   normed = rms_norm(input, weight, eps);  output = residual + normed
411        let fused_norm_add_src: &'static str =
412            include_str!("shaders/fused_norm_add_bf16.metal");
413        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
414        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
415
416        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
417        let fused_hnr_f32_src: &'static str =
418            include_str!("shaders/fused_head_norm_rope_f32.metal");
419        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
420
421        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
422        // Both entry points live in the same .metal file.
423        let fused_hnr_bf16_src: &'static str =
424            include_str!("shaders/fused_head_norm_rope_bf16.metal");
425        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
426        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
427
428        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
429        let fused_norm_add_f32_src: &'static str =
430            include_str!("shaders/fused_norm_add_f32.metal");
431        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
432        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
433        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
434        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
435        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
436        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
437        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
438        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
439
440        // Argsort kernel (Story 2.3) — MoE top-K routing
441        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
442        sources.insert("argsort_desc_f32".into(), argsort_src);
443
444        // Gather / index_select kernel (Story 2.4)
445        let gather_src: &'static str = include_str!("shaders/gather.metal");
446        sources.insert("gather_f32".into(), gather_src);
447
448        // F32 KV cache copy kernel (Session merge S1+S2)
449        let kv_cache_copy_src: &'static str =
450            include_str!("shaders/kv_cache_copy.metal");
451        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
452        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
453
454        // Strided copy kernel (Story 2.5)
455        let copy_src: &'static str = include_str!("shaders/copy.metal");
456        sources.insert("strided_copy_f32".into(), copy_src);
457        sources.insert("offset_copy_f32".into(), copy_src);
458
459        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
460        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
461        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
462        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
463        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
464        // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
465        sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
466        // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
467        sources.insert("dense_matvec_f32".into(), dense_gemm_src);
468
469        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
470        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
471        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
472        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
473        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
474        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
475        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
476        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
477        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
478
479        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
480        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
481        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
482        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
483        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
484        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
485        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
486
487        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
488        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
489        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
490        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
491        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
492
493        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
494        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
495        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
496        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
497
498        // GPU sampling kernels — eliminate logits readback (Phase 6)
499        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
500        sources.insert("argmax_f32".into(), argmax_src);
501        let softmax_sample_src: &'static str =
502            include_str!("shaders/softmax_sample.metal");
503        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
504        // Top-K kernel for Q8 rerank: avoids full-logits readback.
505        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
506        sources.insert("top_k_f32".into(), top_k_src);
507
508        // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
509        // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
510        let moe_stk_src: &'static str =
511            include_str!("shaders/moe_softmax_topk.metal");
512        sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
513        let moe_wr_src: &'static str =
514            include_str!("shaders/moe_weighted_reduce.metal");
515        sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
516        let sdpa_decode_src: &'static str =
517            include_str!("shaders/sdpa_decode.metal");
518        sources.insert("sdpa_decode".into(), sdpa_decode_src);
519
520        Self {
521            cache: HashMap::new(),
522            sources,
523        }
524    }
525
526    /// Register a shader source at runtime (useful for testing and dynamic
527    /// kernel generation).
528    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
529        let name = name.into();
530        // Invalidate any cached pipeline for this name since the source changed.
531        self.cache.remove(&name);
532        self.sources.insert(name, source);
533    }
534
535    /// Get a compiled compute pipeline for the named kernel function.
536    ///
537    /// On first call for a given name, this compiles the MSL source into a
538    /// Metal library, extracts the named function, and creates a
539    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
540    ///
541    /// # Errors
542    ///
543    /// * `MlxError::KernelNotFound` — no source registered for this name.
544    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
545    ///   creation failed.
546    pub fn get_pipeline(
547        &mut self,
548        name: &str,
549        device: &metal::DeviceRef,
550    ) -> Result<&ComputePipelineState> {
551        if !self.cache.contains_key(name) {
552            // Slow path: compile the shader.
553            let source = self.sources.get(name).ok_or_else(|| {
554                MlxError::KernelNotFound(name.to_string())
555            })?;
556
557            let compile_opts = metal::CompileOptions::new();
558            let library = device
559                .new_library_with_source(source, &compile_opts)
560                .map_err(|msg| MlxError::ShaderCompilationError {
561                    name: name.to_string(),
562                    message: msg,
563                })?;
564
565            let function = library
566                .get_function(name, None)
567                .map_err(|msg| MlxError::ShaderCompilationError {
568                    name: name.to_string(),
569                    message: msg,
570                })?;
571
572            let pipeline = device
573                .new_compute_pipeline_state_with_function(&function)
574                .map_err(|msg| MlxError::ShaderCompilationError {
575                    name: name.to_string(),
576                    message: msg,
577                })?;
578
579            self.cache.insert(name.to_string(), pipeline);
580        }
581
582        // At this point the pipeline is guaranteed to be in the cache.
583        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
584        self.cache.get(name).ok_or_else(|| {
585            MlxError::KernelNotFound(name.to_string())
586        })
587    }
588
589    /// Get a compiled compute pipeline for the named kernel, specialized with
590    /// Metal function constants (both bool and i32 in one call).
591    ///
592    /// `bool_constants` contains `(index, value)` pairs mapping to
593    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
594    /// `int_constants` contains `(index, value)` pairs mapping to
595    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
596    /// shader.
597    ///
598    /// Pipelines are cached by a composite key:
599    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
600    /// marks bool entries and the 'i' prefix marks i32 entries, making the
601    /// format unambiguous regardless of constant ordering.  Distinct
602    /// `(name, constants)` combinations each compile to a separate pipeline;
603    /// the slow compilation path runs at most once per unique combination.
604    ///
605    /// # Errors
606    ///
607    /// * `MlxError::KernelNotFound` — no source registered for this name.
608    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
609    ///   specialisation, or pipeline creation failed.
610    pub fn get_pipeline_with_constants(
611        &mut self,
612        name: &str,
613        device: &metal::DeviceRef,
614        bool_constants: &[(usize, bool)],
615        int_constants: &[(usize, i32)],
616    ) -> Result<&ComputePipelineState> {
617        // Build a composite cache key so distinct constant combinations each
618        // compile to their own pipeline.  Bool entries use the 'b' type marker
619        // and i32 entries use 'i'; this prevents a collision between, e.g.,
620        // bool index 5 value 1 and int index 5 value 1.
621        let mut cache_key = name.to_string();
622        for &(index, value) in bool_constants {
623            cache_key.push('|');
624            cache_key.push_str(&index.to_string());
625            cache_key.push_str(if value { ":b1" } else { ":b0" });
626        }
627        for &(index, value) in int_constants {
628            cache_key.push('|');
629            cache_key.push_str(&index.to_string());
630            cache_key.push(':');
631            cache_key.push('i');
632            cache_key.push_str(&value.to_string());
633        }
634
635        if !self.cache.contains_key(&cache_key) {
636            // Slow path: compile the shader with function constant specialisation.
637            let source = self.sources.get(name).ok_or_else(|| {
638                MlxError::KernelNotFound(name.to_string())
639            })?;
640
641            let compile_opts = metal::CompileOptions::new();
642            let library = device
643                .new_library_with_source(source, &compile_opts)
644                .map_err(|msg| MlxError::ShaderCompilationError {
645                    name: name.to_string(),
646                    message: msg,
647                })?;
648
649            // Build the FunctionConstantValues object with all bool and i32
650            // constants.  Metal's set_constant_value_at_index reads the value
651            // through a raw pointer; the pointed-to bytes must match the size
652            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
653            let fcv = FunctionConstantValues::new();
654
655            for &(index, value) in bool_constants {
656                // MTLDataType::Bool = 53 (metal-rs argument.rs).
657                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
658                let v: u8 = if value { 1 } else { 0 };
659                fcv.set_constant_value_at_index(
660                    (&v as *const u8).cast::<std::ffi::c_void>(),
661                    MTLDataType::Bool,
662                    index as u64,
663                );
664            }
665
666            for &(index, value) in int_constants {
667                // MTLDataType::Int = 29 (metal-rs argument.rs).
668                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
669                // matching the Metal shader type `constant int`.
670                fcv.set_constant_value_at_index(
671                    (&value as *const i32).cast::<std::ffi::c_void>(),
672                    MTLDataType::Int,
673                    index as u64,
674                );
675            }
676
677            let function = library
678                .get_function(name, Some(fcv))
679                .map_err(|msg| MlxError::ShaderCompilationError {
680                    name: name.to_string(),
681                    message: msg,
682                })?;
683
684            let pipeline = device
685                .new_compute_pipeline_state_with_function(&function)
686                .map_err(|msg| MlxError::ShaderCompilationError {
687                    name: name.to_string(),
688                    message: msg,
689                })?;
690
691            self.cache.insert(cache_key.clone(), pipeline);
692        }
693
694        self.cache.get(&cache_key).ok_or_else(|| {
695            MlxError::KernelNotFound(name.to_string())
696        })
697    }
698
699    /// Get a compiled compute pipeline for the named kernel, specialized with
700    /// Metal bool function constants.
701    ///
702    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
703    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
704    ///
705    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
706    /// passes an empty `int_constants` slice.  Existing callers continue to
707    /// work without modification; the cache-key format for pure-bool pipelines
708    /// is compatible (bool entries carry the 'b' type marker, which is the
709    /// only format ever written by this wrapper).
710    ///
711    /// # Errors
712    ///
713    /// * `MlxError::KernelNotFound` — no source registered for this name.
714    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
715    ///   specialisation, or pipeline creation failed.
716    pub fn get_pipeline_with_bool_constants(
717        &mut self,
718        name: &str,
719        device: &metal::DeviceRef,
720        bool_constants: &[(usize, bool)],
721    ) -> Result<&ComputePipelineState> {
722        self.get_pipeline_with_constants(name, device, bool_constants, &[])
723    }
724
725    /// Check if a pipeline for the given name is already compiled and cached.
726    pub fn is_cached(&self, name: &str) -> bool {
727        self.cache.contains_key(name)
728    }
729
730    /// Number of compiled pipelines currently in the cache.
731    pub fn cached_count(&self) -> usize {
732        self.cache.len()
733    }
734
735    /// Number of registered shader sources.
736    pub fn source_count(&self) -> usize {
737        self.sources.len()
738    }
739}
740
741impl Default for KernelRegistry {
742    fn default() -> Self {
743        Self::new()
744    }
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750
751    /// Minimal Metal shader that uses a single int function constant.
752    ///
753    /// The kernel writes the constant value N into the first element of the
754    /// output buffer, allowing the test to verify that the Metal compiler
755    /// actually sees distinct specialisations for N=4 and N=8.
756    ///
757    /// The shader is intentionally trivial — we only need it to *compile* with
758    /// an int function constant; correctness of the kernel logic is not under
759    /// test here.
760    const INT_FC_TEST_SHADER: &str = r#"
761#include <metal_stdlib>
762using namespace metal;
763
764constant int test_N [[function_constant(100)]];
765
766kernel void int_fc_test_kernel(
767    device int* out [[buffer(0)]],
768    uint tid [[thread_position_in_grid]])
769{
770    if (tid == 0) {
771        out[0] = test_N;
772    }
773}
774"#;
775
776    /// Verify that `get_pipeline_with_constants` produces distinct cached
777    /// pipelines for different i32 function-constant values, and that
778    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
779    /// works correctly with the new 'b'-prefixed cache-key format.
780    ///
781    /// This test requires a real Metal device and is therefore marked
782    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
783    #[test]
784    fn test_int_fc_distinct_pipelines_and_bool_compat() {
785        let device = metal::Device::system_default()
786            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
787
788        let mut registry = KernelRegistry::new();
789
790        // Register the inline test shader under a name that cannot collide with
791        // any production kernel.
792        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
793
794        // Compile with N=4.
795        let p4_ptr = registry
796            .get_pipeline_with_constants(
797                "int_fc_test_kernel",
798                &device,
799                &[],                  // no bool constants
800                &[(100, 4_i32)],      // int constant index 100 = 4
801            )
802            .expect("pipeline N=4 should compile") as *const _;
803
804        // Cache must now have exactly 1 entry for this kernel.
805        // (Other production kernels may already be in cache from new(); here
806        // we check that the N=4 key was inserted.)
807        let count_after_n4 = registry.cached_count();
808
809        // Compile with N=8 — must produce a SEPARATE pipeline.
810        let p8_ptr = registry
811            .get_pipeline_with_constants(
812                "int_fc_test_kernel",
813                &device,
814                &[],
815                &[(100, 8_i32)],
816            )
817            .expect("pipeline N=8 should compile") as *const _;
818
819        // Cache must have grown by exactly 1.
820        assert_eq!(
821            registry.cached_count(),
822            count_after_n4 + 1,
823            "N=8 must produce a new cache entry"
824        );
825
826        // The two pipelines must be distinct objects in the cache.
827        assert_ne!(
828            p4_ptr, p8_ptr,
829            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
830        );
831
832        // A second call with N=4 must return the SAME pipeline (cache hit, no
833        // new compilation).
834        let p4_again_ptr = registry
835            .get_pipeline_with_constants(
836                "int_fc_test_kernel",
837                &device,
838                &[],
839                &[(100, 4_i32)],
840            )
841            .expect("pipeline N=4 cache hit should succeed") as *const _;
842
843        assert_eq!(
844            registry.cached_count(),
845            count_after_n4 + 1,
846            "repeated N=4 call must be a cache hit, not a new entry"
847        );
848        assert_eq!(
849            p4_ptr, p4_again_ptr,
850            "repeated N=4 call must return the same pipeline pointer"
851        );
852
853        // Verify backward compatibility: get_pipeline_with_bool_constants must
854        // still route through get_pipeline_with_constants and produce a cached
855        // pipeline without panicking.
856        //
857        // We register a separate bool-constant shader that does NOT use a bool
858        // function constant (so the Metal compiler ignores missing FCs for
859        // this trivial case) — but the call path and cache-key format are what
860        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
861        // simply unused by the shader (Metal allows unused FCs when the shader
862        // declares them with `function_constant` but the value is never read).
863        //
864        // To avoid a Metal compiler error for an undeclared function constant,
865        // we register a separate bare-kernel shader for the bool wrapper test.
866        const BARE_SHADER: &str = r#"
867#include <metal_stdlib>
868using namespace metal;
869kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
870    if (tid == 0) { out[0] = 42; }
871}
872"#;
873        registry.register_source("bare_kernel", BARE_SHADER);
874
875        let count_before_bool = registry.cached_count();
876        let _bool_pipeline = registry
877            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
878            .expect("bool-constants wrapper with empty slice must succeed");
879
880        assert_eq!(
881            registry.cached_count(),
882            count_before_bool + 1,
883            "bool-constants wrapper must insert one new cache entry"
884        );
885    }
886}