1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81 let ggml_mm_src: &'static str =
87 include_str!("shaders/quantized_matmul_mm.metal");
88 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92 let ggml_mm_tensor_src: &'static str =
103 include_str!("shaders/quantized_matmul_mm_tensor.metal");
104 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110 let dense_mm_bf16_tensor_src: &'static str =
117 include_str!("shaders/dense_mm_bf16_tensor.metal");
118 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120 let dense_mm_f32_f32_tensor_src: &'static str =
129 include_str!("shaders/dense_mm_f32_f32.metal");
130 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
131
132 let dense_mm_f16_tensor_src: &'static str =
144 include_str!("shaders/dense_mm_f16_tensor.metal");
145 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
146
147 let dense_gemv_bf16_src: &'static str =
154 include_str!("shaders/dense_gemv_bf16.metal");
155 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
156
157 let scale_mask_softmax_src: &'static str =
163 include_str!("shaders/scale_mask_softmax.metal");
164 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
165
166 sources.insert(
168 "quantized_matmul_id".into(),
169 include_str!("shaders/quantized_matmul_id.metal"),
170 );
171
172 let ggml_id_src: &'static str =
174 include_str!("shaders/quantized_matmul_id_ggml.metal");
175 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
176 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
177 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
178 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
179 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
183
184 let ggml_id_mm_src: &'static str =
192 include_str!("shaders/quantized_matmul_id_mm.metal");
193 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
194 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
195 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
196 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
197 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
198
199 let ggml_id_mm_tensor_src: &'static str =
205 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
206 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
207 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
208 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
209
210 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
212 sources.insert("embedding_gather_4bit".into(), embedding_src);
213 sources.insert("embedding_gather_6bit".into(), embedding_src);
214
215 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
217 sources.insert("moe_gate".into(), moe_gate_src);
218
219 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
221 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
222 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
223 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
224 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
225 sources.insert("moe_accumulate".into(), moe_dispatch_src);
226 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
227 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
228 sources.insert("zero_buffer".into(), moe_dispatch_src);
229 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
230 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
231 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
233 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
234 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
235
236 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
238 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
239 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
240 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
241 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
242 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
244 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
245 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
247
248 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
250 sources.insert("elementwise_add_f32".into(), elementwise_src);
251 sources.insert("elementwise_add_f16".into(), elementwise_src);
252 sources.insert("elementwise_mul_f32".into(), elementwise_src);
253 sources.insert("elementwise_mul_f16".into(), elementwise_src);
254 sources.insert("elementwise_add_bf16".into(), elementwise_src);
255 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
256 sources.insert("cast_f16_to_f32".into(), elementwise_src);
257 sources.insert("cast_f32_to_f16".into(), elementwise_src);
258 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
259 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
260 sources.insert("scalar_mul_bf16".into(), elementwise_src);
261 sources.insert("scalar_mul_f32".into(), elementwise_src);
262 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
263 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
264 sources.insert("permute_021_bf16".into(), elementwise_src);
265 sources.insert("transpose_last2_bf16".into(), elementwise_src);
266 sources.insert("transpose_last2_f16".into(), elementwise_src);
267 sources.insert("permute_021_f32".into(), elementwise_src);
268 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
269 sources.insert("transpose_2d_f32".into(), elementwise_src);
270 sources.insert("transpose_2d_f16".into(), elementwise_src);
271
272 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
274 sources.insert("sdpa".into(), sdpa_src);
275 sources.insert("sdpa_bf16".into(), sdpa_src);
276 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
277 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
278 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
279
280 let flash_attn_prefill_src: &'static str =
285 include_str!("shaders/flash_attn_prefill.metal");
286 sources.insert(
288 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
289 flash_attn_prefill_src,
290 );
291 sources.insert(
292 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
293 flash_attn_prefill_src,
294 );
295 sources.insert(
296 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
297 flash_attn_prefill_src,
298 );
299 sources.insert(
300 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
301 flash_attn_prefill_src,
302 );
303 sources.insert(
304 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
305 flash_attn_prefill_src,
306 );
307 sources.insert(
308 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
309 flash_attn_prefill_src,
310 );
311 sources.insert(
315 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
316 flash_attn_prefill_src,
317 );
318 sources.insert(
319 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
320 flash_attn_prefill_src,
321 );
322 sources.insert(
323 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
324 flash_attn_prefill_src,
325 );
326 sources.insert(
327 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
328 flash_attn_prefill_src,
329 );
330
331 let flash_attn_vec_src: &'static str =
334 include_str!("shaders/flash_attn_vec.metal");
335 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
336 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
337 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
338 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
339 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
341 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
342
343 let rope_src: &'static str = include_str!("shaders/rope.metal");
345 sources.insert("rope_f32".into(), rope_src);
346 sources.insert("rope_f16".into(), rope_src);
347 sources.insert("rope_bf16".into(), rope_src);
348 sources.insert("rope_neox_bf16".into(), rope_src);
349 sources.insert("rope_neox_f32".into(), rope_src);
350 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
351 sources.insert("rms_norm_f32".into(), rms_norm_src);
352 sources.insert("rms_norm_f16".into(), rms_norm_src);
353 sources.insert("rms_norm_bf16".into(), rms_norm_src);
354 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
355 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
356 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
357 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
358 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
359 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
360 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
362 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
363 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
364 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
366 sources.insert("l2_norm_f32".into(), l2_norm_src);
367 sources.insert("l2_norm_f16".into(), l2_norm_src);
368 sources.insert("l2_norm_bf16".into(), l2_norm_src);
369 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
371 sources.insert("cumsum_f32".into(), cumsum_src);
372 sources.insert("cumsum_bf16".into(), cumsum_src);
373 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
375 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
376 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
377 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
378 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
379 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
381 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
382 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
383 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
385 sources.insert("rope_multi_f32".into(), rope_multi_src);
386 sources.insert("rope_multi_bf16".into(), rope_multi_src);
387 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
389 sources.insert("gated_delta_net_f32".into(), gdn_src);
390 let gdn_decode_src: &'static str =
395 include_str!("shaders/gated_delta_net_decode.metal");
396 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
397 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
398 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
399 let gdn_chunk_src: &'static str =
403 include_str!("shaders/gated_delta_net_chunk.metal");
404 sources.insert(
405 "gated_delta_net_chunk_inter_state_bf16".into(),
406 gdn_chunk_src,
407 );
408 let gdn_kkt_src: &'static str =
411 include_str!("shaders/gated_delta_net_kkt.metal");
412 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
413 let gdn_recompute_wu_src: &'static str =
417 include_str!("shaders/gated_delta_net_recompute_wu.metal");
418 sources.insert(
419 "gated_delta_net_recompute_wu_bf16".into(),
420 gdn_recompute_wu_src,
421 );
422 let gdn_chunk_o_src: &'static str =
425 include_str!("shaders/gated_delta_net_chunk_o.metal");
426 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
427 let chunk_local_cumsum_g_src: &'static str =
432 include_str!("shaders/chunk_local_cumsum_g.metal");
433 sources.insert(
434 "chunk_local_cumsum_g_f32".into(),
435 chunk_local_cumsum_g_src,
436 );
437 let chunk_tri_solve_invert_src: &'static str =
438 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
439 sources.insert(
440 "chunk_tri_solve_invert_f32".into(),
441 chunk_tri_solve_invert_src,
442 );
443 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
445 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
446 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
447 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
448 sources.insert("silu_mul_f32".into(), silu_mul_src);
449 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
450 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
451 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
452 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
453 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
454 sources.insert("gelu_f32".into(), gelu_src);
455 sources.insert("gelu_f16".into(), gelu_src);
456 sources.insert("gelu_bf16".into(), gelu_src);
457 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
458 sources.insert("softmax_f32".into(), softmax_src);
459 sources.insert("softmax_f16".into(), softmax_src);
460 sources.insert("softmax_bf16".into(), softmax_src);
461 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
462 sources.insert("softcap_f32".into(), softcap_src);
463 sources.insert("softcap_f16".into(), softcap_src);
464 sources.insert("softcap_bf16".into(), softcap_src);
465
466 let fused_norm_add_src: &'static str =
469 include_str!("shaders/fused_norm_add_bf16.metal");
470 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
471 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
472
473 let fused_hnr_f32_src: &'static str =
475 include_str!("shaders/fused_head_norm_rope_f32.metal");
476 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
477
478 let fused_hnr_bf16_src: &'static str =
481 include_str!("shaders/fused_head_norm_rope_bf16.metal");
482 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
483 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
484
485 let fused_norm_add_f32_src: &'static str =
487 include_str!("shaders/fused_norm_add_f32.metal");
488 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
489 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
490 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
491 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
492 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
493 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
494 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
495 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
496
497 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
499 sources.insert("argsort_desc_f32".into(), argsort_src);
500
501 let gather_src: &'static str = include_str!("shaders/gather.metal");
503 sources.insert("gather_f32".into(), gather_src);
504
505 let kv_cache_copy_src: &'static str =
507 include_str!("shaders/kv_cache_copy.metal");
508 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
509 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
510
511 let copy_src: &'static str = include_str!("shaders/copy.metal");
513 sources.insert("strided_copy_f32".into(), copy_src);
514 sources.insert("offset_copy_f32".into(), copy_src);
515
516 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
520 sources.insert("qkv_split_f32".into(), qkv_split_src);
521
522 let repeat_tiled_src: &'static str =
526 include_str!("shaders/repeat_tiled.metal");
527 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
528
529 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
531 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
532 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
533 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
534 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
536 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
538
539 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
541 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
542 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
543 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
545 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
546 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
547 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
548
549 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
551 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
552 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
553 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
555 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
556
557 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
559 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
560 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
562
563 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
565 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
566 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
567
568 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
570 sources.insert("argmax_f32".into(), argmax_src);
571 let softmax_sample_src: &'static str =
572 include_str!("shaders/softmax_sample.metal");
573 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
574 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
576 sources.insert("top_k_f32".into(), top_k_src);
577
578 let moe_stk_src: &'static str =
581 include_str!("shaders/moe_softmax_topk.metal");
582 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
583 let moe_wr_src: &'static str =
584 include_str!("shaders/moe_weighted_reduce.metal");
585 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
586 let sdpa_decode_src: &'static str =
587 include_str!("shaders/sdpa_decode.metal");
588 sources.insert("sdpa_decode".into(), sdpa_decode_src);
589
590 Self {
591 cache: HashMap::new(),
592 sources,
593 }
594 }
595
596 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
599 let name = name.into();
600 self.cache.remove(&name);
602 self.sources.insert(name, source);
603 }
604
605 pub fn get_pipeline(
617 &mut self,
618 name: &str,
619 device: &metal::DeviceRef,
620 ) -> Result<&ComputePipelineState> {
621 if !self.cache.contains_key(name) {
622 let source = self.sources.get(name).ok_or_else(|| {
624 MlxError::KernelNotFound(name.to_string())
625 })?;
626
627 let compile_opts = metal::CompileOptions::new();
628 let library = device
629 .new_library_with_source(source, &compile_opts)
630 .map_err(|msg| MlxError::ShaderCompilationError {
631 name: name.to_string(),
632 message: msg,
633 })?;
634
635 let function = library
636 .get_function(name, None)
637 .map_err(|msg| MlxError::ShaderCompilationError {
638 name: name.to_string(),
639 message: msg,
640 })?;
641
642 let descriptor = ComputePipelineDescriptor::new();
652 descriptor.set_compute_function(Some(&function));
653 descriptor.set_label(name);
654
655 let pipeline = device
656 .new_compute_pipeline_state(&descriptor)
657 .map_err(|msg| MlxError::ShaderCompilationError {
658 name: name.to_string(),
659 message: msg,
660 })?;
661
662 self.cache.insert(name.to_string(), pipeline);
663 }
664
665 self.cache.get(name).ok_or_else(|| {
668 MlxError::KernelNotFound(name.to_string())
669 })
670 }
671
672 pub fn get_pipeline_with_constants(
694 &mut self,
695 name: &str,
696 device: &metal::DeviceRef,
697 bool_constants: &[(usize, bool)],
698 int_constants: &[(usize, i32)],
699 ) -> Result<&ComputePipelineState> {
700 let mut cache_key = name.to_string();
705 for &(index, value) in bool_constants {
706 cache_key.push('|');
707 cache_key.push_str(&index.to_string());
708 cache_key.push_str(if value { ":b1" } else { ":b0" });
709 }
710 for &(index, value) in int_constants {
711 cache_key.push('|');
712 cache_key.push_str(&index.to_string());
713 cache_key.push(':');
714 cache_key.push('i');
715 cache_key.push_str(&value.to_string());
716 }
717
718 if !self.cache.contains_key(&cache_key) {
719 let source = self.sources.get(name).ok_or_else(|| {
721 MlxError::KernelNotFound(name.to_string())
722 })?;
723
724 let compile_opts = metal::CompileOptions::new();
725 let library = device
726 .new_library_with_source(source, &compile_opts)
727 .map_err(|msg| MlxError::ShaderCompilationError {
728 name: name.to_string(),
729 message: msg,
730 })?;
731
732 let fcv = FunctionConstantValues::new();
737
738 for &(index, value) in bool_constants {
739 let v: u8 = if value { 1 } else { 0 };
742 fcv.set_constant_value_at_index(
743 (&v as *const u8).cast::<std::ffi::c_void>(),
744 MTLDataType::Bool,
745 index as u64,
746 );
747 }
748
749 for &(index, value) in int_constants {
750 fcv.set_constant_value_at_index(
754 (&value as *const i32).cast::<std::ffi::c_void>(),
755 MTLDataType::Int,
756 index as u64,
757 );
758 }
759
760 let function = library
761 .get_function(name, Some(fcv))
762 .map_err(|msg| MlxError::ShaderCompilationError {
763 name: name.to_string(),
764 message: msg,
765 })?;
766
767 let descriptor = ComputePipelineDescriptor::new();
774 descriptor.set_compute_function(Some(&function));
775 descriptor.set_label(&cache_key);
776
777 let pipeline = device
778 .new_compute_pipeline_state(&descriptor)
779 .map_err(|msg| MlxError::ShaderCompilationError {
780 name: name.to_string(),
781 message: msg,
782 })?;
783
784 self.cache.insert(cache_key.clone(), pipeline);
785 }
786
787 self.cache.get(&cache_key).ok_or_else(|| {
788 MlxError::KernelNotFound(name.to_string())
789 })
790 }
791
792 pub fn get_pipeline_with_bool_constants(
810 &mut self,
811 name: &str,
812 device: &metal::DeviceRef,
813 bool_constants: &[(usize, bool)],
814 ) -> Result<&ComputePipelineState> {
815 self.get_pipeline_with_constants(name, device, bool_constants, &[])
816 }
817
818 pub fn is_cached(&self, name: &str) -> bool {
820 self.cache.contains_key(name)
821 }
822
823 pub fn cached_count(&self) -> usize {
825 self.cache.len()
826 }
827
828 pub fn source_count(&self) -> usize {
830 self.sources.len()
831 }
832}
833
834impl Default for KernelRegistry {
835 fn default() -> Self {
836 Self::new()
837 }
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843
844 const INT_FC_TEST_SHADER: &str = r#"
854#include <metal_stdlib>
855using namespace metal;
856
857constant int test_N [[function_constant(100)]];
858
859kernel void int_fc_test_kernel(
860 device int* out [[buffer(0)]],
861 uint tid [[thread_position_in_grid]])
862{
863 if (tid == 0) {
864 out[0] = test_N;
865 }
866}
867"#;
868
869 #[test]
877 fn test_int_fc_distinct_pipelines_and_bool_compat() {
878 let device = metal::Device::system_default()
879 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
880
881 let mut registry = KernelRegistry::new();
882
883 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
886
887 let p4_ptr = registry
889 .get_pipeline_with_constants(
890 "int_fc_test_kernel",
891 &device,
892 &[], &[(100, 4_i32)], )
895 .expect("pipeline N=4 should compile") as *const _;
896
897 let count_after_n4 = registry.cached_count();
901
902 let p8_ptr = registry
904 .get_pipeline_with_constants(
905 "int_fc_test_kernel",
906 &device,
907 &[],
908 &[(100, 8_i32)],
909 )
910 .expect("pipeline N=8 should compile") as *const _;
911
912 assert_eq!(
914 registry.cached_count(),
915 count_after_n4 + 1,
916 "N=8 must produce a new cache entry"
917 );
918
919 assert_ne!(
921 p4_ptr, p8_ptr,
922 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
923 );
924
925 let p4_again_ptr = registry
928 .get_pipeline_with_constants(
929 "int_fc_test_kernel",
930 &device,
931 &[],
932 &[(100, 4_i32)],
933 )
934 .expect("pipeline N=4 cache hit should succeed") as *const _;
935
936 assert_eq!(
937 registry.cached_count(),
938 count_after_n4 + 1,
939 "repeated N=4 call must be a cache hit, not a new entry"
940 );
941 assert_eq!(
942 p4_ptr, p4_again_ptr,
943 "repeated N=4 call must return the same pipeline pointer"
944 );
945
946 const BARE_SHADER: &str = r#"
960#include <metal_stdlib>
961using namespace metal;
962kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
963 if (tid == 0) { out[0] = 42; }
964}
965"#;
966 registry.register_source("bare_kernel", BARE_SHADER);
967
968 let count_before_bool = registry.cached_count();
969 let _bool_pipeline = registry
970 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
971 .expect("bool-constants wrapper with empty slice must succeed");
972
973 assert_eq!(
974 registry.cached_count(),
975 count_before_bool + 1,
976 "bool-constants wrapper must insert one new cache entry"
977 );
978 }
979
980 #[test]
991 fn test_pipeline_labels_propagate_for_mst() {
992 let device = metal::Device::system_default()
993 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
994
995 let mut registry = KernelRegistry::new();
996
997 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
999
1000 const BARE_SHADER_LABEL_TEST: &str = r#"
1001#include <metal_stdlib>
1002using namespace metal;
1003kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1004 if (tid == 0) { out[0] = 7; }
1005}
1006"#;
1007 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1008
1009 let plain_label = registry
1013 .get_pipeline("label_smoke_kernel", &device)
1014 .expect("plain pipeline must compile")
1015 .label()
1016 .to_string();
1017 assert_eq!(
1018 plain_label, "label_smoke_kernel",
1019 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1020 );
1021
1022 let label_v7 = registry
1027 .get_pipeline_with_constants(
1028 "int_fc_test_kernel",
1029 &device,
1030 &[],
1031 &[(100, 7_i32)],
1032 )
1033 .expect("specialised pipeline must compile")
1034 .label()
1035 .to_string();
1036 assert_eq!(
1037 label_v7, "int_fc_test_kernel|100:i7",
1038 "get_pipeline_with_constants must label with the cache_key so each \
1039 specialisation is distinct in xctrace MST"
1040 );
1041
1042 let label_v13 = registry
1044 .get_pipeline_with_constants(
1045 "int_fc_test_kernel",
1046 &device,
1047 &[],
1048 &[(100, 13_i32)],
1049 )
1050 .expect("second specialised pipeline must compile")
1051 .label()
1052 .to_string();
1053 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1054 assert_ne!(
1055 label_v7, label_v13,
1056 "distinct constant values must yield distinct pipeline labels"
1057 );
1058 }
1059}