1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
83
84 let ggml_mm_src: &'static str =
90 include_str!("shaders/quantized_matmul_mm.metal");
91 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
92 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
93 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
94
95 let ggml_mm_tensor_src: &'static str =
106 include_str!("shaders/quantized_matmul_mm_tensor.metal");
107 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
109 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
110 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
111 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
112
113 let dense_mm_bf16_tensor_src: &'static str =
120 include_str!("shaders/dense_mm_bf16_tensor.metal");
121 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
122
123 let dense_mm_f32_f32_tensor_src: &'static str =
132 include_str!("shaders/dense_mm_f32_f32.metal");
133 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
134
135 let dense_mm_f16_tensor_src: &'static str =
147 include_str!("shaders/dense_mm_f16_tensor.metal");
148 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
149
150 let dense_gemv_bf16_src: &'static str =
157 include_str!("shaders/dense_gemv_bf16.metal");
158 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
159
160 let scale_mask_softmax_src: &'static str =
166 include_str!("shaders/scale_mask_softmax.metal");
167 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
168
169 sources.insert(
171 "quantized_matmul_id".into(),
172 include_str!("shaders/quantized_matmul_id.metal"),
173 );
174
175 let ggml_id_src: &'static str =
177 include_str!("shaders/quantized_matmul_id_ggml.metal");
178 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
179 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
180 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
183 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
184 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
185 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
189
190 let ggml_id_mm_src: &'static str =
198 include_str!("shaders/quantized_matmul_id_mm.metal");
199 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
200 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
201 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
202 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
203 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
204 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
206
207 let ggml_id_mm_tensor_src: &'static str =
213 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
214 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
215 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
216 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
217 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
219
220 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
222 sources.insert("embedding_gather_4bit".into(), embedding_src);
223 sources.insert("embedding_gather_6bit".into(), embedding_src);
224
225 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
227 sources.insert("moe_gate".into(), moe_gate_src);
228
229 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
231 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
232 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
233 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
234 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
235 sources.insert("moe_accumulate".into(), moe_dispatch_src);
236 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
237 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
238 sources.insert("zero_buffer".into(), moe_dispatch_src);
239 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
240 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
241 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
243 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
244 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
245
246 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
248 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
249 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
250 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
251 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
252 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
254 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
255 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
257
258 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
260 sources.insert("elementwise_add_f32".into(), elementwise_src);
261 sources.insert("elementwise_add_f16".into(), elementwise_src);
262 sources.insert("elementwise_mul_f32".into(), elementwise_src);
263 sources.insert("elementwise_mul_f16".into(), elementwise_src);
264 sources.insert("elementwise_add_bf16".into(), elementwise_src);
265 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
266 sources.insert("cast_f16_to_f32".into(), elementwise_src);
267 sources.insert("cast_f32_to_f16".into(), elementwise_src);
268 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
269 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
270 sources.insert("scalar_mul_bf16".into(), elementwise_src);
271 sources.insert("scalar_mul_f32".into(), elementwise_src);
272 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
273 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
274 sources.insert("permute_021_bf16".into(), elementwise_src);
275 sources.insert("transpose_last2_bf16".into(), elementwise_src);
276 sources.insert("transpose_last2_f16".into(), elementwise_src);
277 sources.insert("permute_021_f32".into(), elementwise_src);
278 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
279 sources.insert("transpose_2d_f32".into(), elementwise_src);
280 sources.insert("transpose_2d_f16".into(), elementwise_src);
281
282 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
284 sources.insert("sdpa".into(), sdpa_src);
285 sources.insert("sdpa_bf16".into(), sdpa_src);
286 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
287 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
288 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
289
290 let flash_attn_prefill_src: &'static str =
295 include_str!("shaders/flash_attn_prefill.metal");
296 sources.insert(
298 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
299 flash_attn_prefill_src,
300 );
301 sources.insert(
302 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
303 flash_attn_prefill_src,
304 );
305 sources.insert(
306 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
307 flash_attn_prefill_src,
308 );
309 sources.insert(
310 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
311 flash_attn_prefill_src,
312 );
313 sources.insert(
314 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
315 flash_attn_prefill_src,
316 );
317 sources.insert(
318 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
319 flash_attn_prefill_src,
320 );
321 sources.insert(
325 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
326 flash_attn_prefill_src,
327 );
328 sources.insert(
329 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
330 flash_attn_prefill_src,
331 );
332 sources.insert(
333 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
334 flash_attn_prefill_src,
335 );
336 sources.insert(
337 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
338 flash_attn_prefill_src,
339 );
340
341 let flash_attn_vec_src: &'static str =
344 include_str!("shaders/flash_attn_vec.metal");
345 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
346 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
347 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
348 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
349 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
351 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
352
353 let rope_src: &'static str = include_str!("shaders/rope.metal");
355 sources.insert("rope_f32".into(), rope_src);
356 sources.insert("rope_f16".into(), rope_src);
357 sources.insert("rope_bf16".into(), rope_src);
358 sources.insert("rope_neox_bf16".into(), rope_src);
359 sources.insert("rope_neox_f32".into(), rope_src);
360 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
361 sources.insert("rms_norm_f32".into(), rms_norm_src);
362 sources.insert("rms_norm_f16".into(), rms_norm_src);
363 sources.insert("rms_norm_bf16".into(), rms_norm_src);
364 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
365 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
366 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
367 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
368 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
369 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
370 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
372 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
373 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
374 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
376 sources.insert("l2_norm_f32".into(), l2_norm_src);
377 sources.insert("l2_norm_f16".into(), l2_norm_src);
378 sources.insert("l2_norm_bf16".into(), l2_norm_src);
379 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
381 sources.insert("cumsum_f32".into(), cumsum_src);
382 sources.insert("cumsum_bf16".into(), cumsum_src);
383 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
385 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
386 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
387 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
388 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
389 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
391 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
392 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
393 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
395 sources.insert("rope_multi_f32".into(), rope_multi_src);
396 sources.insert("rope_multi_bf16".into(), rope_multi_src);
397 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
399 sources.insert("gated_delta_net_f32".into(), gdn_src);
400 let gdn_decode_src: &'static str =
405 include_str!("shaders/gated_delta_net_decode.metal");
406 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
407 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
408 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
409 let gdn_chunk_src: &'static str =
413 include_str!("shaders/gated_delta_net_chunk.metal");
414 sources.insert(
415 "gated_delta_net_chunk_inter_state_bf16".into(),
416 gdn_chunk_src,
417 );
418 let gdn_kkt_src: &'static str =
421 include_str!("shaders/gated_delta_net_kkt.metal");
422 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
423 let gdn_recompute_wu_src: &'static str =
427 include_str!("shaders/gated_delta_net_recompute_wu.metal");
428 sources.insert(
429 "gated_delta_net_recompute_wu_bf16".into(),
430 gdn_recompute_wu_src,
431 );
432 let gdn_chunk_o_src: &'static str =
435 include_str!("shaders/gated_delta_net_chunk_o.metal");
436 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
437 let chunk_local_cumsum_g_src: &'static str =
442 include_str!("shaders/chunk_local_cumsum_g.metal");
443 sources.insert(
444 "chunk_local_cumsum_g_f32".into(),
445 chunk_local_cumsum_g_src,
446 );
447 let chunk_tri_solve_invert_src: &'static str =
448 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
449 sources.insert(
450 "chunk_tri_solve_invert_f32".into(),
451 chunk_tri_solve_invert_src,
452 );
453 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
455 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
456 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
457 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
458 sources.insert("silu_mul_f32".into(), silu_mul_src);
459 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
460 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
461 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
462 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
463 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
464 sources.insert("gelu_f32".into(), gelu_src);
465 sources.insert("gelu_f16".into(), gelu_src);
466 sources.insert("gelu_bf16".into(), gelu_src);
467 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
468 sources.insert("softmax_f32".into(), softmax_src);
469 sources.insert("softmax_f16".into(), softmax_src);
470 sources.insert("softmax_bf16".into(), softmax_src);
471 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
472 sources.insert("softcap_f32".into(), softcap_src);
473 sources.insert("softcap_f16".into(), softcap_src);
474 sources.insert("softcap_bf16".into(), softcap_src);
475
476 let fused_norm_add_src: &'static str =
479 include_str!("shaders/fused_norm_add_bf16.metal");
480 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
481 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
482
483 let fused_hnr_f32_src: &'static str =
485 include_str!("shaders/fused_head_norm_rope_f32.metal");
486 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
487
488 let fused_hnr_bf16_src: &'static str =
491 include_str!("shaders/fused_head_norm_rope_bf16.metal");
492 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
493 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
494
495 let fused_norm_add_f32_src: &'static str =
497 include_str!("shaders/fused_norm_add_f32.metal");
498 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
499 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
500 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
501 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
502 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
503 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
504 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
505 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
506
507 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
509 sources.insert("argsort_desc_f32".into(), argsort_src);
510
511 let gather_src: &'static str = include_str!("shaders/gather.metal");
513 sources.insert("gather_f32".into(), gather_src);
514
515 let kv_cache_copy_src: &'static str =
517 include_str!("shaders/kv_cache_copy.metal");
518 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
519 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
520
521 let copy_src: &'static str = include_str!("shaders/copy.metal");
523 sources.insert("strided_copy_f32".into(), copy_src);
524 sources.insert("offset_copy_f32".into(), copy_src);
525
526 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
530 sources.insert("qkv_split_f32".into(), qkv_split_src);
531
532 let repeat_tiled_src: &'static str =
536 include_str!("shaders/repeat_tiled.metal");
537 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
538
539 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
541 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
542 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
543 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
544 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
546 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
548
549 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
551 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
552 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
553 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
555 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
556 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
557 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
558
559 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
561 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
562 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
563 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
565 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
566
567 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
569 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
570 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
572
573 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
575 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
576 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
577
578 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
580 sources.insert("argmax_f32".into(), argmax_src);
581 let softmax_sample_src: &'static str =
582 include_str!("shaders/softmax_sample.metal");
583 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
584 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
586 sources.insert("top_k_f32".into(), top_k_src);
587
588 let moe_stk_src: &'static str =
591 include_str!("shaders/moe_softmax_topk.metal");
592 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
593 let moe_wr_src: &'static str =
594 include_str!("shaders/moe_weighted_reduce.metal");
595 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
596 let sdpa_decode_src: &'static str =
597 include_str!("shaders/sdpa_decode.metal");
598 sources.insert("sdpa_decode".into(), sdpa_decode_src);
599
600 Self {
601 cache: HashMap::new(),
602 sources,
603 }
604 }
605
606 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
609 let name = name.into();
610 self.cache.remove(&name);
612 self.sources.insert(name, source);
613 }
614
615 pub fn get_pipeline(
627 &mut self,
628 name: &str,
629 device: &metal::DeviceRef,
630 ) -> Result<&ComputePipelineState> {
631 if !self.cache.contains_key(name) {
632 let source = self.sources.get(name).ok_or_else(|| {
634 MlxError::KernelNotFound(name.to_string())
635 })?;
636
637 let compile_opts = metal::CompileOptions::new();
638 let library = device
639 .new_library_with_source(source, &compile_opts)
640 .map_err(|msg| MlxError::ShaderCompilationError {
641 name: name.to_string(),
642 message: msg,
643 })?;
644
645 let function = library
646 .get_function(name, None)
647 .map_err(|msg| MlxError::ShaderCompilationError {
648 name: name.to_string(),
649 message: msg,
650 })?;
651
652 let descriptor = ComputePipelineDescriptor::new();
662 descriptor.set_compute_function(Some(&function));
663 descriptor.set_label(name);
664
665 let pipeline = device
666 .new_compute_pipeline_state(&descriptor)
667 .map_err(|msg| MlxError::ShaderCompilationError {
668 name: name.to_string(),
669 message: msg,
670 })?;
671
672 self.cache.insert(name.to_string(), pipeline);
673 }
674
675 self.cache.get(name).ok_or_else(|| {
678 MlxError::KernelNotFound(name.to_string())
679 })
680 }
681
682 pub fn get_pipeline_with_constants(
704 &mut self,
705 name: &str,
706 device: &metal::DeviceRef,
707 bool_constants: &[(usize, bool)],
708 int_constants: &[(usize, i32)],
709 ) -> Result<&ComputePipelineState> {
710 let mut cache_key = name.to_string();
715 for &(index, value) in bool_constants {
716 cache_key.push('|');
717 cache_key.push_str(&index.to_string());
718 cache_key.push_str(if value { ":b1" } else { ":b0" });
719 }
720 for &(index, value) in int_constants {
721 cache_key.push('|');
722 cache_key.push_str(&index.to_string());
723 cache_key.push(':');
724 cache_key.push('i');
725 cache_key.push_str(&value.to_string());
726 }
727
728 if !self.cache.contains_key(&cache_key) {
729 let source = self.sources.get(name).ok_or_else(|| {
731 MlxError::KernelNotFound(name.to_string())
732 })?;
733
734 let compile_opts = metal::CompileOptions::new();
735 let library = device
736 .new_library_with_source(source, &compile_opts)
737 .map_err(|msg| MlxError::ShaderCompilationError {
738 name: name.to_string(),
739 message: msg,
740 })?;
741
742 let fcv = FunctionConstantValues::new();
747
748 for &(index, value) in bool_constants {
749 let v: u8 = if value { 1 } else { 0 };
752 fcv.set_constant_value_at_index(
753 (&v as *const u8).cast::<std::ffi::c_void>(),
754 MTLDataType::Bool,
755 index as u64,
756 );
757 }
758
759 for &(index, value) in int_constants {
760 fcv.set_constant_value_at_index(
764 (&value as *const i32).cast::<std::ffi::c_void>(),
765 MTLDataType::Int,
766 index as u64,
767 );
768 }
769
770 let function = library
771 .get_function(name, Some(fcv))
772 .map_err(|msg| MlxError::ShaderCompilationError {
773 name: name.to_string(),
774 message: msg,
775 })?;
776
777 let descriptor = ComputePipelineDescriptor::new();
784 descriptor.set_compute_function(Some(&function));
785 descriptor.set_label(&cache_key);
786
787 let pipeline = device
788 .new_compute_pipeline_state(&descriptor)
789 .map_err(|msg| MlxError::ShaderCompilationError {
790 name: name.to_string(),
791 message: msg,
792 })?;
793
794 self.cache.insert(cache_key.clone(), pipeline);
795 }
796
797 self.cache.get(&cache_key).ok_or_else(|| {
798 MlxError::KernelNotFound(name.to_string())
799 })
800 }
801
802 pub fn get_pipeline_with_bool_constants(
820 &mut self,
821 name: &str,
822 device: &metal::DeviceRef,
823 bool_constants: &[(usize, bool)],
824 ) -> Result<&ComputePipelineState> {
825 self.get_pipeline_with_constants(name, device, bool_constants, &[])
826 }
827
828 pub fn is_cached(&self, name: &str) -> bool {
830 self.cache.contains_key(name)
831 }
832
833 pub fn cached_count(&self) -> usize {
835 self.cache.len()
836 }
837
838 pub fn source_count(&self) -> usize {
840 self.sources.len()
841 }
842}
843
844impl Default for KernelRegistry {
845 fn default() -> Self {
846 Self::new()
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 const INT_FC_TEST_SHADER: &str = r#"
864#include <metal_stdlib>
865using namespace metal;
866
867constant int test_N [[function_constant(100)]];
868
869kernel void int_fc_test_kernel(
870 device int* out [[buffer(0)]],
871 uint tid [[thread_position_in_grid]])
872{
873 if (tid == 0) {
874 out[0] = test_N;
875 }
876}
877"#;
878
879 #[test]
887 fn test_int_fc_distinct_pipelines_and_bool_compat() {
888 let device = metal::Device::system_default()
889 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
890
891 let mut registry = KernelRegistry::new();
892
893 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
896
897 let p4_ptr = registry
899 .get_pipeline_with_constants(
900 "int_fc_test_kernel",
901 &device,
902 &[], &[(100, 4_i32)], )
905 .expect("pipeline N=4 should compile") as *const _;
906
907 let count_after_n4 = registry.cached_count();
911
912 let p8_ptr = registry
914 .get_pipeline_with_constants(
915 "int_fc_test_kernel",
916 &device,
917 &[],
918 &[(100, 8_i32)],
919 )
920 .expect("pipeline N=8 should compile") as *const _;
921
922 assert_eq!(
924 registry.cached_count(),
925 count_after_n4 + 1,
926 "N=8 must produce a new cache entry"
927 );
928
929 assert_ne!(
931 p4_ptr, p8_ptr,
932 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
933 );
934
935 let p4_again_ptr = registry
938 .get_pipeline_with_constants(
939 "int_fc_test_kernel",
940 &device,
941 &[],
942 &[(100, 4_i32)],
943 )
944 .expect("pipeline N=4 cache hit should succeed") as *const _;
945
946 assert_eq!(
947 registry.cached_count(),
948 count_after_n4 + 1,
949 "repeated N=4 call must be a cache hit, not a new entry"
950 );
951 assert_eq!(
952 p4_ptr, p4_again_ptr,
953 "repeated N=4 call must return the same pipeline pointer"
954 );
955
956 const BARE_SHADER: &str = r#"
970#include <metal_stdlib>
971using namespace metal;
972kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
973 if (tid == 0) { out[0] = 42; }
974}
975"#;
976 registry.register_source("bare_kernel", BARE_SHADER);
977
978 let count_before_bool = registry.cached_count();
979 let _bool_pipeline = registry
980 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
981 .expect("bool-constants wrapper with empty slice must succeed");
982
983 assert_eq!(
984 registry.cached_count(),
985 count_before_bool + 1,
986 "bool-constants wrapper must insert one new cache entry"
987 );
988 }
989
990 #[test]
1001 fn test_pipeline_labels_propagate_for_mst() {
1002 let device = metal::Device::system_default()
1003 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1004
1005 let mut registry = KernelRegistry::new();
1006
1007 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1009
1010 const BARE_SHADER_LABEL_TEST: &str = r#"
1011#include <metal_stdlib>
1012using namespace metal;
1013kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1014 if (tid == 0) { out[0] = 7; }
1015}
1016"#;
1017 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1018
1019 let plain_label = registry
1023 .get_pipeline("label_smoke_kernel", &device)
1024 .expect("plain pipeline must compile")
1025 .label()
1026 .to_string();
1027 assert_eq!(
1028 plain_label, "label_smoke_kernel",
1029 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1030 );
1031
1032 let label_v7 = registry
1037 .get_pipeline_with_constants(
1038 "int_fc_test_kernel",
1039 &device,
1040 &[],
1041 &[(100, 7_i32)],
1042 )
1043 .expect("specialised pipeline must compile")
1044 .label()
1045 .to_string();
1046 assert_eq!(
1047 label_v7, "int_fc_test_kernel|100:i7",
1048 "get_pipeline_with_constants must label with the cache_key so each \
1049 specialisation is distinct in xctrace MST"
1050 );
1051
1052 let label_v13 = registry
1054 .get_pipeline_with_constants(
1055 "int_fc_test_kernel",
1056 &device,
1057 &[],
1058 &[(100, 13_i32)],
1059 )
1060 .expect("second specialised pipeline must compile")
1061 .label()
1062 .to_string();
1063 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1064 assert_ne!(
1065 label_v7, label_v13,
1066 "distinct constant values must yield distinct pipeline labels"
1067 );
1068 }
1069}