1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q8_0_f32_nr2".into(), ggml_src);
81 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
82 sources.insert("kernel_mul_mv_q6_K_f32_nr2".into(), ggml_src);
87 sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
89 sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
90 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
93 sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
95
96 let ggml_mm_src: &'static str =
102 include_str!("shaders/quantized_matmul_mm.metal");
103 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
104 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
105 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
106 sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
108 sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
109 sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
111 sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
113
114 let ggml_mm_tensor_src: &'static str =
125 include_str!("shaders/quantized_matmul_mm_tensor.metal");
126 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
127 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
128 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
129 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
130 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
131 sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
133 sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
134 sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
136 sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
138 sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
139 sources.insert("hf2q_mul_mm_tensor_v2_f16".into(), ggml_mm_tensor_src);
144 sources.insert("kernel_mul_mm_f16_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
150 sources.insert("kernel_mul_mm_q4_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
154 sources.insert("kernel_mul_mm_q8_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
155 sources.insert("kernel_mul_mm_q6_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
156 sources.insert("kernel_mul_mm_q5_1_tensor_v2_f32".into(), ggml_mm_tensor_src);
157 sources.insert("kernel_mul_mm_iq4_nl_tensor_v2_f32".into(), ggml_mm_tensor_src);
158 sources.insert("kernel_mul_mm_q5_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
159 sources.insert("kernel_mul_mm_q4_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
160 let dequant_to_f16_src: &'static str =
166 include_str!("shaders/dequant_to_f16.metal");
167 sources.insert("hf2q_dequant_q4_0_to_f16".into(), dequant_to_f16_src);
168 sources.insert("hf2q_dequant_q8_0_to_f16".into(), dequant_to_f16_src);
169 sources.insert("hf2q_dequant_q5_1_to_f16".into(), dequant_to_f16_src);
170 sources.insert("hf2q_dequant_iq4_nl_to_f16".into(), dequant_to_f16_src);
171 sources.insert("hf2q_dequant_q4_K_to_f16".into(), dequant_to_f16_src);
172 sources.insert("hf2q_dequant_q5_K_to_f16".into(), dequant_to_f16_src);
173 sources.insert("hf2q_dequant_q6_K_to_f16".into(), dequant_to_f16_src);
174
175 let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
180 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
181 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
182 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
183 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
184 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
185 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
186 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
187 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
188 for r1 in [2, 3, 4, 5].iter() {
191 for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
192 let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
193 sources.insert(name, mul_mv_ext_src);
194 }
195 }
196
197 let dense_mm_bf16_tensor_src: &'static str =
204 include_str!("shaders/dense_mm_bf16_tensor.metal");
205 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
206 sources.insert("hf2q_dense_mm_bf16_f32_tensor_v2".into(), dense_mm_bf16_tensor_src);
211
212 let dense_mm_f32_f32_tensor_src: &'static str =
221 include_str!("shaders/dense_mm_f32_f32.metal");
222 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
223
224 let dense_mm_f16_tensor_src: &'static str =
236 include_str!("shaders/dense_mm_f16_tensor.metal");
237 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
238
239 let dense_gemv_bf16_src: &'static str =
246 include_str!("shaders/dense_gemv_bf16.metal");
247 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
248
249 let scale_mask_softmax_src: &'static str =
255 include_str!("shaders/scale_mask_softmax.metal");
256 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
257 sources.insert("scale_mask_softmax_f32_v4".into(), scale_mask_softmax_src);
261
262 sources.insert(
264 "quantized_matmul_id".into(),
265 include_str!("shaders/quantized_matmul_id.metal"),
266 );
267
268 let ggml_id_src: &'static str =
270 include_str!("shaders/quantized_matmul_id_ggml.metal");
271 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
272 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
273 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
276 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
277 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
278 sources.insert("kernel_mul_mv_id_q6_K_f32_nr2".into(), ggml_id_src);
282 sources.insert("kernel_mul_mv_id_q8_0_f32_nr2".into(), ggml_id_src);
287 sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
289 sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
290 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
294
295 let ggml_id_mm_src: &'static str =
303 include_str!("shaders/quantized_matmul_id_mm.metal");
304 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
305 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
306 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
307 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
308 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
309 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
311 sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
313 sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
314 sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
316
317 let ggml_id_mm_tensor_src: &'static str =
323 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
324 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
325 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
326 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
327 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
329 sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
331 sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
332 sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
334
335 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
337 sources.insert("embedding_gather_4bit".into(), embedding_src);
338 sources.insert("embedding_gather_6bit".into(), embedding_src);
339
340 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
342 sources.insert("moe_gate".into(), moe_gate_src);
343
344 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
346 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
347 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
348 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
349 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
350 sources.insert("moe_accumulate".into(), moe_dispatch_src);
351 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
352 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
353 sources.insert("zero_buffer".into(), moe_dispatch_src);
354 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
355 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
356 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
358 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
359 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
360 sources.insert(
362 "moe_weighted_sum_seq_backward_outputs_f32".into(),
363 moe_dispatch_src,
364 );
365 sources.insert(
366 "moe_weighted_sum_seq_backward_weights_f32".into(),
367 moe_dispatch_src,
368 );
369 sources.insert(
371 "moe_swiglu_seq_backward_f32".into(),
372 moe_dispatch_src,
373 );
374
375 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
377 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
378 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
379 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
380 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
381 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
383 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
384 sources.insert("kv_cache_copy_batch_f32_kv_dual".into(), kv_cache_src);
386 sources.insert("kv_cache_copy_batch_f32_to_f16_kv_dual".into(), kv_cache_src);
387 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
389
390 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
392 sources.insert("elementwise_add_f32".into(), elementwise_src);
393 sources.insert("elementwise_add_f16".into(), elementwise_src);
394 sources.insert("elementwise_mul_f32".into(), elementwise_src);
395 sources.insert("elementwise_mul_f16".into(), elementwise_src);
396 sources.insert("elementwise_add_bf16".into(), elementwise_src);
397 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
398 sources.insert("cast_f16_to_f32".into(), elementwise_src);
399 sources.insert("cast_f32_to_f16".into(), elementwise_src);
400 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
401 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
402 sources.insert("scalar_mul_bf16".into(), elementwise_src);
403 sources.insert("scalar_mul_f32".into(), elementwise_src);
404 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
405 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
406 sources.insert("permute_021_bf16".into(), elementwise_src);
407 sources.insert("transpose_last2_bf16".into(), elementwise_src);
408 sources.insert("transpose_last2_f16".into(), elementwise_src);
409 sources.insert("permute_021_f32".into(), elementwise_src);
410 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
411 sources.insert("transpose_2d_f32".into(), elementwise_src);
412 sources.insert("transpose_2d_f16".into(), elementwise_src);
413
414 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
416 sources.insert("sdpa".into(), sdpa_src);
417 sources.insert("sdpa_bf16".into(), sdpa_src);
418 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
419 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
420 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
421
422 let flash_attn_prefill_src: &'static str =
427 include_str!("shaders/flash_attn_prefill.metal");
428 sources.insert(
430 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
431 flash_attn_prefill_src,
432 );
433 sources.insert(
434 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
435 flash_attn_prefill_src,
436 );
437 sources.insert(
438 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
439 flash_attn_prefill_src,
440 );
441 sources.insert(
442 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
443 flash_attn_prefill_src,
444 );
445 sources.insert(
446 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
447 flash_attn_prefill_src,
448 );
449 sources.insert(
450 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
451 flash_attn_prefill_src,
452 );
453 sources.insert(
457 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
458 flash_attn_prefill_src,
459 );
460 sources.insert(
461 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
462 flash_attn_prefill_src,
463 );
464 sources.insert(
465 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
466 flash_attn_prefill_src,
467 );
468 sources.insert(
469 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
470 flash_attn_prefill_src,
471 );
472
473 let flash_attn_vec_src: &'static str =
476 include_str!("shaders/flash_attn_vec.metal");
477 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
478 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
479 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
480 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
481 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
483 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
484
485 let rope_src: &'static str = include_str!("shaders/rope.metal");
487 sources.insert("rope_f32".into(), rope_src);
488 sources.insert("rope_f16".into(), rope_src);
489 sources.insert("rope_bf16".into(), rope_src);
490 sources.insert("rope_neox_bf16".into(), rope_src);
491 sources.insert("rope_neox_f32".into(), rope_src);
492 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
493 sources.insert("rms_norm_f32".into(), rms_norm_src);
494 sources.insert("rms_norm_f32_v2".into(), rms_norm_src);
498 sources.insert("rms_norm_no_scale_f32_v2".into(), rms_norm_src);
499 sources.insert("rms_norm_f16".into(), rms_norm_src);
500 sources.insert("rms_norm_bf16".into(), rms_norm_src);
501 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
502 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
503 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
504 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
505 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
506 sources.insert("fused_post_attn_triple_norm_f32_v2".into(), rms_norm_src);
508 sources.insert("fused_post_ff_norm2_endlayer_f32".into(), rms_norm_src);
511 sources.insert("fused_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
514 sources.insert("fused_moe_wsum_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
517 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
518 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
520 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
521 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
522 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
524 sources.insert("l2_norm_f32".into(), l2_norm_src);
525 sources.insert("l2_norm_f16".into(), l2_norm_src);
526 sources.insert("l2_norm_bf16".into(), l2_norm_src);
527 sources.insert("l2_norm_scale_f32".into(), l2_norm_src);
529 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
531 sources.insert("cumsum_f32".into(), cumsum_src);
532 sources.insert("cumsum_bf16".into(), cumsum_src);
533 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
535 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
536 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
537 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
538 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
539 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
541 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
542 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
543 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
545 sources.insert("rope_multi_f32".into(), rope_multi_src);
546 sources.insert("rope_multi_bf16".into(), rope_multi_src);
547 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
549 sources.insert("gated_delta_net_f32".into(), gdn_src);
550 let gdn_decode_src: &'static str =
555 include_str!("shaders/gated_delta_net_decode.metal");
556 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
557 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
558 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
559 let gdn_chunk_src: &'static str =
563 include_str!("shaders/gated_delta_net_chunk.metal");
564 sources.insert(
565 "gated_delta_net_chunk_inter_state_bf16".into(),
566 gdn_chunk_src,
567 );
568 let gdn_kkt_src: &'static str =
571 include_str!("shaders/gated_delta_net_kkt.metal");
572 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
573 let gdn_recompute_wu_src: &'static str =
577 include_str!("shaders/gated_delta_net_recompute_wu.metal");
578 sources.insert(
579 "gated_delta_net_recompute_wu_bf16".into(),
580 gdn_recompute_wu_src,
581 );
582 let gdn_chunk_o_src: &'static str =
585 include_str!("shaders/gated_delta_net_chunk_o.metal");
586 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
587 let chunk_local_cumsum_g_src: &'static str =
592 include_str!("shaders/chunk_local_cumsum_g.metal");
593 sources.insert(
594 "chunk_local_cumsum_g_f32".into(),
595 chunk_local_cumsum_g_src,
596 );
597 let chunk_tri_solve_invert_src: &'static str =
598 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
599 sources.insert(
600 "chunk_tri_solve_invert_f32".into(),
601 chunk_tri_solve_invert_src,
602 );
603 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
605 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
606 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
607 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
608 sources.insert("silu_mul_f32".into(), silu_mul_src);
609 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
610 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
611 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
612 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
613 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
614 sources.insert("gelu_f32".into(), gelu_src);
615 sources.insert("gelu_f16".into(), gelu_src);
616 sources.insert("gelu_bf16".into(), gelu_src);
617 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
618 sources.insert("softmax_f32".into(), softmax_src);
619 sources.insert("softmax_f16".into(), softmax_src);
620 sources.insert("softmax_bf16".into(), softmax_src);
621 let softmax_backward_src: &'static str =
622 include_str!("shaders/softmax_backward.metal");
623 sources.insert("softmax_backward_f32".into(), softmax_backward_src);
624 let log_elementwise_src: &'static str =
625 include_str!("shaders/log_elementwise.metal");
626 sources.insert("log_f32".into(), log_elementwise_src);
627 sources.insert("log_backward_f32".into(), log_elementwise_src);
628 let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
629 sources.insert("row_sum_f32".into(), row_sum_src);
630 sources.insert("row_sum_backward_f32".into(), row_sum_src);
631 let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
635 sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
636 sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
637 let rms_norm_backward_src: &'static str =
641 include_str!("shaders/rms_norm_backward.metal");
642 sources.insert(
643 "rms_norm_compute_rms_inv_f32".into(),
644 rms_norm_backward_src,
645 );
646 sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
647 sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
648 let slice_concat_2d_src: &'static str =
653 include_str!("shaders/slice_concat_2d.metal");
654 sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
655 sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
656 let silu_backward_src: &'static str =
659 include_str!("shaders/silu_backward.metal");
660 sources.insert("silu_f32".into(), silu_backward_src);
661 sources.insert("silu_backward_f32".into(), silu_backward_src);
662 let embedding_autograd_src: &'static str =
664 include_str!("shaders/embedding_autograd.metal");
665 sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
666 sources.insert(
667 "embedding_scatter_add_f32".into(),
668 embedding_autograd_src,
669 );
670 let adam_update_src: &'static str =
673 include_str!("shaders/adam_update.metal");
674 sources.insert("adam_update_f32".into(), adam_update_src);
675 let qdq_affine_src: &'static str =
679 include_str!("shaders/qdq_affine.metal");
680 sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
681 sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
682 sources.insert(
683 "qdq_affine_backward_scales_f32".into(),
684 qdq_affine_src,
685 );
686 sources.insert(
687 "qdq_affine_backward_biases_f32".into(),
688 qdq_affine_src,
689 );
690 let qmm_affine_src: &'static str =
694 include_str!("shaders/qmm_affine.metal");
695 sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
696 let qmm_affine_tiled_src: &'static str =
700 include_str!("shaders/qmm_affine_tiled.metal");
701 sources.insert(
702 "qmm_affine_t_f32_tiled".into(),
703 qmm_affine_tiled_src,
704 );
705 let qmm_affine_simd_src: &'static str =
710 include_str!("shaders/qmm_affine_simd.metal");
711 sources.insert(
712 "qmm_affine_t_f32_simd".into(),
713 qmm_affine_simd_src,
714 );
715 let qmm_affine_simd4_src: &'static str =
720 include_str!("shaders/qmm_affine_simd4.metal");
721 sources.insert(
722 "qmm_affine_t_f32_simd4".into(),
723 qmm_affine_simd4_src,
724 );
725 let qmm_affine_simd4_gs64_src: &'static str =
729 include_str!("shaders/qmm_affine_simd4_gs64.metal");
730 sources.insert(
731 "qmm_affine_t_f32_simd4_gs64".into(),
732 qmm_affine_simd4_gs64_src,
733 );
734 let qmm_affine_t_packed_simd4_b4_src: &'static str =
738 include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
739 sources.insert(
740 "qmm_affine_t_packed_simd4_b4".into(),
741 qmm_affine_t_packed_simd4_b4_src,
742 );
743 let conv1d_dwc_src: &'static str =
748 include_str!("shaders/conv1d_depthwise_causal.metal");
749 sources.insert(
750 "conv1d_depthwise_causal_forward_f32".into(),
751 conv1d_dwc_src,
752 );
753 sources.insert(
754 "conv1d_depthwise_causal_backward_dx_f32".into(),
755 conv1d_dwc_src,
756 );
757 sources.insert(
758 "conv1d_depthwise_causal_backward_dw_f32".into(),
759 conv1d_dwc_src,
760 );
761 let exp_src: &'static str =
764 include_str!("shaders/exp_elementwise.metal");
765 sources.insert("exp_f32".into(), exp_src);
766 sources.insert("exp_backward_f32".into(), exp_src);
767 let outer_src: &'static str =
771 include_str!("shaders/outer_product.metal");
772 sources.insert("outer_product_f32".into(), outer_src);
773 sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
774 sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
775 let taa_src: &'static str =
778 include_str!("shaders/take_along_axis.metal");
779 sources.insert("take_along_axis_f32".into(), taa_src);
780 sources.insert("take_along_axis_backward_f32".into(), taa_src);
781 let div_src: &'static str =
783 include_str!("shaders/divide_elementwise.metal");
784 sources.insert("divide_f32".into(), div_src);
785 sources.insert("divide_backward_f32".into(), div_src);
786 let sqrt_src: &'static str =
788 include_str!("shaders/sqrt_elementwise.metal");
789 sources.insert("sqrt_f32".into(), sqrt_src);
790 sources.insert("sqrt_backward_f32".into(), sqrt_src);
791 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
792 sources.insert("softcap_f32".into(), softcap_src);
793 sources.insert("softcap_f16".into(), softcap_src);
794 sources.insert("softcap_bf16".into(), softcap_src);
795
796 let fused_norm_add_src: &'static str =
799 include_str!("shaders/fused_norm_add_bf16.metal");
800 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
801 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
802
803 let fused_hnr_f32_src: &'static str =
805 include_str!("shaders/fused_head_norm_rope_f32.metal");
806 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
807 sources.insert("fused_head_norm_rope_f32_v2".into(), fused_hnr_f32_src);
811
812 let fused_hnr_bf16_src: &'static str =
815 include_str!("shaders/fused_head_norm_rope_bf16.metal");
816 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
817 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
818
819 let fused_norm_add_f32_src: &'static str =
821 include_str!("shaders/fused_norm_add_f32.metal");
822 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
823 sources.insert("fused_norm_add_f32_v2".into(), fused_norm_add_f32_src);
828 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
829 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
830 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
831 sources.insert("fused_moe_routing_f32_v2".into(), fused_norm_add_f32_src);
833 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
834 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
835 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
836 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
837
838 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
840 sources.insert("argsort_desc_f32".into(), argsort_src);
841
842 let gather_src: &'static str = include_str!("shaders/gather.metal");
844 sources.insert("gather_f32".into(), gather_src);
845
846 let kv_cache_copy_src: &'static str =
848 include_str!("shaders/kv_cache_copy.metal");
849 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
850 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
851
852 let copy_src: &'static str = include_str!("shaders/copy.metal");
854 sources.insert("strided_copy_f32".into(), copy_src);
855 sources.insert("offset_copy_f32".into(), copy_src);
856
857 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
861 sources.insert("qkv_split_f32".into(), qkv_split_src);
862
863 let repeat_tiled_src: &'static str =
867 include_str!("shaders/repeat_tiled.metal");
868 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
869
870 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
872 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
873 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
874 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
875 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
877 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
879
880 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
882 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
883 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
884 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
886 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
887 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
888 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
889
890 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
892 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
893 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
894 sources.insert("hadamard_quantize_kv_fast_dual_d256".into(), hq_fast_src);
896 sources.insert("hadamard_quantize_kv_fast_dual_d512".into(), hq_fast_src);
897 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
899 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
900 sources.insert("hadamard_quantize_kv_hb_dual_d256".into(), hq_fast_src);
902 sources.insert("hadamard_quantize_kv_hb_dual_d512".into(), hq_fast_src);
903 sources.insert("kv_quantize_v_no_fwht_d256".into(), hq_fast_src);
907 sources.insert("kv_quantize_v_no_fwht_d512".into(), hq_fast_src);
908 sources.insert("kv_copy_kf16_quantize_v_no_fwht_d256".into(), hq_fast_src);
912 sources.insert("kv_copy_kf16_quantize_v_no_fwht_d512".into(), hq_fast_src);
913
914 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
916 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
917 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
919 sources.insert("tq_dequantize_hb_kv_seq".into(), tq_dq_src);
926
927 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
929 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
930 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
931
932 let reduce_undo_src: &'static str = include_str!("shaders/flash_attn_vec_reduce_tq_hb_undo.metal");
937 sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk256".into(), reduce_undo_src);
938 sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk512".into(), reduce_undo_src);
939
940 let hybrid_src: &'static str = include_str!("shaders/flash_attn_vec_hybrid.metal");
944 sources.insert("flash_attn_vec_hybrid_dk256".into(), hybrid_src);
945 sources.insert("flash_attn_vec_hybrid_dk512".into(), hybrid_src);
946
947 let peer_port_src: &'static str = include_str!("shaders/flash_attn_vec_peer_port_f16.metal");
950 sources.insert("flash_attn_vec_peer_port_f16_dk256_dv256".into(), peer_port_src);
951
952 let peer_port_reduce_src: &'static str =
955 include_str!("shaders/flash_attn_vec_peer_port_f16_reduce.metal");
956 sources.insert(
957 "flash_attn_vec_peer_port_f16_reduce_dv256_nwg32".into(),
958 peer_port_reduce_src,
959 );
960
961 let peer_port_nwg32_src: &'static str =
964 include_str!("shaders/flash_attn_vec_peer_port_f16_nwg32.metal");
965 sources.insert(
966 "flash_attn_vec_peer_port_f16_nwg32_dk256_dv256".into(),
967 peer_port_nwg32_src,
968 );
969
970 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
972 sources.insert("argmax_f32".into(), argmax_src);
973 let softmax_sample_src: &'static str =
974 include_str!("shaders/softmax_sample.metal");
975 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
976 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
978 sources.insert("top_k_f32".into(), top_k_src);
979
980 let moe_stk_src: &'static str =
983 include_str!("shaders/moe_softmax_topk.metal");
984 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
985 let moe_wr_src: &'static str =
986 include_str!("shaders/moe_weighted_reduce.metal");
987 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
988 let sdpa_decode_src: &'static str =
989 include_str!("shaders/sdpa_decode.metal");
990 sources.insert("sdpa_decode".into(), sdpa_decode_src);
991
992 Self {
993 cache: HashMap::new(),
994 sources,
995 }
996 }
997
998 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
1001 let name = name.into();
1002 self.cache.remove(&name);
1004 self.sources.insert(name, source);
1005 }
1006
1007 pub fn get_pipeline(
1019 &mut self,
1020 name: &str,
1021 device: &metal::DeviceRef,
1022 ) -> Result<&ComputePipelineState> {
1023 if !self.cache.contains_key(name) {
1024 let source = self.sources.get(name).ok_or_else(|| {
1026 MlxError::KernelNotFound(name.to_string())
1027 })?;
1028
1029 let compile_opts = metal::CompileOptions::new();
1030 let library = device
1031 .new_library_with_source(source, &compile_opts)
1032 .map_err(|msg| MlxError::ShaderCompilationError {
1033 name: name.to_string(),
1034 message: msg,
1035 })?;
1036
1037 let function = library
1038 .get_function(name, None)
1039 .map_err(|msg| MlxError::ShaderCompilationError {
1040 name: name.to_string(),
1041 message: msg,
1042 })?;
1043
1044 let descriptor = ComputePipelineDescriptor::new();
1054 descriptor.set_compute_function(Some(&function));
1055 descriptor.set_label(name);
1056 if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
1063 descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
1064 }
1065
1066 let pipeline = device
1067 .new_compute_pipeline_state(&descriptor)
1068 .map_err(|msg| MlxError::ShaderCompilationError {
1069 name: name.to_string(),
1070 message: msg,
1071 })?;
1072
1073 self.cache.insert(name.to_string(), pipeline);
1074 }
1075
1076 self.cache.get(name).ok_or_else(|| {
1079 MlxError::KernelNotFound(name.to_string())
1080 })
1081 }
1082
1083 pub fn get_pipeline_with_constants(
1105 &mut self,
1106 name: &str,
1107 device: &metal::DeviceRef,
1108 bool_constants: &[(usize, bool)],
1109 int_constants: &[(usize, i32)],
1110 ) -> Result<&ComputePipelineState> {
1111 let mut cache_key = name.to_string();
1116 for &(index, value) in bool_constants {
1117 cache_key.push('|');
1118 cache_key.push_str(&index.to_string());
1119 cache_key.push_str(if value { ":b1" } else { ":b0" });
1120 }
1121 for &(index, value) in int_constants {
1122 cache_key.push('|');
1123 cache_key.push_str(&index.to_string());
1124 cache_key.push(':');
1125 cache_key.push('i');
1126 cache_key.push_str(&value.to_string());
1127 }
1128
1129 if !self.cache.contains_key(&cache_key) {
1130 let source = self.sources.get(name).ok_or_else(|| {
1132 MlxError::KernelNotFound(name.to_string())
1133 })?;
1134
1135 let compile_opts = metal::CompileOptions::new();
1136 let library = device
1137 .new_library_with_source(source, &compile_opts)
1138 .map_err(|msg| MlxError::ShaderCompilationError {
1139 name: name.to_string(),
1140 message: msg,
1141 })?;
1142
1143 let fcv = FunctionConstantValues::new();
1148
1149 for &(index, value) in bool_constants {
1150 let v: u8 = if value { 1 } else { 0 };
1153 fcv.set_constant_value_at_index(
1154 (&v as *const u8).cast::<std::ffi::c_void>(),
1155 MTLDataType::Bool,
1156 index as u64,
1157 );
1158 }
1159
1160 for &(index, value) in int_constants {
1161 fcv.set_constant_value_at_index(
1165 (&value as *const i32).cast::<std::ffi::c_void>(),
1166 MTLDataType::Int,
1167 index as u64,
1168 );
1169 }
1170
1171 let function = library
1172 .get_function(name, Some(fcv))
1173 .map_err(|msg| MlxError::ShaderCompilationError {
1174 name: name.to_string(),
1175 message: msg,
1176 })?;
1177
1178 let descriptor = ComputePipelineDescriptor::new();
1185 descriptor.set_compute_function(Some(&function));
1186 descriptor.set_label(&cache_key);
1187 if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
1189 descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
1190 }
1191
1192 let pipeline = device
1193 .new_compute_pipeline_state(&descriptor)
1194 .map_err(|msg| MlxError::ShaderCompilationError {
1195 name: name.to_string(),
1196 message: msg,
1197 })?;
1198
1199 self.cache.insert(cache_key.clone(), pipeline);
1200 }
1201
1202 self.cache.get(&cache_key).ok_or_else(|| {
1203 MlxError::KernelNotFound(name.to_string())
1204 })
1205 }
1206
1207 pub fn get_pipeline_with_bool_constants(
1225 &mut self,
1226 name: &str,
1227 device: &metal::DeviceRef,
1228 bool_constants: &[(usize, bool)],
1229 ) -> Result<&ComputePipelineState> {
1230 self.get_pipeline_with_constants(name, device, bool_constants, &[])
1231 }
1232
1233 pub fn is_cached(&self, name: &str) -> bool {
1235 self.cache.contains_key(name)
1236 }
1237
1238 pub fn cached_count(&self) -> usize {
1240 self.cache.len()
1241 }
1242
1243 pub fn source_count(&self) -> usize {
1245 self.sources.len()
1246 }
1247}
1248
1249impl Default for KernelRegistry {
1250 fn default() -> Self {
1251 Self::new()
1252 }
1253}
1254
1255#[cfg(test)]
1256mod tests {
1257 use super::*;
1258
1259 const INT_FC_TEST_SHADER: &str = r#"
1269#include <metal_stdlib>
1270using namespace metal;
1271
1272constant int test_N [[function_constant(100)]];
1273
1274kernel void int_fc_test_kernel(
1275 device int* out [[buffer(0)]],
1276 uint tid [[thread_position_in_grid]])
1277{
1278 if (tid == 0) {
1279 out[0] = test_N;
1280 }
1281}
1282"#;
1283
1284 #[test]
1292 fn test_int_fc_distinct_pipelines_and_bool_compat() {
1293 let device = metal::Device::system_default()
1294 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1295
1296 let mut registry = KernelRegistry::new();
1297
1298 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1301
1302 let p4_ptr = registry
1304 .get_pipeline_with_constants(
1305 "int_fc_test_kernel",
1306 &device,
1307 &[], &[(100, 4_i32)], )
1310 .expect("pipeline N=4 should compile") as *const _;
1311
1312 let count_after_n4 = registry.cached_count();
1316
1317 let p8_ptr = registry
1319 .get_pipeline_with_constants(
1320 "int_fc_test_kernel",
1321 &device,
1322 &[],
1323 &[(100, 8_i32)],
1324 )
1325 .expect("pipeline N=8 should compile") as *const _;
1326
1327 assert_eq!(
1329 registry.cached_count(),
1330 count_after_n4 + 1,
1331 "N=8 must produce a new cache entry"
1332 );
1333
1334 assert_ne!(
1336 p4_ptr, p8_ptr,
1337 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1338 );
1339
1340 let p4_again_ptr = registry
1343 .get_pipeline_with_constants(
1344 "int_fc_test_kernel",
1345 &device,
1346 &[],
1347 &[(100, 4_i32)],
1348 )
1349 .expect("pipeline N=4 cache hit should succeed") as *const _;
1350
1351 assert_eq!(
1352 registry.cached_count(),
1353 count_after_n4 + 1,
1354 "repeated N=4 call must be a cache hit, not a new entry"
1355 );
1356 assert_eq!(
1357 p4_ptr, p4_again_ptr,
1358 "repeated N=4 call must return the same pipeline pointer"
1359 );
1360
1361 const BARE_SHADER: &str = r#"
1375#include <metal_stdlib>
1376using namespace metal;
1377kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1378 if (tid == 0) { out[0] = 42; }
1379}
1380"#;
1381 registry.register_source("bare_kernel", BARE_SHADER);
1382
1383 let count_before_bool = registry.cached_count();
1384 let _bool_pipeline = registry
1385 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1386 .expect("bool-constants wrapper with empty slice must succeed");
1387
1388 assert_eq!(
1389 registry.cached_count(),
1390 count_before_bool + 1,
1391 "bool-constants wrapper must insert one new cache entry"
1392 );
1393 }
1394
1395 #[test]
1406 fn test_pipeline_labels_propagate_for_mst() {
1407 let device = metal::Device::system_default()
1408 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1409
1410 let mut registry = KernelRegistry::new();
1411
1412 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1414
1415 const BARE_SHADER_LABEL_TEST: &str = r#"
1416#include <metal_stdlib>
1417using namespace metal;
1418kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1419 if (tid == 0) { out[0] = 7; }
1420}
1421"#;
1422 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1423
1424 let plain_label = registry
1428 .get_pipeline("label_smoke_kernel", &device)
1429 .expect("plain pipeline must compile")
1430 .label()
1431 .to_string();
1432 assert_eq!(
1433 plain_label, "label_smoke_kernel",
1434 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1435 );
1436
1437 let label_v7 = registry
1442 .get_pipeline_with_constants(
1443 "int_fc_test_kernel",
1444 &device,
1445 &[],
1446 &[(100, 7_i32)],
1447 )
1448 .expect("specialised pipeline must compile")
1449 .label()
1450 .to_string();
1451 assert_eq!(
1452 label_v7, "int_fc_test_kernel|100:i7",
1453 "get_pipeline_with_constants must label with the cache_key so each \
1454 specialisation is distinct in xctrace MST"
1455 );
1456
1457 let label_v13 = registry
1459 .get_pipeline_with_constants(
1460 "int_fc_test_kernel",
1461 &device,
1462 &[],
1463 &[(100, 13_i32)],
1464 )
1465 .expect("second specialised pipeline must compile")
1466 .label()
1467 .to_string();
1468 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1469 assert_ne!(
1470 label_v7, label_v13,
1471 "distinct constant values must yield distinct pipeline labels"
1472 );
1473 }
1474}