Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
15//   Int  = 29
16//   Bool = 53
17// These are used when calling set_constant_value_at_index so the Metal runtime
18// knows how wide each constant value is.
19
20/// Registry that lazily compiles and caches Metal compute pipelines from
21/// embedded MSL source.
22///
23/// # Usage
24///
25/// ```ignore
26/// let mut registry = KernelRegistry::new();
27/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
28/// encoder.encode(&pipeline, &buffers, grid, tg);
29/// ```
30///
31/// # Thread Safety
32///
33/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
34/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
35/// access, wrap it in a `Mutex` or use one registry per thread.
36pub struct KernelRegistry {
37    /// Cached pipelines keyed by kernel function name.
38    cache: HashMap<String, ComputePipelineState>,
39    /// MSL source text keyed by kernel function name.
40    ///
41    /// Populated at construction time with all embedded shader sources.
42    sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46    /// Create a new registry with all embedded shader sources pre-registered.
47    ///
48    /// No compilation happens here — shaders are compiled lazily on first use.
49    pub fn new() -> Self {
50        let mut sources = HashMap::new();
51
52        // Register embedded shader sources.
53        sources.insert(
54            "placeholder".into(),
55            include_str!("shaders/placeholder.metal"),
56        );
57        sources.insert(
58            "quantized_matmul".into(),
59            include_str!("shaders/quantized_matmul.metal"),
60        );
61        sources.insert(
62            "quantized_matmul_simd".into(),
63            include_str!("shaders/quantized_matmul.metal"),
64        );
65        sources.insert(
66            "quantized_matmul_simd_bf16".into(),
67            include_str!("shaders/quantized_matmul.metal"),
68        );
69        sources.insert(
70            "quantized_matmul_simd_bf16_expert".into(),
71            include_str!("shaders/quantized_matmul.metal"),
72        );
73
74        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
75        let ggml_src: &'static str =
76            include_str!("shaders/quantized_matmul_ggml.metal");
77        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80        // ADR-013 P7 — Q4_K dense decode mat-vec (port of llama.cpp's
81        // kernel_mul_mv_q4_K_f32 at ggml-metal.metal:7715-7821).
82        sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
83
84        // GGML block-format quantized matrix-matrix kernels
85        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
86        // Used at prefill m > 8 to reuse each weight tile across a 32-row
87        // block via threadgroup-staged simdgroup MMA, instead of re-reading
88        // every block per prompt-token as the mv kernel does.
89        let ggml_mm_src: &'static str =
90            include_str!("shaders/quantized_matmul_mm.metal");
91        sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
92        sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
93        sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
94
95        // GGML block-format quantized matrix-matrix kernels — tensor API
96        // variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
97        // kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
98        // Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
99        // primitive which on M3+ dispatches to hardware tensor cores for
100        // 2-3x the effective FLOP throughput vs the simdgroup MMA path.
101        // Only compiled on devices where the tensor API is available; the
102        // kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
103        // compilation so non-tensor devices transparently fall back to the
104        // non-tensor `kernel_mul_mm_<q>_f32` kernels.
105        let ggml_mm_tensor_src: &'static str =
106            include_str!("shaders/quantized_matmul_mm_tensor.metal");
107        sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
108        sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
109        sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
110        sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
111        sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
112
113        // Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
114        // prefill Q@K^T and scores@V, modeled on llama.cpp's
115        // kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
116        // active).  Tile geometry and write-back identical to the
117        // quantized tensor kernel; only the A-stage copy (bfloat →
118        // bfloat, no dequantize) differs.
119        let dense_mm_bf16_tensor_src: &'static str =
120            include_str!("shaders/dense_mm_bf16_tensor.metal");
121        sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
122
123        // Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
124        // sibling of dense_mm_bf16_tensor).  Used by hf2q's ADR-005
125        // iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
126        // the BF16 K-stage cast as a confounding variable.  Port of
127        // llama.cpp's kernel_mul_mm_f32_f32 specialization
128        // (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
129        // branch.  Same tile geometry (NR0=64 NR1=32 NK=32) but
130        // float-everywhere shmem staging.
131        let dense_mm_f32_f32_tensor_src: &'static str =
132            include_str!("shaders/dense_mm_f32_f32.metal");
133        sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
134
135        // Dense f16×f32 → f32 tensor-API matmul (F16-staging sibling
136        // of dense_mm_bf16_tensor).  Used by hf2q's ADR-005 Phase 2c
137        // iter-128 gemma4v ViT precision-parity path: every mmproj
138        // weight is stored as F16 in GGUF, peer's `kernel_mul_mm_f16_f32`
139        // (`ggml-metal.metal:10099`) stages BOTH A and B as `half` in
140        // shmem and computes on `simdgroup_half8x8`.  Matches peer
141        // per-element rounding budget exactly (10-bit mantissa vs
142        // BF16's 7-bit), closing the 1.16x/block cascade compound that
143        // iter-127 numerically bisected to BF16 staging.  Same tile
144        // geometry as the BF16 sibling (NR0=64 NR1=32 NK=32, 8 KB
145        // shmem) — half and bfloat share 16-bit storage.
146        let dense_mm_f16_tensor_src: &'static str =
147            include_str!("shaders/dense_mm_f16_tensor.metal");
148        sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
149
150        // Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
151        // for M=1 single-token decode.  Port of llama.cpp's
152        // kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
153        // Used in apply_linear_projection_f32 when seq_len=1 and the
154        // weight matrix is BF16, replacing the MM kernel (~2× faster for
155        // M=1 due to better memory bandwidth utilization per thread).
156        let dense_gemv_bf16_src: &'static str =
157            include_str!("shaders/dense_gemv_bf16.metal");
158        sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
159
160        // Fused scale-mask-softmax for the non-flash-attention prefill
161        // path.  One row-local threadgroup per (head, query) pair
162        // replaces three separate dispatches (scale, mask-add, softmax);
163        // reads a bf16 mask (-INF at masked positions, matching
164        // flash_attn_prefill_mask.metal) that is shared across heads.
165        let scale_mask_softmax_src: &'static str =
166            include_str!("shaders/scale_mask_softmax.metal");
167        sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
168
169        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
170        sources.insert(
171            "quantized_matmul_id".into(),
172            include_str!("shaders/quantized_matmul_id.metal"),
173        );
174
175        // Expert-routed (MoE) GGML block-format quantized matmul kernels
176        let ggml_id_src: &'static str =
177            include_str!("shaders/quantized_matmul_id_ggml.metal");
178        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
179        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
180        // ADR-013 P7 — Q4_K MoE expert-routed mat-vec (port of
181        // llama.cpp's kernel_mul_mv_id_q4_K_f32 at ggml-metal.metal:10349).
182        sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
183        sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
184        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
185        // Fused-SwiGLU mv_id variants (ADR-012 §Optimize / Task #15):
186        // computes y[r][n] = sum_k(dequant(W[expert][n][k]) * silu(gate[r][k]) * up[r][k])
187        // in one dispatch — replaces silu_mul + expert_down sequence.
188        sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
189
190        // Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
191        // (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
192        // `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
193        // Two-stage dispatch: map0 regroups the token-to-expert table into
194        // per-expert routed-token lists, then mm_id stages a 64x32 expert
195        // weight tile into threadgroup shmem and reuses it across a 32-row
196        // block of that expert's routed tokens.
197        let ggml_id_mm_src: &'static str =
198            include_str!("shaders/quantized_matmul_id_mm.metal");
199        sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
200        sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
201        sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
202        sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
203        sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
204        // ADR-013 P16 — Q4_K mm_id (port of llama.cpp ggml-metal.metal:10169).
205        sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
206
207        // MoE-routed quantized matrix-matrix kernels — tensor API variant
208        // (ADR-011 Phase 3 Wave P3b-tensor).  Uses the MPP tensor_ops
209        // matmul2d primitive for hardware-tensor-core MMA on M3+.  Only
210        // the mm_id kernel is ported — map0 is a short pre-pass (not
211        // matmul) and continues to use the simdgroup version.
212        let ggml_id_mm_tensor_src: &'static str =
213            include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
214        sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
215        sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
216        sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
217        // ADR-013 P16 — Q4_K tensor-API mm_id.
218        sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
219
220        // Embedding kernels (Story 1.5)
221        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
222        sources.insert("embedding_gather_4bit".into(), embedding_src);
223        sources.insert("embedding_gather_6bit".into(), embedding_src);
224
225        // MoE gate kernel (Story 1.5)
226        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
227        sources.insert("moe_gate".into(), moe_gate_src);
228
229        // MoE dispatch kernels (Story 1.5)
230        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
231        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
232        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
233        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
234        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
235        sources.insert("moe_accumulate".into(), moe_dispatch_src);
236        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
237        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
238        sources.insert("zero_buffer".into(), moe_dispatch_src);
239        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
240        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
241        // bf16 variants (Phase 2 bf16 activation path)
242        sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
243        sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
244        sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
245
246        // Batched KV cache copy kernels
247        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
248        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
249        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
250        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
251        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
252        // Wave P4.11 — fused K+V copy variants
253        sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
254        sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
255        // bf16-source KV cache copy (Phase 2 bf16 activation path)
256        sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
257
258        // Elementwise and transpose kernels (Story 1.5)
259        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
260        sources.insert("elementwise_add_f32".into(), elementwise_src);
261        sources.insert("elementwise_add_f16".into(), elementwise_src);
262        sources.insert("elementwise_mul_f32".into(), elementwise_src);
263        sources.insert("elementwise_mul_f16".into(), elementwise_src);
264        sources.insert("elementwise_add_bf16".into(), elementwise_src);
265        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
266        sources.insert("cast_f16_to_f32".into(), elementwise_src);
267        sources.insert("cast_f32_to_f16".into(), elementwise_src);
268        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
269        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
270        sources.insert("scalar_mul_bf16".into(), elementwise_src);
271        sources.insert("scalar_mul_f32".into(), elementwise_src);
272        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
273        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
274        sources.insert("permute_021_bf16".into(), elementwise_src);
275        sources.insert("transpose_last2_bf16".into(), elementwise_src);
276        sources.insert("transpose_last2_f16".into(), elementwise_src);
277        sources.insert("permute_021_f32".into(), elementwise_src);
278        sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
279        sources.insert("transpose_2d_f32".into(), elementwise_src);
280        sources.insert("transpose_2d_f16".into(), elementwise_src);
281
282        // Attention kernels (Story 1.3)
283        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
284        sources.insert("sdpa".into(), sdpa_src);
285        sources.insert("sdpa_bf16".into(), sdpa_src);
286        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
287        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
288        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
289
290        // Flash-attention tiled prefill kernel (ADR-011 Phase 1).
291        // Ten entry points; all backed by the same shader source.
292        // Pipelines are compiled with function constants via
293        // `get_pipeline_with_bool_constants` — not `get_pipeline`.
294        let flash_attn_prefill_src: &'static str =
295            include_str!("shaders/flash_attn_prefill.metal");
296        // D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
297        sources.insert(
298            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
299            flash_attn_prefill_src,
300        );
301        sources.insert(
302            "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
303            flash_attn_prefill_src,
304        );
305        sources.insert(
306            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
307            flash_attn_prefill_src,
308        );
309        sources.insert(
310            "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
311            flash_attn_prefill_src,
312        );
313        sources.insert(
314            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
315            flash_attn_prefill_src,
316        );
317        sources.insert(
318            "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
319            flash_attn_prefill_src,
320        );
321        // D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
322        // NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
323        // the 32 KB Metal limit (candle sdpa.rs:86-94).
324        sources.insert(
325            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
326            flash_attn_prefill_src,
327        );
328        sources.insert(
329            "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
330            flash_attn_prefill_src,
331        );
332        sources.insert(
333            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
334            flash_attn_prefill_src,
335        );
336        sources.insert(
337            "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
338            flash_attn_prefill_src,
339        );
340
341        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
342        // (ported from llama.cpp flash_attn_ext_vec)
343        let flash_attn_vec_src: &'static str =
344            include_str!("shaders/flash_attn_vec.metal");
345        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
346        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
347        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
348        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
349        // F16 KV variants (Phase 4a)
350        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
351        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
352
353        // RoPE, normalization, activation kernels (Story 1.4)
354        let rope_src: &'static str = include_str!("shaders/rope.metal");
355        sources.insert("rope_f32".into(), rope_src);
356        sources.insert("rope_f16".into(), rope_src);
357        sources.insert("rope_bf16".into(), rope_src);
358        sources.insert("rope_neox_bf16".into(), rope_src);
359        sources.insert("rope_neox_f32".into(), rope_src);
360        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
361        sources.insert("rms_norm_f32".into(), rms_norm_src);
362        sources.insert("rms_norm_f16".into(), rms_norm_src);
363        sources.insert("rms_norm_bf16".into(), rms_norm_src);
364        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
365        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
366        sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
367        sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
368        sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
369        sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
370        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
371        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
372        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
373        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
374        // L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
375        let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
376        sources.insert("l2_norm_f32".into(), l2_norm_src);
377        sources.insert("l2_norm_f16".into(), l2_norm_src);
378        sources.insert("l2_norm_bf16".into(), l2_norm_src);
379        // Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
380        let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
381        sources.insert("cumsum_f32".into(), cumsum_src);
382        sources.insert("cumsum_bf16".into(), cumsum_src);
383        // SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
384        let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
385        sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
386        sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
387        sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
388        sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
389        // Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
390        let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
391        sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
392        sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
393        // Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
394        let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
395        sources.insert("rope_multi_f32".into(), rope_multi_src);
396        sources.insert("rope_multi_bf16".into(), rope_multi_src);
397        // Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
398        let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
399        sources.insert("gated_delta_net_f32".into(), gdn_src);
400        // ADR-015 iter56 — decode-only `simd_sum` variant. Three NSG-templated
401        // host names share the same source; selection is by D_k via
402        // `dispatch_gated_delta_net_decode`. Drop-in for the fused kernel
403        // above when n_tokens=1.
404        let gdn_decode_src: &'static str =
405            include_str!("shaders/gated_delta_net_decode.metal");
406        sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
407        sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
408        sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
409        // Wave 5b — chunk-parallel inter-chunk state-recurrence kernel
410        // (the one new kernel in the chunk-parallel pipeline; spec source:
411        // arXiv 2412.06464 §4 + FLA chunk_delta_h.py:43-298).
412        let gdn_chunk_src: &'static str =
413            include_str!("shaders/gated_delta_net_chunk.metal");
414        sources.insert(
415            "gated_delta_net_chunk_inter_state_bf16".into(),
416            gdn_chunk_src,
417        );
418        // Wave 5b.1 iter 2 — chunk_scaled_dot_kkt kernel (input-side of
419        // the chunk pipeline; spec source: FLA chunk_scaled_dot_kkt.py:36-99).
420        let gdn_kkt_src: &'static str =
421            include_str!("shaders/gated_delta_net_kkt.metal");
422        sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
423        // Wave 5b.1 iter 2 — recompute_w_u_fwd kernel (applies post-solve A
424        // to (β·v) and (β·k·exp(g)) to produce w and u; spec source: FLA
425        // wy_fast.py:29-117).
426        let gdn_recompute_wu_src: &'static str =
427            include_str!("shaders/gated_delta_net_recompute_wu.metal");
428        sources.insert(
429            "gated_delta_net_recompute_wu_bf16".into(),
430            gdn_recompute_wu_src,
431        );
432        // Wave 5b.1 iter 3 — chunk_fwd_o kernel (per-chunk output: closes
433        // the chunk pipeline; spec source: FLA chunk_o.py:42-138).
434        let gdn_chunk_o_src: &'static str =
435            include_str!("shaders/gated_delta_net_chunk_o.metal");
436        sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
437        // Wave 5b.1 iter 4 — orchestrator helper kernels:
438        //   chunk_local_cumsum_g_f32      — per-chunk prefix sum on g [B, T, H]
439        //   chunk_tri_solve_invert_f32    — per-chunk-block (I + A_strict)^-1
440        //                                   on FLA's [B, T, H, BT] layout.
441        let chunk_local_cumsum_g_src: &'static str =
442            include_str!("shaders/chunk_local_cumsum_g.metal");
443        sources.insert(
444            "chunk_local_cumsum_g_f32".into(),
445            chunk_local_cumsum_g_src,
446        );
447        let chunk_tri_solve_invert_src: &'static str =
448            include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
449        sources.insert(
450            "chunk_tri_solve_invert_f32".into(),
451            chunk_tri_solve_invert_src,
452        );
453        // Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
454        let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
455        sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
456        sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
457        let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
458        sources.insert("silu_mul_f32".into(), silu_mul_src);
459        let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
460        sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
461        let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
462        sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
463        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
464        sources.insert("gelu_f32".into(), gelu_src);
465        sources.insert("gelu_f16".into(), gelu_src);
466        sources.insert("gelu_bf16".into(), gelu_src);
467        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
468        sources.insert("softmax_f32".into(), softmax_src);
469        sources.insert("softmax_f16".into(), softmax_src);
470        sources.insert("softmax_bf16".into(), softmax_src);
471        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
472        sources.insert("softcap_f32".into(), softcap_src);
473        sources.insert("softcap_f16".into(), softcap_src);
474        sources.insert("softcap_bf16".into(), softcap_src);
475
476        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
477        //   normed = rms_norm(input, weight, eps);  output = residual + normed
478        let fused_norm_add_src: &'static str =
479            include_str!("shaders/fused_norm_add_bf16.metal");
480        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
481        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
482
483        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
484        let fused_hnr_f32_src: &'static str =
485            include_str!("shaders/fused_head_norm_rope_f32.metal");
486        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
487
488        // Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
489        // Both entry points live in the same .metal file.
490        let fused_hnr_bf16_src: &'static str =
491            include_str!("shaders/fused_head_norm_rope_bf16.metal");
492        sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
493        sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
494
495        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
496        let fused_norm_add_f32_src: &'static str =
497            include_str!("shaders/fused_norm_add_f32.metal");
498        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
499        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
500        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
501        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
502        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
503        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
504        sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
505        sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
506
507        // Argsort kernel (Story 2.3) — MoE top-K routing
508        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
509        sources.insert("argsort_desc_f32".into(), argsort_src);
510
511        // Gather / index_select kernel (Story 2.4)
512        let gather_src: &'static str = include_str!("shaders/gather.metal");
513        sources.insert("gather_f32".into(), gather_src);
514
515        // F32 KV cache copy kernel (Session merge S1+S2)
516        let kv_cache_copy_src: &'static str =
517            include_str!("shaders/kv_cache_copy.metal");
518        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
519        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
520
521        // Strided copy kernel (Story 2.5)
522        let copy_src: &'static str = include_str!("shaders/copy.metal");
523        sources.insert("strided_copy_f32".into(), copy_src);
524        sources.insert("offset_copy_f32".into(), copy_src);
525
526        // Fused-QKV split kernel (ADR-005 W-5b.18 — replaces hf2q CPU
527        // download → triple-loop split → 3× upload round-trip in
528        // gpu_delta_net::layer_qkv_deinterleave).
529        let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
530        sources.insert("qkv_split_f32".into(), qkv_split_src);
531
532        // Tiled-GQA broadcast kernel (ADR-005 W-5b.19 — replaces hf2q CPU
533        // tiled-replicate at gpu_delta_net::apply_gated_delta_net_chunk
534        // GQA pre-expansion, ~497 ms / 10.4 ms-per-layer at PP4106).
535        let repeat_tiled_src: &'static str =
536            include_str!("shaders/repeat_tiled.metal");
537        sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
538
539        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
540        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
541        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
542        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
543        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
544        // BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
545        sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
546        // Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
547        sources.insert("dense_matvec_f32".into(), dense_gemm_src);
548
549        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
550        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
551        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
552        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
553        // ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
554        sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
555        sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
556        sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
557        sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
558
559        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
560        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
561        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
562        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
563        // Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
564        sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
565        sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
566
567        // iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
568        let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
569        sources.insert("tq_dequantize_kv".into(), tq_dq_src);
570        // Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
571        sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
572
573        // iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
574        let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
575        sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
576        sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
577
578        // GPU sampling kernels — eliminate logits readback (Phase 6)
579        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
580        sources.insert("argmax_f32".into(), argmax_src);
581        let softmax_sample_src: &'static str =
582            include_str!("shaders/softmax_sample.metal");
583        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
584        // Top-K kernel for Q8 rerank: avoids full-logits readback.
585        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
586        sources.insert("top_k_f32".into(), top_k_src);
587
588        // MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
589        // Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
590        let moe_stk_src: &'static str =
591            include_str!("shaders/moe_softmax_topk.metal");
592        sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
593        let moe_wr_src: &'static str =
594            include_str!("shaders/moe_weighted_reduce.metal");
595        sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
596        let sdpa_decode_src: &'static str =
597            include_str!("shaders/sdpa_decode.metal");
598        sources.insert("sdpa_decode".into(), sdpa_decode_src);
599
600        Self {
601            cache: HashMap::new(),
602            sources,
603        }
604    }
605
606    /// Register a shader source at runtime (useful for testing and dynamic
607    /// kernel generation).
608    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
609        let name = name.into();
610        // Invalidate any cached pipeline for this name since the source changed.
611        self.cache.remove(&name);
612        self.sources.insert(name, source);
613    }
614
615    /// Get a compiled compute pipeline for the named kernel function.
616    ///
617    /// On first call for a given name, this compiles the MSL source into a
618    /// Metal library, extracts the named function, and creates a
619    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
620    ///
621    /// # Errors
622    ///
623    /// * `MlxError::KernelNotFound` — no source registered for this name.
624    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
625    ///   creation failed.
626    pub fn get_pipeline(
627        &mut self,
628        name: &str,
629        device: &metal::DeviceRef,
630    ) -> Result<&ComputePipelineState> {
631        if !self.cache.contains_key(name) {
632            // Slow path: compile the shader.
633            let source = self.sources.get(name).ok_or_else(|| {
634                MlxError::KernelNotFound(name.to_string())
635            })?;
636
637            let compile_opts = metal::CompileOptions::new();
638            let library = device
639                .new_library_with_source(source, &compile_opts)
640                .map_err(|msg| MlxError::ShaderCompilationError {
641                    name: name.to_string(),
642                    message: msg,
643                })?;
644
645            let function = library
646                .get_function(name, None)
647                .map_err(|msg| MlxError::ShaderCompilationError {
648                    name: name.to_string(),
649                    message: msg,
650                })?;
651
652            // Build the pipeline through a descriptor so we can attach a
653            // human-readable label.  The label propagates into Instruments /
654            // xctrace Metal System Trace as the per-pipeline identifier
655            // (`metal-object-label` schema), giving us per-kernel attribution
656            // instead of the generic "Compute Command 0" placeholder.
657            //
658            // `MTLComputePipelineState.label` is read-only after creation per
659            // the Apple Metal spec; the only supported way to set it is via
660            // the descriptor before pipeline creation.  ADR-015 iter9b.
661            let descriptor = ComputePipelineDescriptor::new();
662            descriptor.set_compute_function(Some(&function));
663            descriptor.set_label(name);
664
665            let pipeline = device
666                .new_compute_pipeline_state(&descriptor)
667                .map_err(|msg| MlxError::ShaderCompilationError {
668                    name: name.to_string(),
669                    message: msg,
670                })?;
671
672            self.cache.insert(name.to_string(), pipeline);
673        }
674
675        // At this point the pipeline is guaranteed to be in the cache.
676        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
677        self.cache.get(name).ok_or_else(|| {
678            MlxError::KernelNotFound(name.to_string())
679        })
680    }
681
682    /// Get a compiled compute pipeline for the named kernel, specialized with
683    /// Metal function constants (both bool and i32 in one call).
684    ///
685    /// `bool_constants` contains `(index, value)` pairs mapping to
686    /// `[[function_constant(index)]]` bool declarations in the MSL shader.
687    /// `int_constants` contains `(index, value)` pairs mapping to
688    /// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
689    /// shader.
690    ///
691    /// Pipelines are cached by a composite key:
692    /// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`.  The 'b' prefix
693    /// marks bool entries and the 'i' prefix marks i32 entries, making the
694    /// format unambiguous regardless of constant ordering.  Distinct
695    /// `(name, constants)` combinations each compile to a separate pipeline;
696    /// the slow compilation path runs at most once per unique combination.
697    ///
698    /// # Errors
699    ///
700    /// * `MlxError::KernelNotFound` — no source registered for this name.
701    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
702    ///   specialisation, or pipeline creation failed.
703    pub fn get_pipeline_with_constants(
704        &mut self,
705        name: &str,
706        device: &metal::DeviceRef,
707        bool_constants: &[(usize, bool)],
708        int_constants: &[(usize, i32)],
709    ) -> Result<&ComputePipelineState> {
710        // Build a composite cache key so distinct constant combinations each
711        // compile to their own pipeline.  Bool entries use the 'b' type marker
712        // and i32 entries use 'i'; this prevents a collision between, e.g.,
713        // bool index 5 value 1 and int index 5 value 1.
714        let mut cache_key = name.to_string();
715        for &(index, value) in bool_constants {
716            cache_key.push('|');
717            cache_key.push_str(&index.to_string());
718            cache_key.push_str(if value { ":b1" } else { ":b0" });
719        }
720        for &(index, value) in int_constants {
721            cache_key.push('|');
722            cache_key.push_str(&index.to_string());
723            cache_key.push(':');
724            cache_key.push('i');
725            cache_key.push_str(&value.to_string());
726        }
727
728        if !self.cache.contains_key(&cache_key) {
729            // Slow path: compile the shader with function constant specialisation.
730            let source = self.sources.get(name).ok_or_else(|| {
731                MlxError::KernelNotFound(name.to_string())
732            })?;
733
734            let compile_opts = metal::CompileOptions::new();
735            let library = device
736                .new_library_with_source(source, &compile_opts)
737                .map_err(|msg| MlxError::ShaderCompilationError {
738                    name: name.to_string(),
739                    message: msg,
740                })?;
741
742            // Build the FunctionConstantValues object with all bool and i32
743            // constants.  Metal's set_constant_value_at_index reads the value
744            // through a raw pointer; the pointed-to bytes must match the size
745            // declared in the MSL shader (1 byte for bool, 4 bytes for int).
746            let fcv = FunctionConstantValues::new();
747
748            for &(index, value) in bool_constants {
749                // MTLDataType::Bool = 53 (metal-rs argument.rs).
750                // The Metal runtime reads it as an Objective-C BOOL (uint8_t).
751                let v: u8 = if value { 1 } else { 0 };
752                fcv.set_constant_value_at_index(
753                    (&v as *const u8).cast::<std::ffi::c_void>(),
754                    MTLDataType::Bool,
755                    index as u64,
756                );
757            }
758
759            for &(index, value) in int_constants {
760                // MTLDataType::Int = 29 (metal-rs argument.rs).
761                // The Metal runtime reads 4 bytes as a signed 32-bit integer,
762                // matching the Metal shader type `constant int`.
763                fcv.set_constant_value_at_index(
764                    (&value as *const i32).cast::<std::ffi::c_void>(),
765                    MTLDataType::Int,
766                    index as u64,
767                );
768            }
769
770            let function = library
771                .get_function(name, Some(fcv))
772                .map_err(|msg| MlxError::ShaderCompilationError {
773                    name: name.to_string(),
774                    message: msg,
775                })?;
776
777            // Label this specialisation with the full composite cache key
778            // (e.g. `kernel_mul_mv_q4_0_f32|0:b1|3:i32`) so xctrace Metal
779            // System Trace shows each function-constant variant as a distinct
780            // pipeline.  Without this, all specialisations share a generic
781            // "Compute Command 0" identifier and we cannot attribute µs/token
782            // to a specific (kernel, constants) combination.  ADR-015 iter9b.
783            let descriptor = ComputePipelineDescriptor::new();
784            descriptor.set_compute_function(Some(&function));
785            descriptor.set_label(&cache_key);
786
787            let pipeline = device
788                .new_compute_pipeline_state(&descriptor)
789                .map_err(|msg| MlxError::ShaderCompilationError {
790                    name: name.to_string(),
791                    message: msg,
792                })?;
793
794            self.cache.insert(cache_key.clone(), pipeline);
795        }
796
797        self.cache.get(&cache_key).ok_or_else(|| {
798            MlxError::KernelNotFound(name.to_string())
799        })
800    }
801
802    /// Get a compiled compute pipeline for the named kernel, specialized with
803    /// Metal bool function constants.
804    ///
805    /// The `bool_constants` slice contains `(index, value)` pairs.  Each pair
806    /// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
807    ///
808    /// This is a thin wrapper around [`get_pipeline_with_constants`] that
809    /// passes an empty `int_constants` slice.  Existing callers continue to
810    /// work without modification; the cache-key format for pure-bool pipelines
811    /// is compatible (bool entries carry the 'b' type marker, which is the
812    /// only format ever written by this wrapper).
813    ///
814    /// # Errors
815    ///
816    /// * `MlxError::KernelNotFound` — no source registered for this name.
817    /// * `MlxError::ShaderCompilationError` — MSL compilation, function
818    ///   specialisation, or pipeline creation failed.
819    pub fn get_pipeline_with_bool_constants(
820        &mut self,
821        name: &str,
822        device: &metal::DeviceRef,
823        bool_constants: &[(usize, bool)],
824    ) -> Result<&ComputePipelineState> {
825        self.get_pipeline_with_constants(name, device, bool_constants, &[])
826    }
827
828    /// Check if a pipeline for the given name is already compiled and cached.
829    pub fn is_cached(&self, name: &str) -> bool {
830        self.cache.contains_key(name)
831    }
832
833    /// Number of compiled pipelines currently in the cache.
834    pub fn cached_count(&self) -> usize {
835        self.cache.len()
836    }
837
838    /// Number of registered shader sources.
839    pub fn source_count(&self) -> usize {
840        self.sources.len()
841    }
842}
843
844impl Default for KernelRegistry {
845    fn default() -> Self {
846        Self::new()
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    /// Minimal Metal shader that uses a single int function constant.
855    ///
856    /// The kernel writes the constant value N into the first element of the
857    /// output buffer, allowing the test to verify that the Metal compiler
858    /// actually sees distinct specialisations for N=4 and N=8.
859    ///
860    /// The shader is intentionally trivial — we only need it to *compile* with
861    /// an int function constant; correctness of the kernel logic is not under
862    /// test here.
863    const INT_FC_TEST_SHADER: &str = r#"
864#include <metal_stdlib>
865using namespace metal;
866
867constant int test_N [[function_constant(100)]];
868
869kernel void int_fc_test_kernel(
870    device int* out [[buffer(0)]],
871    uint tid [[thread_position_in_grid]])
872{
873    if (tid == 0) {
874        out[0] = test_N;
875    }
876}
877"#;
878
879    /// Verify that `get_pipeline_with_constants` produces distinct cached
880    /// pipelines for different i32 function-constant values, and that
881    /// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
882    /// works correctly with the new 'b'-prefixed cache-key format.
883    ///
884    /// This test requires a real Metal device and is therefore marked
885    /// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
886    #[test]
887    fn test_int_fc_distinct_pipelines_and_bool_compat() {
888        let device = metal::Device::system_default()
889            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
890
891        let mut registry = KernelRegistry::new();
892
893        // Register the inline test shader under a name that cannot collide with
894        // any production kernel.
895        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
896
897        // Compile with N=4.
898        let p4_ptr = registry
899            .get_pipeline_with_constants(
900                "int_fc_test_kernel",
901                &device,
902                &[],                  // no bool constants
903                &[(100, 4_i32)],      // int constant index 100 = 4
904            )
905            .expect("pipeline N=4 should compile") as *const _;
906
907        // Cache must now have exactly 1 entry for this kernel.
908        // (Other production kernels may already be in cache from new(); here
909        // we check that the N=4 key was inserted.)
910        let count_after_n4 = registry.cached_count();
911
912        // Compile with N=8 — must produce a SEPARATE pipeline.
913        let p8_ptr = registry
914            .get_pipeline_with_constants(
915                "int_fc_test_kernel",
916                &device,
917                &[],
918                &[(100, 8_i32)],
919            )
920            .expect("pipeline N=8 should compile") as *const _;
921
922        // Cache must have grown by exactly 1.
923        assert_eq!(
924            registry.cached_count(),
925            count_after_n4 + 1,
926            "N=8 must produce a new cache entry"
927        );
928
929        // The two pipelines must be distinct objects in the cache.
930        assert_ne!(
931            p4_ptr, p8_ptr,
932            "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
933        );
934
935        // A second call with N=4 must return the SAME pipeline (cache hit, no
936        // new compilation).
937        let p4_again_ptr = registry
938            .get_pipeline_with_constants(
939                "int_fc_test_kernel",
940                &device,
941                &[],
942                &[(100, 4_i32)],
943            )
944            .expect("pipeline N=4 cache hit should succeed") as *const _;
945
946        assert_eq!(
947            registry.cached_count(),
948            count_after_n4 + 1,
949            "repeated N=4 call must be a cache hit, not a new entry"
950        );
951        assert_eq!(
952            p4_ptr, p4_again_ptr,
953            "repeated N=4 call must return the same pipeline pointer"
954        );
955
956        // Verify backward compatibility: get_pipeline_with_bool_constants must
957        // still route through get_pipeline_with_constants and produce a cached
958        // pipeline without panicking.
959        //
960        // We register a separate bool-constant shader that does NOT use a bool
961        // function constant (so the Metal compiler ignores missing FCs for
962        // this trivial case) — but the call path and cache-key format are what
963        // matter here.  We reuse the int_fc_test_kernel source; the bool FC is
964        // simply unused by the shader (Metal allows unused FCs when the shader
965        // declares them with `function_constant` but the value is never read).
966        //
967        // To avoid a Metal compiler error for an undeclared function constant,
968        // we register a separate bare-kernel shader for the bool wrapper test.
969        const BARE_SHADER: &str = r#"
970#include <metal_stdlib>
971using namespace metal;
972kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
973    if (tid == 0) { out[0] = 42; }
974}
975"#;
976        registry.register_source("bare_kernel", BARE_SHADER);
977
978        let count_before_bool = registry.cached_count();
979        let _bool_pipeline = registry
980            .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
981            .expect("bool-constants wrapper with empty slice must succeed");
982
983        assert_eq!(
984            registry.cached_count(),
985            count_before_bool + 1,
986            "bool-constants wrapper must insert one new cache entry"
987        );
988    }
989
990    /// Verify that the `MTLComputePipelineState.label` produced by
991    /// `get_pipeline` and `get_pipeline_with_constants` actually propagates
992    /// from the descriptor to the resulting pipeline state.
993    ///
994    /// This is the in-process smoke check for ADR-015 iter9b: we cannot
995    /// reach into xctrace from Rust, but we can read back the same `label`
996    /// property xctrace consumes via `ComputePipelineStateRef::label()`.
997    /// If labels are missing or wrong here, the MST trace will also show
998    /// generic identifiers — so this test gates the iter9 retry's
999    /// per-Q4_0-kernel attribution.
1000    #[test]
1001    fn test_pipeline_labels_propagate_for_mst() {
1002        let device = metal::Device::system_default()
1003            .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1004
1005        let mut registry = KernelRegistry::new();
1006
1007        // Reuse the same trivial shaders as the int-FC test.
1008        registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1009
1010        const BARE_SHADER_LABEL_TEST: &str = r#"
1011#include <metal_stdlib>
1012using namespace metal;
1013kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1014    if (tid == 0) { out[0] = 7; }
1015}
1016"#;
1017        registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1018
1019        // Plain get_pipeline path — label must equal the kernel name.
1020        // Capture as owned String so the cache borrow is released before
1021        // the next get_pipeline_with_constants call below.
1022        let plain_label = registry
1023            .get_pipeline("label_smoke_kernel", &device)
1024            .expect("plain pipeline must compile")
1025            .label()
1026            .to_string();
1027        assert_eq!(
1028            plain_label, "label_smoke_kernel",
1029            "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1030        );
1031
1032        // Constants path — label must equal the composite cache key so each
1033        // function-constant variant is individually attributable in MST.
1034        // We capture the label as an owned String to release the borrow on
1035        // the cache before fetching the next specialisation.
1036        let label_v7 = registry
1037            .get_pipeline_with_constants(
1038                "int_fc_test_kernel",
1039                &device,
1040                &[],
1041                &[(100, 7_i32)],
1042            )
1043            .expect("specialised pipeline must compile")
1044            .label()
1045            .to_string();
1046        assert_eq!(
1047            label_v7, "int_fc_test_kernel|100:i7",
1048            "get_pipeline_with_constants must label with the cache_key so each \
1049             specialisation is distinct in xctrace MST"
1050        );
1051
1052        // A second specialisation must produce a different label.
1053        let label_v13 = registry
1054            .get_pipeline_with_constants(
1055                "int_fc_test_kernel",
1056                &device,
1057                &[],
1058                &[(100, 13_i32)],
1059            )
1060            .expect("second specialised pipeline must compile")
1061            .label()
1062            .to_string();
1063        assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1064        assert_ne!(
1065            label_v7, label_v13,
1066            "distinct constant values must yield distinct pipeline labels"
1067        );
1068    }
1069}