1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80 sources.insert("kernel_mul_mv_q6_K_f32_nr2".into(), ggml_src);
85 sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
87 sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
88 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
91 sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
93
94 let ggml_mm_src: &'static str =
100 include_str!("shaders/quantized_matmul_mm.metal");
101 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
102 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
103 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
104 sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
106 sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
107 sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
109 sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
111
112 let ggml_mm_tensor_src: &'static str =
123 include_str!("shaders/quantized_matmul_mm_tensor.metal");
124 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
125 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
126 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
127 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
128 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
129 sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
131 sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
132 sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
134 sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
136 sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
137
138 let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
143 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
144 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
145 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
146 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
147 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
148 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
149 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
150 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
151 for r1 in [2, 3, 4, 5].iter() {
154 for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
155 let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
156 sources.insert(name, mul_mv_ext_src);
157 }
158 }
159
160 let dense_mm_bf16_tensor_src: &'static str =
167 include_str!("shaders/dense_mm_bf16_tensor.metal");
168 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
169
170 let dense_mm_f32_f32_tensor_src: &'static str =
179 include_str!("shaders/dense_mm_f32_f32.metal");
180 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
181
182 let dense_mm_f16_tensor_src: &'static str =
194 include_str!("shaders/dense_mm_f16_tensor.metal");
195 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
196
197 let dense_gemv_bf16_src: &'static str =
204 include_str!("shaders/dense_gemv_bf16.metal");
205 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
206
207 let scale_mask_softmax_src: &'static str =
213 include_str!("shaders/scale_mask_softmax.metal");
214 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
215
216 sources.insert(
218 "quantized_matmul_id".into(),
219 include_str!("shaders/quantized_matmul_id.metal"),
220 );
221
222 let ggml_id_src: &'static str =
224 include_str!("shaders/quantized_matmul_id_ggml.metal");
225 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
226 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
227 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
230 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
231 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
232 sources.insert("kernel_mul_mv_id_q6_K_f32_nr2".into(), ggml_id_src);
236 sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
238 sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
239 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
243
244 let ggml_id_mm_src: &'static str =
252 include_str!("shaders/quantized_matmul_id_mm.metal");
253 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
254 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
255 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
256 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
257 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
258 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
260 sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
262 sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
263 sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
265
266 let ggml_id_mm_tensor_src: &'static str =
272 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
273 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
274 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
275 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
276 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
278 sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
280 sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
281 sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
283
284 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
286 sources.insert("embedding_gather_4bit".into(), embedding_src);
287 sources.insert("embedding_gather_6bit".into(), embedding_src);
288
289 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
291 sources.insert("moe_gate".into(), moe_gate_src);
292
293 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
295 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
296 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
297 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
298 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
299 sources.insert("moe_accumulate".into(), moe_dispatch_src);
300 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
301 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
302 sources.insert("zero_buffer".into(), moe_dispatch_src);
303 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
304 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
305 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
307 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
308 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
309 sources.insert(
311 "moe_weighted_sum_seq_backward_outputs_f32".into(),
312 moe_dispatch_src,
313 );
314 sources.insert(
315 "moe_weighted_sum_seq_backward_weights_f32".into(),
316 moe_dispatch_src,
317 );
318 sources.insert(
320 "moe_swiglu_seq_backward_f32".into(),
321 moe_dispatch_src,
322 );
323
324 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
326 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
327 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
328 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
329 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
330 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
332 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
333 sources.insert("kv_cache_copy_batch_f32_kv_dual".into(), kv_cache_src);
335 sources.insert("kv_cache_copy_batch_f32_to_f16_kv_dual".into(), kv_cache_src);
336 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
338
339 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
341 sources.insert("elementwise_add_f32".into(), elementwise_src);
342 sources.insert("elementwise_add_f16".into(), elementwise_src);
343 sources.insert("elementwise_mul_f32".into(), elementwise_src);
344 sources.insert("elementwise_mul_f16".into(), elementwise_src);
345 sources.insert("elementwise_add_bf16".into(), elementwise_src);
346 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
347 sources.insert("cast_f16_to_f32".into(), elementwise_src);
348 sources.insert("cast_f32_to_f16".into(), elementwise_src);
349 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
350 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
351 sources.insert("scalar_mul_bf16".into(), elementwise_src);
352 sources.insert("scalar_mul_f32".into(), elementwise_src);
353 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
354 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
355 sources.insert("permute_021_bf16".into(), elementwise_src);
356 sources.insert("transpose_last2_bf16".into(), elementwise_src);
357 sources.insert("transpose_last2_f16".into(), elementwise_src);
358 sources.insert("permute_021_f32".into(), elementwise_src);
359 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
360 sources.insert("transpose_2d_f32".into(), elementwise_src);
361 sources.insert("transpose_2d_f16".into(), elementwise_src);
362
363 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
365 sources.insert("sdpa".into(), sdpa_src);
366 sources.insert("sdpa_bf16".into(), sdpa_src);
367 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
368 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
369 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
370
371 let flash_attn_prefill_src: &'static str =
376 include_str!("shaders/flash_attn_prefill.metal");
377 sources.insert(
379 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
380 flash_attn_prefill_src,
381 );
382 sources.insert(
383 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
384 flash_attn_prefill_src,
385 );
386 sources.insert(
387 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
388 flash_attn_prefill_src,
389 );
390 sources.insert(
391 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
392 flash_attn_prefill_src,
393 );
394 sources.insert(
395 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
396 flash_attn_prefill_src,
397 );
398 sources.insert(
399 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
400 flash_attn_prefill_src,
401 );
402 sources.insert(
406 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
407 flash_attn_prefill_src,
408 );
409 sources.insert(
410 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
411 flash_attn_prefill_src,
412 );
413 sources.insert(
414 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
415 flash_attn_prefill_src,
416 );
417 sources.insert(
418 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
419 flash_attn_prefill_src,
420 );
421
422 let flash_attn_vec_src: &'static str =
425 include_str!("shaders/flash_attn_vec.metal");
426 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
427 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
428 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
429 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
430 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
432 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
433
434 let rope_src: &'static str = include_str!("shaders/rope.metal");
436 sources.insert("rope_f32".into(), rope_src);
437 sources.insert("rope_f16".into(), rope_src);
438 sources.insert("rope_bf16".into(), rope_src);
439 sources.insert("rope_neox_bf16".into(), rope_src);
440 sources.insert("rope_neox_f32".into(), rope_src);
441 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
442 sources.insert("rms_norm_f32".into(), rms_norm_src);
443 sources.insert("rms_norm_f32_v2".into(), rms_norm_src);
447 sources.insert("rms_norm_no_scale_f32_v2".into(), rms_norm_src);
448 sources.insert("rms_norm_f16".into(), rms_norm_src);
449 sources.insert("rms_norm_bf16".into(), rms_norm_src);
450 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
451 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
452 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
453 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
454 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
455 sources.insert("fused_post_ff_norm2_endlayer_f32".into(), rms_norm_src);
458 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
459 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
461 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
462 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
463 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
465 sources.insert("l2_norm_f32".into(), l2_norm_src);
466 sources.insert("l2_norm_f16".into(), l2_norm_src);
467 sources.insert("l2_norm_bf16".into(), l2_norm_src);
468 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
470 sources.insert("cumsum_f32".into(), cumsum_src);
471 sources.insert("cumsum_bf16".into(), cumsum_src);
472 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
474 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
475 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
476 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
477 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
478 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
480 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
481 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
482 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
484 sources.insert("rope_multi_f32".into(), rope_multi_src);
485 sources.insert("rope_multi_bf16".into(), rope_multi_src);
486 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
488 sources.insert("gated_delta_net_f32".into(), gdn_src);
489 let gdn_decode_src: &'static str =
494 include_str!("shaders/gated_delta_net_decode.metal");
495 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
496 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
497 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
498 let gdn_chunk_src: &'static str =
502 include_str!("shaders/gated_delta_net_chunk.metal");
503 sources.insert(
504 "gated_delta_net_chunk_inter_state_bf16".into(),
505 gdn_chunk_src,
506 );
507 let gdn_kkt_src: &'static str =
510 include_str!("shaders/gated_delta_net_kkt.metal");
511 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
512 let gdn_recompute_wu_src: &'static str =
516 include_str!("shaders/gated_delta_net_recompute_wu.metal");
517 sources.insert(
518 "gated_delta_net_recompute_wu_bf16".into(),
519 gdn_recompute_wu_src,
520 );
521 let gdn_chunk_o_src: &'static str =
524 include_str!("shaders/gated_delta_net_chunk_o.metal");
525 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
526 let chunk_local_cumsum_g_src: &'static str =
531 include_str!("shaders/chunk_local_cumsum_g.metal");
532 sources.insert(
533 "chunk_local_cumsum_g_f32".into(),
534 chunk_local_cumsum_g_src,
535 );
536 let chunk_tri_solve_invert_src: &'static str =
537 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
538 sources.insert(
539 "chunk_tri_solve_invert_f32".into(),
540 chunk_tri_solve_invert_src,
541 );
542 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
544 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
545 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
546 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
547 sources.insert("silu_mul_f32".into(), silu_mul_src);
548 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
549 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
550 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
551 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
552 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
553 sources.insert("gelu_f32".into(), gelu_src);
554 sources.insert("gelu_f16".into(), gelu_src);
555 sources.insert("gelu_bf16".into(), gelu_src);
556 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
557 sources.insert("softmax_f32".into(), softmax_src);
558 sources.insert("softmax_f16".into(), softmax_src);
559 sources.insert("softmax_bf16".into(), softmax_src);
560 let softmax_backward_src: &'static str =
561 include_str!("shaders/softmax_backward.metal");
562 sources.insert("softmax_backward_f32".into(), softmax_backward_src);
563 let log_elementwise_src: &'static str =
564 include_str!("shaders/log_elementwise.metal");
565 sources.insert("log_f32".into(), log_elementwise_src);
566 sources.insert("log_backward_f32".into(), log_elementwise_src);
567 let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
568 sources.insert("row_sum_f32".into(), row_sum_src);
569 sources.insert("row_sum_backward_f32".into(), row_sum_src);
570 let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
574 sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
575 sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
576 let rms_norm_backward_src: &'static str =
580 include_str!("shaders/rms_norm_backward.metal");
581 sources.insert(
582 "rms_norm_compute_rms_inv_f32".into(),
583 rms_norm_backward_src,
584 );
585 sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
586 sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
587 let slice_concat_2d_src: &'static str =
592 include_str!("shaders/slice_concat_2d.metal");
593 sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
594 sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
595 let silu_backward_src: &'static str =
598 include_str!("shaders/silu_backward.metal");
599 sources.insert("silu_f32".into(), silu_backward_src);
600 sources.insert("silu_backward_f32".into(), silu_backward_src);
601 let embedding_autograd_src: &'static str =
603 include_str!("shaders/embedding_autograd.metal");
604 sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
605 sources.insert(
606 "embedding_scatter_add_f32".into(),
607 embedding_autograd_src,
608 );
609 let adam_update_src: &'static str =
612 include_str!("shaders/adam_update.metal");
613 sources.insert("adam_update_f32".into(), adam_update_src);
614 let qdq_affine_src: &'static str =
618 include_str!("shaders/qdq_affine.metal");
619 sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
620 sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
621 sources.insert(
622 "qdq_affine_backward_scales_f32".into(),
623 qdq_affine_src,
624 );
625 sources.insert(
626 "qdq_affine_backward_biases_f32".into(),
627 qdq_affine_src,
628 );
629 let qmm_affine_src: &'static str =
633 include_str!("shaders/qmm_affine.metal");
634 sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
635 let qmm_affine_tiled_src: &'static str =
639 include_str!("shaders/qmm_affine_tiled.metal");
640 sources.insert(
641 "qmm_affine_t_f32_tiled".into(),
642 qmm_affine_tiled_src,
643 );
644 let qmm_affine_simd_src: &'static str =
649 include_str!("shaders/qmm_affine_simd.metal");
650 sources.insert(
651 "qmm_affine_t_f32_simd".into(),
652 qmm_affine_simd_src,
653 );
654 let qmm_affine_simd4_src: &'static str =
659 include_str!("shaders/qmm_affine_simd4.metal");
660 sources.insert(
661 "qmm_affine_t_f32_simd4".into(),
662 qmm_affine_simd4_src,
663 );
664 let qmm_affine_simd4_gs64_src: &'static str =
668 include_str!("shaders/qmm_affine_simd4_gs64.metal");
669 sources.insert(
670 "qmm_affine_t_f32_simd4_gs64".into(),
671 qmm_affine_simd4_gs64_src,
672 );
673 let qmm_affine_t_packed_simd4_b4_src: &'static str =
677 include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
678 sources.insert(
679 "qmm_affine_t_packed_simd4_b4".into(),
680 qmm_affine_t_packed_simd4_b4_src,
681 );
682 let conv1d_dwc_src: &'static str =
687 include_str!("shaders/conv1d_depthwise_causal.metal");
688 sources.insert(
689 "conv1d_depthwise_causal_forward_f32".into(),
690 conv1d_dwc_src,
691 );
692 sources.insert(
693 "conv1d_depthwise_causal_backward_dx_f32".into(),
694 conv1d_dwc_src,
695 );
696 sources.insert(
697 "conv1d_depthwise_causal_backward_dw_f32".into(),
698 conv1d_dwc_src,
699 );
700 let exp_src: &'static str =
703 include_str!("shaders/exp_elementwise.metal");
704 sources.insert("exp_f32".into(), exp_src);
705 sources.insert("exp_backward_f32".into(), exp_src);
706 let outer_src: &'static str =
710 include_str!("shaders/outer_product.metal");
711 sources.insert("outer_product_f32".into(), outer_src);
712 sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
713 sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
714 let taa_src: &'static str =
717 include_str!("shaders/take_along_axis.metal");
718 sources.insert("take_along_axis_f32".into(), taa_src);
719 sources.insert("take_along_axis_backward_f32".into(), taa_src);
720 let div_src: &'static str =
722 include_str!("shaders/divide_elementwise.metal");
723 sources.insert("divide_f32".into(), div_src);
724 sources.insert("divide_backward_f32".into(), div_src);
725 let sqrt_src: &'static str =
727 include_str!("shaders/sqrt_elementwise.metal");
728 sources.insert("sqrt_f32".into(), sqrt_src);
729 sources.insert("sqrt_backward_f32".into(), sqrt_src);
730 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
731 sources.insert("softcap_f32".into(), softcap_src);
732 sources.insert("softcap_f16".into(), softcap_src);
733 sources.insert("softcap_bf16".into(), softcap_src);
734
735 let fused_norm_add_src: &'static str =
738 include_str!("shaders/fused_norm_add_bf16.metal");
739 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
740 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
741
742 let fused_hnr_f32_src: &'static str =
744 include_str!("shaders/fused_head_norm_rope_f32.metal");
745 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
746 sources.insert("fused_head_norm_rope_f32_v2".into(), fused_hnr_f32_src);
750
751 let fused_hnr_bf16_src: &'static str =
754 include_str!("shaders/fused_head_norm_rope_bf16.metal");
755 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
756 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
757
758 let fused_norm_add_f32_src: &'static str =
760 include_str!("shaders/fused_norm_add_f32.metal");
761 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
762 sources.insert("fused_norm_add_f32_v2".into(), fused_norm_add_f32_src);
767 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
768 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
769 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
770 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
771 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
772 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
773 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
774
775 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
777 sources.insert("argsort_desc_f32".into(), argsort_src);
778
779 let gather_src: &'static str = include_str!("shaders/gather.metal");
781 sources.insert("gather_f32".into(), gather_src);
782
783 let kv_cache_copy_src: &'static str =
785 include_str!("shaders/kv_cache_copy.metal");
786 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
787 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
788
789 let copy_src: &'static str = include_str!("shaders/copy.metal");
791 sources.insert("strided_copy_f32".into(), copy_src);
792 sources.insert("offset_copy_f32".into(), copy_src);
793
794 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
798 sources.insert("qkv_split_f32".into(), qkv_split_src);
799
800 let repeat_tiled_src: &'static str =
804 include_str!("shaders/repeat_tiled.metal");
805 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
806
807 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
809 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
810 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
811 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
812 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
814 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
816
817 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
819 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
820 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
821 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
823 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
824 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
825 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
826
827 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
829 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
830 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
831 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
833 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
834 sources.insert("hadamard_quantize_kv_hb_dual_d256".into(), hq_fast_src);
836 sources.insert("hadamard_quantize_kv_hb_dual_d512".into(), hq_fast_src);
837
838 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
840 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
841 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
843 sources.insert("tq_dequantize_hb_kv_seq".into(), tq_dq_src);
850
851 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
853 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
854 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
855
856 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
858 sources.insert("argmax_f32".into(), argmax_src);
859 let softmax_sample_src: &'static str =
860 include_str!("shaders/softmax_sample.metal");
861 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
862 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
864 sources.insert("top_k_f32".into(), top_k_src);
865
866 let moe_stk_src: &'static str =
869 include_str!("shaders/moe_softmax_topk.metal");
870 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
871 let moe_wr_src: &'static str =
872 include_str!("shaders/moe_weighted_reduce.metal");
873 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
874 let sdpa_decode_src: &'static str =
875 include_str!("shaders/sdpa_decode.metal");
876 sources.insert("sdpa_decode".into(), sdpa_decode_src);
877
878 Self {
879 cache: HashMap::new(),
880 sources,
881 }
882 }
883
884 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
887 let name = name.into();
888 self.cache.remove(&name);
890 self.sources.insert(name, source);
891 }
892
893 pub fn get_pipeline(
905 &mut self,
906 name: &str,
907 device: &metal::DeviceRef,
908 ) -> Result<&ComputePipelineState> {
909 if !self.cache.contains_key(name) {
910 let source = self.sources.get(name).ok_or_else(|| {
912 MlxError::KernelNotFound(name.to_string())
913 })?;
914
915 let compile_opts = metal::CompileOptions::new();
916 let library = device
917 .new_library_with_source(source, &compile_opts)
918 .map_err(|msg| MlxError::ShaderCompilationError {
919 name: name.to_string(),
920 message: msg,
921 })?;
922
923 let function = library
924 .get_function(name, None)
925 .map_err(|msg| MlxError::ShaderCompilationError {
926 name: name.to_string(),
927 message: msg,
928 })?;
929
930 let descriptor = ComputePipelineDescriptor::new();
940 descriptor.set_compute_function(Some(&function));
941 descriptor.set_label(name);
942
943 let pipeline = device
944 .new_compute_pipeline_state(&descriptor)
945 .map_err(|msg| MlxError::ShaderCompilationError {
946 name: name.to_string(),
947 message: msg,
948 })?;
949
950 self.cache.insert(name.to_string(), pipeline);
951 }
952
953 self.cache.get(name).ok_or_else(|| {
956 MlxError::KernelNotFound(name.to_string())
957 })
958 }
959
960 pub fn get_pipeline_with_constants(
982 &mut self,
983 name: &str,
984 device: &metal::DeviceRef,
985 bool_constants: &[(usize, bool)],
986 int_constants: &[(usize, i32)],
987 ) -> Result<&ComputePipelineState> {
988 let mut cache_key = name.to_string();
993 for &(index, value) in bool_constants {
994 cache_key.push('|');
995 cache_key.push_str(&index.to_string());
996 cache_key.push_str(if value { ":b1" } else { ":b0" });
997 }
998 for &(index, value) in int_constants {
999 cache_key.push('|');
1000 cache_key.push_str(&index.to_string());
1001 cache_key.push(':');
1002 cache_key.push('i');
1003 cache_key.push_str(&value.to_string());
1004 }
1005
1006 if !self.cache.contains_key(&cache_key) {
1007 let source = self.sources.get(name).ok_or_else(|| {
1009 MlxError::KernelNotFound(name.to_string())
1010 })?;
1011
1012 let compile_opts = metal::CompileOptions::new();
1013 let library = device
1014 .new_library_with_source(source, &compile_opts)
1015 .map_err(|msg| MlxError::ShaderCompilationError {
1016 name: name.to_string(),
1017 message: msg,
1018 })?;
1019
1020 let fcv = FunctionConstantValues::new();
1025
1026 for &(index, value) in bool_constants {
1027 let v: u8 = if value { 1 } else { 0 };
1030 fcv.set_constant_value_at_index(
1031 (&v as *const u8).cast::<std::ffi::c_void>(),
1032 MTLDataType::Bool,
1033 index as u64,
1034 );
1035 }
1036
1037 for &(index, value) in int_constants {
1038 fcv.set_constant_value_at_index(
1042 (&value as *const i32).cast::<std::ffi::c_void>(),
1043 MTLDataType::Int,
1044 index as u64,
1045 );
1046 }
1047
1048 let function = library
1049 .get_function(name, Some(fcv))
1050 .map_err(|msg| MlxError::ShaderCompilationError {
1051 name: name.to_string(),
1052 message: msg,
1053 })?;
1054
1055 let descriptor = ComputePipelineDescriptor::new();
1062 descriptor.set_compute_function(Some(&function));
1063 descriptor.set_label(&cache_key);
1064
1065 let pipeline = device
1066 .new_compute_pipeline_state(&descriptor)
1067 .map_err(|msg| MlxError::ShaderCompilationError {
1068 name: name.to_string(),
1069 message: msg,
1070 })?;
1071
1072 self.cache.insert(cache_key.clone(), pipeline);
1073 }
1074
1075 self.cache.get(&cache_key).ok_or_else(|| {
1076 MlxError::KernelNotFound(name.to_string())
1077 })
1078 }
1079
1080 pub fn get_pipeline_with_bool_constants(
1098 &mut self,
1099 name: &str,
1100 device: &metal::DeviceRef,
1101 bool_constants: &[(usize, bool)],
1102 ) -> Result<&ComputePipelineState> {
1103 self.get_pipeline_with_constants(name, device, bool_constants, &[])
1104 }
1105
1106 pub fn is_cached(&self, name: &str) -> bool {
1108 self.cache.contains_key(name)
1109 }
1110
1111 pub fn cached_count(&self) -> usize {
1113 self.cache.len()
1114 }
1115
1116 pub fn source_count(&self) -> usize {
1118 self.sources.len()
1119 }
1120}
1121
1122impl Default for KernelRegistry {
1123 fn default() -> Self {
1124 Self::new()
1125 }
1126}
1127
1128#[cfg(test)]
1129mod tests {
1130 use super::*;
1131
1132 const INT_FC_TEST_SHADER: &str = r#"
1142#include <metal_stdlib>
1143using namespace metal;
1144
1145constant int test_N [[function_constant(100)]];
1146
1147kernel void int_fc_test_kernel(
1148 device int* out [[buffer(0)]],
1149 uint tid [[thread_position_in_grid]])
1150{
1151 if (tid == 0) {
1152 out[0] = test_N;
1153 }
1154}
1155"#;
1156
1157 #[test]
1165 fn test_int_fc_distinct_pipelines_and_bool_compat() {
1166 let device = metal::Device::system_default()
1167 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1168
1169 let mut registry = KernelRegistry::new();
1170
1171 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1174
1175 let p4_ptr = registry
1177 .get_pipeline_with_constants(
1178 "int_fc_test_kernel",
1179 &device,
1180 &[], &[(100, 4_i32)], )
1183 .expect("pipeline N=4 should compile") as *const _;
1184
1185 let count_after_n4 = registry.cached_count();
1189
1190 let p8_ptr = registry
1192 .get_pipeline_with_constants(
1193 "int_fc_test_kernel",
1194 &device,
1195 &[],
1196 &[(100, 8_i32)],
1197 )
1198 .expect("pipeline N=8 should compile") as *const _;
1199
1200 assert_eq!(
1202 registry.cached_count(),
1203 count_after_n4 + 1,
1204 "N=8 must produce a new cache entry"
1205 );
1206
1207 assert_ne!(
1209 p4_ptr, p8_ptr,
1210 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1211 );
1212
1213 let p4_again_ptr = registry
1216 .get_pipeline_with_constants(
1217 "int_fc_test_kernel",
1218 &device,
1219 &[],
1220 &[(100, 4_i32)],
1221 )
1222 .expect("pipeline N=4 cache hit should succeed") as *const _;
1223
1224 assert_eq!(
1225 registry.cached_count(),
1226 count_after_n4 + 1,
1227 "repeated N=4 call must be a cache hit, not a new entry"
1228 );
1229 assert_eq!(
1230 p4_ptr, p4_again_ptr,
1231 "repeated N=4 call must return the same pipeline pointer"
1232 );
1233
1234 const BARE_SHADER: &str = r#"
1248#include <metal_stdlib>
1249using namespace metal;
1250kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1251 if (tid == 0) { out[0] = 42; }
1252}
1253"#;
1254 registry.register_source("bare_kernel", BARE_SHADER);
1255
1256 let count_before_bool = registry.cached_count();
1257 let _bool_pipeline = registry
1258 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1259 .expect("bool-constants wrapper with empty slice must succeed");
1260
1261 assert_eq!(
1262 registry.cached_count(),
1263 count_before_bool + 1,
1264 "bool-constants wrapper must insert one new cache entry"
1265 );
1266 }
1267
1268 #[test]
1279 fn test_pipeline_labels_propagate_for_mst() {
1280 let device = metal::Device::system_default()
1281 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1282
1283 let mut registry = KernelRegistry::new();
1284
1285 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1287
1288 const BARE_SHADER_LABEL_TEST: &str = r#"
1289#include <metal_stdlib>
1290using namespace metal;
1291kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1292 if (tid == 0) { out[0] = 7; }
1293}
1294"#;
1295 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1296
1297 let plain_label = registry
1301 .get_pipeline("label_smoke_kernel", &device)
1302 .expect("plain pipeline must compile")
1303 .label()
1304 .to_string();
1305 assert_eq!(
1306 plain_label, "label_smoke_kernel",
1307 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1308 );
1309
1310 let label_v7 = registry
1315 .get_pipeline_with_constants(
1316 "int_fc_test_kernel",
1317 &device,
1318 &[],
1319 &[(100, 7_i32)],
1320 )
1321 .expect("specialised pipeline must compile")
1322 .label()
1323 .to_string();
1324 assert_eq!(
1325 label_v7, "int_fc_test_kernel|100:i7",
1326 "get_pipeline_with_constants must label with the cache_key so each \
1327 specialisation is distinct in xctrace MST"
1328 );
1329
1330 let label_v13 = registry
1332 .get_pipeline_with_constants(
1333 "int_fc_test_kernel",
1334 &device,
1335 &[],
1336 &[(100, 13_i32)],
1337 )
1338 .expect("second specialised pipeline must compile")
1339 .label()
1340 .to_string();
1341 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1342 assert_ne!(
1343 label_v7, label_v13,
1344 "distinct constant values must yield distinct pipeline labels"
1345 );
1346 }
1347}