1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
83
84 let ggml_mm_src: &'static str =
90 include_str!("shaders/quantized_matmul_mm.metal");
91 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
92 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
93 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
94
95 let ggml_mm_tensor_src: &'static str =
106 include_str!("shaders/quantized_matmul_mm_tensor.metal");
107 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
109 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
110 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
111 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
112
113 let dense_mm_bf16_tensor_src: &'static str =
120 include_str!("shaders/dense_mm_bf16_tensor.metal");
121 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
122
123 let dense_mm_f32_f32_tensor_src: &'static str =
132 include_str!("shaders/dense_mm_f32_f32.metal");
133 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
134
135 let dense_mm_f16_tensor_src: &'static str =
147 include_str!("shaders/dense_mm_f16_tensor.metal");
148 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
149
150 let dense_gemv_bf16_src: &'static str =
157 include_str!("shaders/dense_gemv_bf16.metal");
158 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
159
160 let scale_mask_softmax_src: &'static str =
166 include_str!("shaders/scale_mask_softmax.metal");
167 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
168
169 sources.insert(
171 "quantized_matmul_id".into(),
172 include_str!("shaders/quantized_matmul_id.metal"),
173 );
174
175 let ggml_id_src: &'static str =
177 include_str!("shaders/quantized_matmul_id_ggml.metal");
178 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
179 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
180 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
183 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
184 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
185 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
189
190 let ggml_id_mm_src: &'static str =
198 include_str!("shaders/quantized_matmul_id_mm.metal");
199 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
200 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
201 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
202 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
203 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
204 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
206
207 let ggml_id_mm_tensor_src: &'static str =
213 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
214 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
215 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
216 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
217 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
219
220 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
222 sources.insert("embedding_gather_4bit".into(), embedding_src);
223 sources.insert("embedding_gather_6bit".into(), embedding_src);
224
225 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
227 sources.insert("moe_gate".into(), moe_gate_src);
228
229 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
231 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
232 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
233 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
234 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
235 sources.insert("moe_accumulate".into(), moe_dispatch_src);
236 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
237 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
238 sources.insert("zero_buffer".into(), moe_dispatch_src);
239 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
240 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
241 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
243 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
244 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
245
246 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
248 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
249 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
250 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
251 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
252 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
254 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
255 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
257
258 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
260 sources.insert("elementwise_add_f32".into(), elementwise_src);
261 sources.insert("elementwise_add_f16".into(), elementwise_src);
262 sources.insert("elementwise_mul_f32".into(), elementwise_src);
263 sources.insert("elementwise_mul_f16".into(), elementwise_src);
264 sources.insert("elementwise_add_bf16".into(), elementwise_src);
265 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
266 sources.insert("cast_f16_to_f32".into(), elementwise_src);
267 sources.insert("cast_f32_to_f16".into(), elementwise_src);
268 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
269 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
270 sources.insert("scalar_mul_bf16".into(), elementwise_src);
271 sources.insert("scalar_mul_f32".into(), elementwise_src);
272 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
273 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
274 sources.insert("permute_021_bf16".into(), elementwise_src);
275 sources.insert("transpose_last2_bf16".into(), elementwise_src);
276 sources.insert("transpose_last2_f16".into(), elementwise_src);
277 sources.insert("permute_021_f32".into(), elementwise_src);
278 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
279 sources.insert("transpose_2d_f32".into(), elementwise_src);
280 sources.insert("transpose_2d_f16".into(), elementwise_src);
281
282 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
284 sources.insert("sdpa".into(), sdpa_src);
285 sources.insert("sdpa_bf16".into(), sdpa_src);
286 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
287 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
288 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
289
290 let flash_attn_prefill_src: &'static str =
295 include_str!("shaders/flash_attn_prefill.metal");
296 sources.insert(
298 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
299 flash_attn_prefill_src,
300 );
301 sources.insert(
302 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
303 flash_attn_prefill_src,
304 );
305 sources.insert(
306 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
307 flash_attn_prefill_src,
308 );
309 sources.insert(
310 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
311 flash_attn_prefill_src,
312 );
313 sources.insert(
314 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
315 flash_attn_prefill_src,
316 );
317 sources.insert(
318 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
319 flash_attn_prefill_src,
320 );
321 sources.insert(
325 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
326 flash_attn_prefill_src,
327 );
328 sources.insert(
329 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
330 flash_attn_prefill_src,
331 );
332 sources.insert(
333 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
334 flash_attn_prefill_src,
335 );
336 sources.insert(
337 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
338 flash_attn_prefill_src,
339 );
340
341 let flash_attn_vec_src: &'static str =
344 include_str!("shaders/flash_attn_vec.metal");
345 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
346 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
347 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
348 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
349 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
351 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
352
353 let rope_src: &'static str = include_str!("shaders/rope.metal");
355 sources.insert("rope_f32".into(), rope_src);
356 sources.insert("rope_f16".into(), rope_src);
357 sources.insert("rope_bf16".into(), rope_src);
358 sources.insert("rope_neox_bf16".into(), rope_src);
359 sources.insert("rope_neox_f32".into(), rope_src);
360 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
361 sources.insert("rms_norm_f32".into(), rms_norm_src);
362 sources.insert("rms_norm_f16".into(), rms_norm_src);
363 sources.insert("rms_norm_bf16".into(), rms_norm_src);
364 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
365 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
366 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
367 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
368 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
369 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
370 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
372 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
373 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
374 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
376 sources.insert("l2_norm_f32".into(), l2_norm_src);
377 sources.insert("l2_norm_f16".into(), l2_norm_src);
378 sources.insert("l2_norm_bf16".into(), l2_norm_src);
379 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
381 sources.insert("cumsum_f32".into(), cumsum_src);
382 sources.insert("cumsum_bf16".into(), cumsum_src);
383 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
385 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
386 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
387 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
388 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
389 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
391 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
392 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
393 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
395 sources.insert("rope_multi_f32".into(), rope_multi_src);
396 sources.insert("rope_multi_bf16".into(), rope_multi_src);
397 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
399 sources.insert("gated_delta_net_f32".into(), gdn_src);
400 let gdn_decode_src: &'static str =
405 include_str!("shaders/gated_delta_net_decode.metal");
406 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
407 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
408 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
409 let gdn_chunk_src: &'static str =
413 include_str!("shaders/gated_delta_net_chunk.metal");
414 sources.insert(
415 "gated_delta_net_chunk_inter_state_bf16".into(),
416 gdn_chunk_src,
417 );
418 let gdn_kkt_src: &'static str =
421 include_str!("shaders/gated_delta_net_kkt.metal");
422 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
423 let gdn_recompute_wu_src: &'static str =
427 include_str!("shaders/gated_delta_net_recompute_wu.metal");
428 sources.insert(
429 "gated_delta_net_recompute_wu_bf16".into(),
430 gdn_recompute_wu_src,
431 );
432 let gdn_chunk_o_src: &'static str =
435 include_str!("shaders/gated_delta_net_chunk_o.metal");
436 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
437 let chunk_local_cumsum_g_src: &'static str =
442 include_str!("shaders/chunk_local_cumsum_g.metal");
443 sources.insert(
444 "chunk_local_cumsum_g_f32".into(),
445 chunk_local_cumsum_g_src,
446 );
447 let chunk_tri_solve_invert_src: &'static str =
448 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
449 sources.insert(
450 "chunk_tri_solve_invert_f32".into(),
451 chunk_tri_solve_invert_src,
452 );
453 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
455 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
456 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
457 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
458 sources.insert("silu_mul_f32".into(), silu_mul_src);
459 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
460 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
461 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
462 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
463 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
464 sources.insert("gelu_f32".into(), gelu_src);
465 sources.insert("gelu_f16".into(), gelu_src);
466 sources.insert("gelu_bf16".into(), gelu_src);
467 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
468 sources.insert("softmax_f32".into(), softmax_src);
469 sources.insert("softmax_f16".into(), softmax_src);
470 sources.insert("softmax_bf16".into(), softmax_src);
471 let softmax_backward_src: &'static str =
472 include_str!("shaders/softmax_backward.metal");
473 sources.insert("softmax_backward_f32".into(), softmax_backward_src);
474 let log_elementwise_src: &'static str =
475 include_str!("shaders/log_elementwise.metal");
476 sources.insert("log_f32".into(), log_elementwise_src);
477 sources.insert("log_backward_f32".into(), log_elementwise_src);
478 let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
479 sources.insert("row_sum_f32".into(), row_sum_src);
480 sources.insert("row_sum_backward_f32".into(), row_sum_src);
481 let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
485 sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
486 sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
487 let rms_norm_backward_src: &'static str =
491 include_str!("shaders/rms_norm_backward.metal");
492 sources.insert(
493 "rms_norm_compute_rms_inv_f32".into(),
494 rms_norm_backward_src,
495 );
496 sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
497 sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
498 let slice_concat_2d_src: &'static str =
503 include_str!("shaders/slice_concat_2d.metal");
504 sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
505 sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
506 let silu_backward_src: &'static str =
509 include_str!("shaders/silu_backward.metal");
510 sources.insert("silu_f32".into(), silu_backward_src);
511 sources.insert("silu_backward_f32".into(), silu_backward_src);
512 let embedding_autograd_src: &'static str =
514 include_str!("shaders/embedding_autograd.metal");
515 sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
516 sources.insert(
517 "embedding_scatter_add_f32".into(),
518 embedding_autograd_src,
519 );
520 let adam_update_src: &'static str =
523 include_str!("shaders/adam_update.metal");
524 sources.insert("adam_update_f32".into(), adam_update_src);
525 let qdq_affine_src: &'static str =
529 include_str!("shaders/qdq_affine.metal");
530 sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
531 sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
532 sources.insert(
533 "qdq_affine_backward_scales_f32".into(),
534 qdq_affine_src,
535 );
536 sources.insert(
537 "qdq_affine_backward_biases_f32".into(),
538 qdq_affine_src,
539 );
540 let qmm_affine_src: &'static str =
544 include_str!("shaders/qmm_affine.metal");
545 sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
546 let qmm_affine_tiled_src: &'static str =
550 include_str!("shaders/qmm_affine_tiled.metal");
551 sources.insert(
552 "qmm_affine_t_f32_tiled".into(),
553 qmm_affine_tiled_src,
554 );
555 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
556 sources.insert("softcap_f32".into(), softcap_src);
557 sources.insert("softcap_f16".into(), softcap_src);
558 sources.insert("softcap_bf16".into(), softcap_src);
559
560 let fused_norm_add_src: &'static str =
563 include_str!("shaders/fused_norm_add_bf16.metal");
564 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
565 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
566
567 let fused_hnr_f32_src: &'static str =
569 include_str!("shaders/fused_head_norm_rope_f32.metal");
570 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
571
572 let fused_hnr_bf16_src: &'static str =
575 include_str!("shaders/fused_head_norm_rope_bf16.metal");
576 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
577 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
578
579 let fused_norm_add_f32_src: &'static str =
581 include_str!("shaders/fused_norm_add_f32.metal");
582 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
583 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
584 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
585 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
586 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
587 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
588 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
589 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
590
591 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
593 sources.insert("argsort_desc_f32".into(), argsort_src);
594
595 let gather_src: &'static str = include_str!("shaders/gather.metal");
597 sources.insert("gather_f32".into(), gather_src);
598
599 let kv_cache_copy_src: &'static str =
601 include_str!("shaders/kv_cache_copy.metal");
602 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
603 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
604
605 let copy_src: &'static str = include_str!("shaders/copy.metal");
607 sources.insert("strided_copy_f32".into(), copy_src);
608 sources.insert("offset_copy_f32".into(), copy_src);
609
610 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
614 sources.insert("qkv_split_f32".into(), qkv_split_src);
615
616 let repeat_tiled_src: &'static str =
620 include_str!("shaders/repeat_tiled.metal");
621 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
622
623 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
625 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
626 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
627 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
628 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
630 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
632
633 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
635 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
636 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
637 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
639 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
640 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
641 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
642
643 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
645 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
646 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
647 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
649 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
650
651 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
653 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
654 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
656
657 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
659 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
660 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
661
662 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
664 sources.insert("argmax_f32".into(), argmax_src);
665 let softmax_sample_src: &'static str =
666 include_str!("shaders/softmax_sample.metal");
667 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
668 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
670 sources.insert("top_k_f32".into(), top_k_src);
671
672 let moe_stk_src: &'static str =
675 include_str!("shaders/moe_softmax_topk.metal");
676 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
677 let moe_wr_src: &'static str =
678 include_str!("shaders/moe_weighted_reduce.metal");
679 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
680 let sdpa_decode_src: &'static str =
681 include_str!("shaders/sdpa_decode.metal");
682 sources.insert("sdpa_decode".into(), sdpa_decode_src);
683
684 Self {
685 cache: HashMap::new(),
686 sources,
687 }
688 }
689
690 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
693 let name = name.into();
694 self.cache.remove(&name);
696 self.sources.insert(name, source);
697 }
698
699 pub fn get_pipeline(
711 &mut self,
712 name: &str,
713 device: &metal::DeviceRef,
714 ) -> Result<&ComputePipelineState> {
715 if !self.cache.contains_key(name) {
716 let source = self.sources.get(name).ok_or_else(|| {
718 MlxError::KernelNotFound(name.to_string())
719 })?;
720
721 let compile_opts = metal::CompileOptions::new();
722 let library = device
723 .new_library_with_source(source, &compile_opts)
724 .map_err(|msg| MlxError::ShaderCompilationError {
725 name: name.to_string(),
726 message: msg,
727 })?;
728
729 let function = library
730 .get_function(name, None)
731 .map_err(|msg| MlxError::ShaderCompilationError {
732 name: name.to_string(),
733 message: msg,
734 })?;
735
736 let descriptor = ComputePipelineDescriptor::new();
746 descriptor.set_compute_function(Some(&function));
747 descriptor.set_label(name);
748
749 let pipeline = device
750 .new_compute_pipeline_state(&descriptor)
751 .map_err(|msg| MlxError::ShaderCompilationError {
752 name: name.to_string(),
753 message: msg,
754 })?;
755
756 self.cache.insert(name.to_string(), pipeline);
757 }
758
759 self.cache.get(name).ok_or_else(|| {
762 MlxError::KernelNotFound(name.to_string())
763 })
764 }
765
766 pub fn get_pipeline_with_constants(
788 &mut self,
789 name: &str,
790 device: &metal::DeviceRef,
791 bool_constants: &[(usize, bool)],
792 int_constants: &[(usize, i32)],
793 ) -> Result<&ComputePipelineState> {
794 let mut cache_key = name.to_string();
799 for &(index, value) in bool_constants {
800 cache_key.push('|');
801 cache_key.push_str(&index.to_string());
802 cache_key.push_str(if value { ":b1" } else { ":b0" });
803 }
804 for &(index, value) in int_constants {
805 cache_key.push('|');
806 cache_key.push_str(&index.to_string());
807 cache_key.push(':');
808 cache_key.push('i');
809 cache_key.push_str(&value.to_string());
810 }
811
812 if !self.cache.contains_key(&cache_key) {
813 let source = self.sources.get(name).ok_or_else(|| {
815 MlxError::KernelNotFound(name.to_string())
816 })?;
817
818 let compile_opts = metal::CompileOptions::new();
819 let library = device
820 .new_library_with_source(source, &compile_opts)
821 .map_err(|msg| MlxError::ShaderCompilationError {
822 name: name.to_string(),
823 message: msg,
824 })?;
825
826 let fcv = FunctionConstantValues::new();
831
832 for &(index, value) in bool_constants {
833 let v: u8 = if value { 1 } else { 0 };
836 fcv.set_constant_value_at_index(
837 (&v as *const u8).cast::<std::ffi::c_void>(),
838 MTLDataType::Bool,
839 index as u64,
840 );
841 }
842
843 for &(index, value) in int_constants {
844 fcv.set_constant_value_at_index(
848 (&value as *const i32).cast::<std::ffi::c_void>(),
849 MTLDataType::Int,
850 index as u64,
851 );
852 }
853
854 let function = library
855 .get_function(name, Some(fcv))
856 .map_err(|msg| MlxError::ShaderCompilationError {
857 name: name.to_string(),
858 message: msg,
859 })?;
860
861 let descriptor = ComputePipelineDescriptor::new();
868 descriptor.set_compute_function(Some(&function));
869 descriptor.set_label(&cache_key);
870
871 let pipeline = device
872 .new_compute_pipeline_state(&descriptor)
873 .map_err(|msg| MlxError::ShaderCompilationError {
874 name: name.to_string(),
875 message: msg,
876 })?;
877
878 self.cache.insert(cache_key.clone(), pipeline);
879 }
880
881 self.cache.get(&cache_key).ok_or_else(|| {
882 MlxError::KernelNotFound(name.to_string())
883 })
884 }
885
886 pub fn get_pipeline_with_bool_constants(
904 &mut self,
905 name: &str,
906 device: &metal::DeviceRef,
907 bool_constants: &[(usize, bool)],
908 ) -> Result<&ComputePipelineState> {
909 self.get_pipeline_with_constants(name, device, bool_constants, &[])
910 }
911
912 pub fn is_cached(&self, name: &str) -> bool {
914 self.cache.contains_key(name)
915 }
916
917 pub fn cached_count(&self) -> usize {
919 self.cache.len()
920 }
921
922 pub fn source_count(&self) -> usize {
924 self.sources.len()
925 }
926}
927
928impl Default for KernelRegistry {
929 fn default() -> Self {
930 Self::new()
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937
938 const INT_FC_TEST_SHADER: &str = r#"
948#include <metal_stdlib>
949using namespace metal;
950
951constant int test_N [[function_constant(100)]];
952
953kernel void int_fc_test_kernel(
954 device int* out [[buffer(0)]],
955 uint tid [[thread_position_in_grid]])
956{
957 if (tid == 0) {
958 out[0] = test_N;
959 }
960}
961"#;
962
963 #[test]
971 fn test_int_fc_distinct_pipelines_and_bool_compat() {
972 let device = metal::Device::system_default()
973 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
974
975 let mut registry = KernelRegistry::new();
976
977 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
980
981 let p4_ptr = registry
983 .get_pipeline_with_constants(
984 "int_fc_test_kernel",
985 &device,
986 &[], &[(100, 4_i32)], )
989 .expect("pipeline N=4 should compile") as *const _;
990
991 let count_after_n4 = registry.cached_count();
995
996 let p8_ptr = registry
998 .get_pipeline_with_constants(
999 "int_fc_test_kernel",
1000 &device,
1001 &[],
1002 &[(100, 8_i32)],
1003 )
1004 .expect("pipeline N=8 should compile") as *const _;
1005
1006 assert_eq!(
1008 registry.cached_count(),
1009 count_after_n4 + 1,
1010 "N=8 must produce a new cache entry"
1011 );
1012
1013 assert_ne!(
1015 p4_ptr, p8_ptr,
1016 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1017 );
1018
1019 let p4_again_ptr = registry
1022 .get_pipeline_with_constants(
1023 "int_fc_test_kernel",
1024 &device,
1025 &[],
1026 &[(100, 4_i32)],
1027 )
1028 .expect("pipeline N=4 cache hit should succeed") as *const _;
1029
1030 assert_eq!(
1031 registry.cached_count(),
1032 count_after_n4 + 1,
1033 "repeated N=4 call must be a cache hit, not a new entry"
1034 );
1035 assert_eq!(
1036 p4_ptr, p4_again_ptr,
1037 "repeated N=4 call must return the same pipeline pointer"
1038 );
1039
1040 const BARE_SHADER: &str = r#"
1054#include <metal_stdlib>
1055using namespace metal;
1056kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1057 if (tid == 0) { out[0] = 42; }
1058}
1059"#;
1060 registry.register_source("bare_kernel", BARE_SHADER);
1061
1062 let count_before_bool = registry.cached_count();
1063 let _bool_pipeline = registry
1064 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1065 .expect("bool-constants wrapper with empty slice must succeed");
1066
1067 assert_eq!(
1068 registry.cached_count(),
1069 count_before_bool + 1,
1070 "bool-constants wrapper must insert one new cache entry"
1071 );
1072 }
1073
1074 #[test]
1085 fn test_pipeline_labels_propagate_for_mst() {
1086 let device = metal::Device::system_default()
1087 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1088
1089 let mut registry = KernelRegistry::new();
1090
1091 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1093
1094 const BARE_SHADER_LABEL_TEST: &str = r#"
1095#include <metal_stdlib>
1096using namespace metal;
1097kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1098 if (tid == 0) { out[0] = 7; }
1099}
1100"#;
1101 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1102
1103 let plain_label = registry
1107 .get_pipeline("label_smoke_kernel", &device)
1108 .expect("plain pipeline must compile")
1109 .label()
1110 .to_string();
1111 assert_eq!(
1112 plain_label, "label_smoke_kernel",
1113 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1114 );
1115
1116 let label_v7 = registry
1121 .get_pipeline_with_constants(
1122 "int_fc_test_kernel",
1123 &device,
1124 &[],
1125 &[(100, 7_i32)],
1126 )
1127 .expect("specialised pipeline must compile")
1128 .label()
1129 .to_string();
1130 assert_eq!(
1131 label_v7, "int_fc_test_kernel|100:i7",
1132 "get_pipeline_with_constants must label with the cache_key so each \
1133 specialisation is distinct in xctrace MST"
1134 );
1135
1136 let label_v13 = registry
1138 .get_pipeline_with_constants(
1139 "int_fc_test_kernel",
1140 &device,
1141 &[],
1142 &[(100, 13_i32)],
1143 )
1144 .expect("second specialised pipeline must compile")
1145 .label()
1146 .to_string();
1147 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1148 assert_ne!(
1149 label_v7, label_v13,
1150 "distinct constant values must yield distinct pipeline labels"
1151 );
1152 }
1153}