1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80 sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
82 sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
83 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
86 sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
88
89 let ggml_mm_src: &'static str =
95 include_str!("shaders/quantized_matmul_mm.metal");
96 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
97 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
98 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
99 sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
101 sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
102 sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
104 sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
106
107 let ggml_mm_tensor_src: &'static str =
118 include_str!("shaders/quantized_matmul_mm_tensor.metal");
119 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
120 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
121 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
122 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
123 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
124 sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
126 sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
127 sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
129 sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
131 sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
132
133 let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
138 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
139 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
140 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
141 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
142 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
143 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
144 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
145 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
146 for r1 in [2, 3, 4, 5].iter() {
149 for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
150 let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
151 sources.insert(name, mul_mv_ext_src);
152 }
153 }
154
155 let dense_mm_bf16_tensor_src: &'static str =
162 include_str!("shaders/dense_mm_bf16_tensor.metal");
163 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
164
165 let dense_mm_f32_f32_tensor_src: &'static str =
174 include_str!("shaders/dense_mm_f32_f32.metal");
175 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
176
177 let dense_mm_f16_tensor_src: &'static str =
189 include_str!("shaders/dense_mm_f16_tensor.metal");
190 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
191
192 let dense_gemv_bf16_src: &'static str =
199 include_str!("shaders/dense_gemv_bf16.metal");
200 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
201
202 let scale_mask_softmax_src: &'static str =
208 include_str!("shaders/scale_mask_softmax.metal");
209 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
210
211 sources.insert(
213 "quantized_matmul_id".into(),
214 include_str!("shaders/quantized_matmul_id.metal"),
215 );
216
217 let ggml_id_src: &'static str =
219 include_str!("shaders/quantized_matmul_id_ggml.metal");
220 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
221 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
222 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
225 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
226 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
227 sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
229 sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
230 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
234
235 let ggml_id_mm_src: &'static str =
243 include_str!("shaders/quantized_matmul_id_mm.metal");
244 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
245 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
246 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
247 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
248 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
249 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
251 sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
253 sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
254 sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
256
257 let ggml_id_mm_tensor_src: &'static str =
263 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
264 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
265 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
266 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
267 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
269 sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
271 sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
272 sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
274
275 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
277 sources.insert("embedding_gather_4bit".into(), embedding_src);
278 sources.insert("embedding_gather_6bit".into(), embedding_src);
279
280 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
282 sources.insert("moe_gate".into(), moe_gate_src);
283
284 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
286 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
287 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
288 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
289 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
290 sources.insert("moe_accumulate".into(), moe_dispatch_src);
291 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
292 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
293 sources.insert("zero_buffer".into(), moe_dispatch_src);
294 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
295 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
296 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
298 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
299 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
300 sources.insert(
302 "moe_weighted_sum_seq_backward_outputs_f32".into(),
303 moe_dispatch_src,
304 );
305 sources.insert(
306 "moe_weighted_sum_seq_backward_weights_f32".into(),
307 moe_dispatch_src,
308 );
309 sources.insert(
311 "moe_swiglu_seq_backward_f32".into(),
312 moe_dispatch_src,
313 );
314
315 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
317 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
318 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
319 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
320 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
321 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
323 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
324 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
326
327 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
329 sources.insert("elementwise_add_f32".into(), elementwise_src);
330 sources.insert("elementwise_add_f16".into(), elementwise_src);
331 sources.insert("elementwise_mul_f32".into(), elementwise_src);
332 sources.insert("elementwise_mul_f16".into(), elementwise_src);
333 sources.insert("elementwise_add_bf16".into(), elementwise_src);
334 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
335 sources.insert("cast_f16_to_f32".into(), elementwise_src);
336 sources.insert("cast_f32_to_f16".into(), elementwise_src);
337 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
338 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
339 sources.insert("scalar_mul_bf16".into(), elementwise_src);
340 sources.insert("scalar_mul_f32".into(), elementwise_src);
341 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
342 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
343 sources.insert("permute_021_bf16".into(), elementwise_src);
344 sources.insert("transpose_last2_bf16".into(), elementwise_src);
345 sources.insert("transpose_last2_f16".into(), elementwise_src);
346 sources.insert("permute_021_f32".into(), elementwise_src);
347 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
348 sources.insert("transpose_2d_f32".into(), elementwise_src);
349 sources.insert("transpose_2d_f16".into(), elementwise_src);
350
351 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
353 sources.insert("sdpa".into(), sdpa_src);
354 sources.insert("sdpa_bf16".into(), sdpa_src);
355 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
356 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
357 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
358
359 let flash_attn_prefill_src: &'static str =
364 include_str!("shaders/flash_attn_prefill.metal");
365 sources.insert(
367 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
368 flash_attn_prefill_src,
369 );
370 sources.insert(
371 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
372 flash_attn_prefill_src,
373 );
374 sources.insert(
375 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
376 flash_attn_prefill_src,
377 );
378 sources.insert(
379 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
380 flash_attn_prefill_src,
381 );
382 sources.insert(
383 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
384 flash_attn_prefill_src,
385 );
386 sources.insert(
387 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
388 flash_attn_prefill_src,
389 );
390 sources.insert(
394 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
395 flash_attn_prefill_src,
396 );
397 sources.insert(
398 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
399 flash_attn_prefill_src,
400 );
401 sources.insert(
402 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
403 flash_attn_prefill_src,
404 );
405 sources.insert(
406 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
407 flash_attn_prefill_src,
408 );
409
410 let flash_attn_vec_src: &'static str =
413 include_str!("shaders/flash_attn_vec.metal");
414 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
415 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
416 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
417 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
418 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
420 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
421
422 let rope_src: &'static str = include_str!("shaders/rope.metal");
424 sources.insert("rope_f32".into(), rope_src);
425 sources.insert("rope_f16".into(), rope_src);
426 sources.insert("rope_bf16".into(), rope_src);
427 sources.insert("rope_neox_bf16".into(), rope_src);
428 sources.insert("rope_neox_f32".into(), rope_src);
429 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
430 sources.insert("rms_norm_f32".into(), rms_norm_src);
431 sources.insert("rms_norm_f16".into(), rms_norm_src);
432 sources.insert("rms_norm_bf16".into(), rms_norm_src);
433 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
434 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
435 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
436 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
437 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
438 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
439 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
441 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
442 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
443 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
445 sources.insert("l2_norm_f32".into(), l2_norm_src);
446 sources.insert("l2_norm_f16".into(), l2_norm_src);
447 sources.insert("l2_norm_bf16".into(), l2_norm_src);
448 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
450 sources.insert("cumsum_f32".into(), cumsum_src);
451 sources.insert("cumsum_bf16".into(), cumsum_src);
452 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
454 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
455 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
456 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
457 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
458 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
460 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
461 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
462 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
464 sources.insert("rope_multi_f32".into(), rope_multi_src);
465 sources.insert("rope_multi_bf16".into(), rope_multi_src);
466 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
468 sources.insert("gated_delta_net_f32".into(), gdn_src);
469 let gdn_decode_src: &'static str =
474 include_str!("shaders/gated_delta_net_decode.metal");
475 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
476 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
477 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
478 let gdn_chunk_src: &'static str =
482 include_str!("shaders/gated_delta_net_chunk.metal");
483 sources.insert(
484 "gated_delta_net_chunk_inter_state_bf16".into(),
485 gdn_chunk_src,
486 );
487 let gdn_kkt_src: &'static str =
490 include_str!("shaders/gated_delta_net_kkt.metal");
491 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
492 let gdn_recompute_wu_src: &'static str =
496 include_str!("shaders/gated_delta_net_recompute_wu.metal");
497 sources.insert(
498 "gated_delta_net_recompute_wu_bf16".into(),
499 gdn_recompute_wu_src,
500 );
501 let gdn_chunk_o_src: &'static str =
504 include_str!("shaders/gated_delta_net_chunk_o.metal");
505 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
506 let chunk_local_cumsum_g_src: &'static str =
511 include_str!("shaders/chunk_local_cumsum_g.metal");
512 sources.insert(
513 "chunk_local_cumsum_g_f32".into(),
514 chunk_local_cumsum_g_src,
515 );
516 let chunk_tri_solve_invert_src: &'static str =
517 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
518 sources.insert(
519 "chunk_tri_solve_invert_f32".into(),
520 chunk_tri_solve_invert_src,
521 );
522 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
524 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
525 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
526 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
527 sources.insert("silu_mul_f32".into(), silu_mul_src);
528 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
529 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
530 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
531 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
532 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
533 sources.insert("gelu_f32".into(), gelu_src);
534 sources.insert("gelu_f16".into(), gelu_src);
535 sources.insert("gelu_bf16".into(), gelu_src);
536 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
537 sources.insert("softmax_f32".into(), softmax_src);
538 sources.insert("softmax_f16".into(), softmax_src);
539 sources.insert("softmax_bf16".into(), softmax_src);
540 let softmax_backward_src: &'static str =
541 include_str!("shaders/softmax_backward.metal");
542 sources.insert("softmax_backward_f32".into(), softmax_backward_src);
543 let log_elementwise_src: &'static str =
544 include_str!("shaders/log_elementwise.metal");
545 sources.insert("log_f32".into(), log_elementwise_src);
546 sources.insert("log_backward_f32".into(), log_elementwise_src);
547 let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
548 sources.insert("row_sum_f32".into(), row_sum_src);
549 sources.insert("row_sum_backward_f32".into(), row_sum_src);
550 let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
554 sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
555 sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
556 let rms_norm_backward_src: &'static str =
560 include_str!("shaders/rms_norm_backward.metal");
561 sources.insert(
562 "rms_norm_compute_rms_inv_f32".into(),
563 rms_norm_backward_src,
564 );
565 sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
566 sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
567 let slice_concat_2d_src: &'static str =
572 include_str!("shaders/slice_concat_2d.metal");
573 sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
574 sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
575 let silu_backward_src: &'static str =
578 include_str!("shaders/silu_backward.metal");
579 sources.insert("silu_f32".into(), silu_backward_src);
580 sources.insert("silu_backward_f32".into(), silu_backward_src);
581 let embedding_autograd_src: &'static str =
583 include_str!("shaders/embedding_autograd.metal");
584 sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
585 sources.insert(
586 "embedding_scatter_add_f32".into(),
587 embedding_autograd_src,
588 );
589 let adam_update_src: &'static str =
592 include_str!("shaders/adam_update.metal");
593 sources.insert("adam_update_f32".into(), adam_update_src);
594 let qdq_affine_src: &'static str =
598 include_str!("shaders/qdq_affine.metal");
599 sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
600 sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
601 sources.insert(
602 "qdq_affine_backward_scales_f32".into(),
603 qdq_affine_src,
604 );
605 sources.insert(
606 "qdq_affine_backward_biases_f32".into(),
607 qdq_affine_src,
608 );
609 let qmm_affine_src: &'static str =
613 include_str!("shaders/qmm_affine.metal");
614 sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
615 let qmm_affine_tiled_src: &'static str =
619 include_str!("shaders/qmm_affine_tiled.metal");
620 sources.insert(
621 "qmm_affine_t_f32_tiled".into(),
622 qmm_affine_tiled_src,
623 );
624 let qmm_affine_simd_src: &'static str =
629 include_str!("shaders/qmm_affine_simd.metal");
630 sources.insert(
631 "qmm_affine_t_f32_simd".into(),
632 qmm_affine_simd_src,
633 );
634 let qmm_affine_simd4_src: &'static str =
639 include_str!("shaders/qmm_affine_simd4.metal");
640 sources.insert(
641 "qmm_affine_t_f32_simd4".into(),
642 qmm_affine_simd4_src,
643 );
644 let qmm_affine_simd4_gs64_src: &'static str =
648 include_str!("shaders/qmm_affine_simd4_gs64.metal");
649 sources.insert(
650 "qmm_affine_t_f32_simd4_gs64".into(),
651 qmm_affine_simd4_gs64_src,
652 );
653 let qmm_affine_t_packed_simd4_b4_src: &'static str =
657 include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
658 sources.insert(
659 "qmm_affine_t_packed_simd4_b4".into(),
660 qmm_affine_t_packed_simd4_b4_src,
661 );
662 let conv1d_dwc_src: &'static str =
667 include_str!("shaders/conv1d_depthwise_causal.metal");
668 sources.insert(
669 "conv1d_depthwise_causal_forward_f32".into(),
670 conv1d_dwc_src,
671 );
672 sources.insert(
673 "conv1d_depthwise_causal_backward_dx_f32".into(),
674 conv1d_dwc_src,
675 );
676 sources.insert(
677 "conv1d_depthwise_causal_backward_dw_f32".into(),
678 conv1d_dwc_src,
679 );
680 let exp_src: &'static str =
683 include_str!("shaders/exp_elementwise.metal");
684 sources.insert("exp_f32".into(), exp_src);
685 sources.insert("exp_backward_f32".into(), exp_src);
686 let outer_src: &'static str =
690 include_str!("shaders/outer_product.metal");
691 sources.insert("outer_product_f32".into(), outer_src);
692 sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
693 sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
694 let taa_src: &'static str =
697 include_str!("shaders/take_along_axis.metal");
698 sources.insert("take_along_axis_f32".into(), taa_src);
699 sources.insert("take_along_axis_backward_f32".into(), taa_src);
700 let div_src: &'static str =
702 include_str!("shaders/divide_elementwise.metal");
703 sources.insert("divide_f32".into(), div_src);
704 sources.insert("divide_backward_f32".into(), div_src);
705 let sqrt_src: &'static str =
707 include_str!("shaders/sqrt_elementwise.metal");
708 sources.insert("sqrt_f32".into(), sqrt_src);
709 sources.insert("sqrt_backward_f32".into(), sqrt_src);
710 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
711 sources.insert("softcap_f32".into(), softcap_src);
712 sources.insert("softcap_f16".into(), softcap_src);
713 sources.insert("softcap_bf16".into(), softcap_src);
714
715 let fused_norm_add_src: &'static str =
718 include_str!("shaders/fused_norm_add_bf16.metal");
719 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
720 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
721
722 let fused_hnr_f32_src: &'static str =
724 include_str!("shaders/fused_head_norm_rope_f32.metal");
725 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
726
727 let fused_hnr_bf16_src: &'static str =
730 include_str!("shaders/fused_head_norm_rope_bf16.metal");
731 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
732 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
733
734 let fused_norm_add_f32_src: &'static str =
736 include_str!("shaders/fused_norm_add_f32.metal");
737 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
738 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
739 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
740 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
741 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
742 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
743 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
744 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
745
746 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
748 sources.insert("argsort_desc_f32".into(), argsort_src);
749
750 let gather_src: &'static str = include_str!("shaders/gather.metal");
752 sources.insert("gather_f32".into(), gather_src);
753
754 let kv_cache_copy_src: &'static str =
756 include_str!("shaders/kv_cache_copy.metal");
757 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
758 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
759
760 let copy_src: &'static str = include_str!("shaders/copy.metal");
762 sources.insert("strided_copy_f32".into(), copy_src);
763 sources.insert("offset_copy_f32".into(), copy_src);
764
765 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
769 sources.insert("qkv_split_f32".into(), qkv_split_src);
770
771 let repeat_tiled_src: &'static str =
775 include_str!("shaders/repeat_tiled.metal");
776 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
777
778 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
780 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
781 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
782 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
783 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
785 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
787
788 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
790 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
791 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
792 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
794 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
795 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
796 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
797
798 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
800 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
801 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
802 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
804 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
805
806 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
808 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
809 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
811
812 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
814 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
815 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
816
817 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
819 sources.insert("argmax_f32".into(), argmax_src);
820 let softmax_sample_src: &'static str =
821 include_str!("shaders/softmax_sample.metal");
822 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
823 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
825 sources.insert("top_k_f32".into(), top_k_src);
826
827 let moe_stk_src: &'static str =
830 include_str!("shaders/moe_softmax_topk.metal");
831 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
832 let moe_wr_src: &'static str =
833 include_str!("shaders/moe_weighted_reduce.metal");
834 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
835 let sdpa_decode_src: &'static str =
836 include_str!("shaders/sdpa_decode.metal");
837 sources.insert("sdpa_decode".into(), sdpa_decode_src);
838
839 Self {
840 cache: HashMap::new(),
841 sources,
842 }
843 }
844
845 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
848 let name = name.into();
849 self.cache.remove(&name);
851 self.sources.insert(name, source);
852 }
853
854 pub fn get_pipeline(
866 &mut self,
867 name: &str,
868 device: &metal::DeviceRef,
869 ) -> Result<&ComputePipelineState> {
870 if !self.cache.contains_key(name) {
871 let source = self.sources.get(name).ok_or_else(|| {
873 MlxError::KernelNotFound(name.to_string())
874 })?;
875
876 let compile_opts = metal::CompileOptions::new();
877 let library = device
878 .new_library_with_source(source, &compile_opts)
879 .map_err(|msg| MlxError::ShaderCompilationError {
880 name: name.to_string(),
881 message: msg,
882 })?;
883
884 let function = library
885 .get_function(name, None)
886 .map_err(|msg| MlxError::ShaderCompilationError {
887 name: name.to_string(),
888 message: msg,
889 })?;
890
891 let descriptor = ComputePipelineDescriptor::new();
901 descriptor.set_compute_function(Some(&function));
902 descriptor.set_label(name);
903
904 let pipeline = device
905 .new_compute_pipeline_state(&descriptor)
906 .map_err(|msg| MlxError::ShaderCompilationError {
907 name: name.to_string(),
908 message: msg,
909 })?;
910
911 self.cache.insert(name.to_string(), pipeline);
912 }
913
914 self.cache.get(name).ok_or_else(|| {
917 MlxError::KernelNotFound(name.to_string())
918 })
919 }
920
921 pub fn get_pipeline_with_constants(
943 &mut self,
944 name: &str,
945 device: &metal::DeviceRef,
946 bool_constants: &[(usize, bool)],
947 int_constants: &[(usize, i32)],
948 ) -> Result<&ComputePipelineState> {
949 let mut cache_key = name.to_string();
954 for &(index, value) in bool_constants {
955 cache_key.push('|');
956 cache_key.push_str(&index.to_string());
957 cache_key.push_str(if value { ":b1" } else { ":b0" });
958 }
959 for &(index, value) in int_constants {
960 cache_key.push('|');
961 cache_key.push_str(&index.to_string());
962 cache_key.push(':');
963 cache_key.push('i');
964 cache_key.push_str(&value.to_string());
965 }
966
967 if !self.cache.contains_key(&cache_key) {
968 let source = self.sources.get(name).ok_or_else(|| {
970 MlxError::KernelNotFound(name.to_string())
971 })?;
972
973 let compile_opts = metal::CompileOptions::new();
974 let library = device
975 .new_library_with_source(source, &compile_opts)
976 .map_err(|msg| MlxError::ShaderCompilationError {
977 name: name.to_string(),
978 message: msg,
979 })?;
980
981 let fcv = FunctionConstantValues::new();
986
987 for &(index, value) in bool_constants {
988 let v: u8 = if value { 1 } else { 0 };
991 fcv.set_constant_value_at_index(
992 (&v as *const u8).cast::<std::ffi::c_void>(),
993 MTLDataType::Bool,
994 index as u64,
995 );
996 }
997
998 for &(index, value) in int_constants {
999 fcv.set_constant_value_at_index(
1003 (&value as *const i32).cast::<std::ffi::c_void>(),
1004 MTLDataType::Int,
1005 index as u64,
1006 );
1007 }
1008
1009 let function = library
1010 .get_function(name, Some(fcv))
1011 .map_err(|msg| MlxError::ShaderCompilationError {
1012 name: name.to_string(),
1013 message: msg,
1014 })?;
1015
1016 let descriptor = ComputePipelineDescriptor::new();
1023 descriptor.set_compute_function(Some(&function));
1024 descriptor.set_label(&cache_key);
1025
1026 let pipeline = device
1027 .new_compute_pipeline_state(&descriptor)
1028 .map_err(|msg| MlxError::ShaderCompilationError {
1029 name: name.to_string(),
1030 message: msg,
1031 })?;
1032
1033 self.cache.insert(cache_key.clone(), pipeline);
1034 }
1035
1036 self.cache.get(&cache_key).ok_or_else(|| {
1037 MlxError::KernelNotFound(name.to_string())
1038 })
1039 }
1040
1041 pub fn get_pipeline_with_bool_constants(
1059 &mut self,
1060 name: &str,
1061 device: &metal::DeviceRef,
1062 bool_constants: &[(usize, bool)],
1063 ) -> Result<&ComputePipelineState> {
1064 self.get_pipeline_with_constants(name, device, bool_constants, &[])
1065 }
1066
1067 pub fn is_cached(&self, name: &str) -> bool {
1069 self.cache.contains_key(name)
1070 }
1071
1072 pub fn cached_count(&self) -> usize {
1074 self.cache.len()
1075 }
1076
1077 pub fn source_count(&self) -> usize {
1079 self.sources.len()
1080 }
1081}
1082
1083impl Default for KernelRegistry {
1084 fn default() -> Self {
1085 Self::new()
1086 }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use super::*;
1092
1093 const INT_FC_TEST_SHADER: &str = r#"
1103#include <metal_stdlib>
1104using namespace metal;
1105
1106constant int test_N [[function_constant(100)]];
1107
1108kernel void int_fc_test_kernel(
1109 device int* out [[buffer(0)]],
1110 uint tid [[thread_position_in_grid]])
1111{
1112 if (tid == 0) {
1113 out[0] = test_N;
1114 }
1115}
1116"#;
1117
1118 #[test]
1126 fn test_int_fc_distinct_pipelines_and_bool_compat() {
1127 let device = metal::Device::system_default()
1128 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1129
1130 let mut registry = KernelRegistry::new();
1131
1132 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1135
1136 let p4_ptr = registry
1138 .get_pipeline_with_constants(
1139 "int_fc_test_kernel",
1140 &device,
1141 &[], &[(100, 4_i32)], )
1144 .expect("pipeline N=4 should compile") as *const _;
1145
1146 let count_after_n4 = registry.cached_count();
1150
1151 let p8_ptr = registry
1153 .get_pipeline_with_constants(
1154 "int_fc_test_kernel",
1155 &device,
1156 &[],
1157 &[(100, 8_i32)],
1158 )
1159 .expect("pipeline N=8 should compile") as *const _;
1160
1161 assert_eq!(
1163 registry.cached_count(),
1164 count_after_n4 + 1,
1165 "N=8 must produce a new cache entry"
1166 );
1167
1168 assert_ne!(
1170 p4_ptr, p8_ptr,
1171 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1172 );
1173
1174 let p4_again_ptr = registry
1177 .get_pipeline_with_constants(
1178 "int_fc_test_kernel",
1179 &device,
1180 &[],
1181 &[(100, 4_i32)],
1182 )
1183 .expect("pipeline N=4 cache hit should succeed") as *const _;
1184
1185 assert_eq!(
1186 registry.cached_count(),
1187 count_after_n4 + 1,
1188 "repeated N=4 call must be a cache hit, not a new entry"
1189 );
1190 assert_eq!(
1191 p4_ptr, p4_again_ptr,
1192 "repeated N=4 call must return the same pipeline pointer"
1193 );
1194
1195 const BARE_SHADER: &str = r#"
1209#include <metal_stdlib>
1210using namespace metal;
1211kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1212 if (tid == 0) { out[0] = 42; }
1213}
1214"#;
1215 registry.register_source("bare_kernel", BARE_SHADER);
1216
1217 let count_before_bool = registry.cached_count();
1218 let _bool_pipeline = registry
1219 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1220 .expect("bool-constants wrapper with empty slice must succeed");
1221
1222 assert_eq!(
1223 registry.cached_count(),
1224 count_before_bool + 1,
1225 "bool-constants wrapper must insert one new cache entry"
1226 );
1227 }
1228
1229 #[test]
1240 fn test_pipeline_labels_propagate_for_mst() {
1241 let device = metal::Device::system_default()
1242 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1243
1244 let mut registry = KernelRegistry::new();
1245
1246 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1248
1249 const BARE_SHADER_LABEL_TEST: &str = r#"
1250#include <metal_stdlib>
1251using namespace metal;
1252kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1253 if (tid == 0) { out[0] = 7; }
1254}
1255"#;
1256 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1257
1258 let plain_label = registry
1262 .get_pipeline("label_smoke_kernel", &device)
1263 .expect("plain pipeline must compile")
1264 .label()
1265 .to_string();
1266 assert_eq!(
1267 plain_label, "label_smoke_kernel",
1268 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1269 );
1270
1271 let label_v7 = registry
1276 .get_pipeline_with_constants(
1277 "int_fc_test_kernel",
1278 &device,
1279 &[],
1280 &[(100, 7_i32)],
1281 )
1282 .expect("specialised pipeline must compile")
1283 .label()
1284 .to_string();
1285 assert_eq!(
1286 label_v7, "int_fc_test_kernel|100:i7",
1287 "get_pipeline_with_constants must label with the cache_key so each \
1288 specialisation is distinct in xctrace MST"
1289 );
1290
1291 let label_v13 = registry
1293 .get_pipeline_with_constants(
1294 "int_fc_test_kernel",
1295 &device,
1296 &[],
1297 &[(100, 13_i32)],
1298 )
1299 .expect("second specialised pipeline must compile")
1300 .label()
1301 .to_string();
1302 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1303 assert_ne!(
1304 label_v7, label_v13,
1305 "distinct constant values must yield distinct pipeline labels"
1306 );
1307 }
1308}