1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80 sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
82 sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
83 sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
86 sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
88
89 let ggml_mm_src: &'static str =
95 include_str!("shaders/quantized_matmul_mm.metal");
96 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
97 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
98 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
99 sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
101 sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
102 sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
104 sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
106
107 let ggml_mm_tensor_src: &'static str =
118 include_str!("shaders/quantized_matmul_mm_tensor.metal");
119 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
120 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
121 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
122 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
123 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
124 sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
126 sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
127 sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
129 sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
131 sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
132
133 let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
138 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
139 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
140 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
141 sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
142 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
143 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
144 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
145 sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
146 for r1 in [2, 3, 4, 5].iter() {
149 for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
150 let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
151 sources.insert(name, mul_mv_ext_src);
152 }
153 }
154
155 let dense_mm_bf16_tensor_src: &'static str =
162 include_str!("shaders/dense_mm_bf16_tensor.metal");
163 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
164
165 let dense_mm_f32_f32_tensor_src: &'static str =
174 include_str!("shaders/dense_mm_f32_f32.metal");
175 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
176
177 let dense_mm_f16_tensor_src: &'static str =
189 include_str!("shaders/dense_mm_f16_tensor.metal");
190 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
191
192 let dense_gemv_bf16_src: &'static str =
199 include_str!("shaders/dense_gemv_bf16.metal");
200 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
201
202 let scale_mask_softmax_src: &'static str =
208 include_str!("shaders/scale_mask_softmax.metal");
209 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
210
211 sources.insert(
213 "quantized_matmul_id".into(),
214 include_str!("shaders/quantized_matmul_id.metal"),
215 );
216
217 let ggml_id_src: &'static str =
219 include_str!("shaders/quantized_matmul_id_ggml.metal");
220 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
221 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
222 sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
225 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
226 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
227 sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
229 sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
230 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
234
235 let ggml_id_mm_src: &'static str =
243 include_str!("shaders/quantized_matmul_id_mm.metal");
244 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
245 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
246 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
247 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
248 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
249 sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
251 sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
253 sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
254 sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
256
257 let ggml_id_mm_tensor_src: &'static str =
263 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
264 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
265 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
266 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
267 sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
269 sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
271 sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
272 sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
274
275 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
277 sources.insert("embedding_gather_4bit".into(), embedding_src);
278 sources.insert("embedding_gather_6bit".into(), embedding_src);
279
280 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
282 sources.insert("moe_gate".into(), moe_gate_src);
283
284 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
286 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
287 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
288 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
289 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
290 sources.insert("moe_accumulate".into(), moe_dispatch_src);
291 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
292 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
293 sources.insert("zero_buffer".into(), moe_dispatch_src);
294 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
295 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
296 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
298 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
299 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
300 sources.insert(
302 "moe_weighted_sum_seq_backward_outputs_f32".into(),
303 moe_dispatch_src,
304 );
305 sources.insert(
306 "moe_weighted_sum_seq_backward_weights_f32".into(),
307 moe_dispatch_src,
308 );
309 sources.insert(
311 "moe_swiglu_seq_backward_f32".into(),
312 moe_dispatch_src,
313 );
314
315 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
317 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
318 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
319 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
320 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
321 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
323 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
324 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
326
327 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
329 sources.insert("elementwise_add_f32".into(), elementwise_src);
330 sources.insert("elementwise_add_f16".into(), elementwise_src);
331 sources.insert("elementwise_mul_f32".into(), elementwise_src);
332 sources.insert("elementwise_mul_f16".into(), elementwise_src);
333 sources.insert("elementwise_add_bf16".into(), elementwise_src);
334 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
335 sources.insert("cast_f16_to_f32".into(), elementwise_src);
336 sources.insert("cast_f32_to_f16".into(), elementwise_src);
337 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
338 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
339 sources.insert("scalar_mul_bf16".into(), elementwise_src);
340 sources.insert("scalar_mul_f32".into(), elementwise_src);
341 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
342 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
343 sources.insert("permute_021_bf16".into(), elementwise_src);
344 sources.insert("transpose_last2_bf16".into(), elementwise_src);
345 sources.insert("transpose_last2_f16".into(), elementwise_src);
346 sources.insert("permute_021_f32".into(), elementwise_src);
347 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
348 sources.insert("transpose_2d_f32".into(), elementwise_src);
349 sources.insert("transpose_2d_f16".into(), elementwise_src);
350
351 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
353 sources.insert("sdpa".into(), sdpa_src);
354 sources.insert("sdpa_bf16".into(), sdpa_src);
355 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
356 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
357 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
358
359 let flash_attn_prefill_src: &'static str =
364 include_str!("shaders/flash_attn_prefill.metal");
365 sources.insert(
367 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
368 flash_attn_prefill_src,
369 );
370 sources.insert(
371 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
372 flash_attn_prefill_src,
373 );
374 sources.insert(
375 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
376 flash_attn_prefill_src,
377 );
378 sources.insert(
379 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
380 flash_attn_prefill_src,
381 );
382 sources.insert(
383 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
384 flash_attn_prefill_src,
385 );
386 sources.insert(
387 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
388 flash_attn_prefill_src,
389 );
390 sources.insert(
394 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
395 flash_attn_prefill_src,
396 );
397 sources.insert(
398 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
399 flash_attn_prefill_src,
400 );
401 sources.insert(
402 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
403 flash_attn_prefill_src,
404 );
405 sources.insert(
406 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
407 flash_attn_prefill_src,
408 );
409
410 let flash_attn_vec_src: &'static str =
413 include_str!("shaders/flash_attn_vec.metal");
414 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
415 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
416 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
417 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
418 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
420 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
421
422 let rope_src: &'static str = include_str!("shaders/rope.metal");
424 sources.insert("rope_f32".into(), rope_src);
425 sources.insert("rope_f16".into(), rope_src);
426 sources.insert("rope_bf16".into(), rope_src);
427 sources.insert("rope_neox_bf16".into(), rope_src);
428 sources.insert("rope_neox_f32".into(), rope_src);
429 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
430 sources.insert("rms_norm_f32".into(), rms_norm_src);
431 sources.insert("rms_norm_f16".into(), rms_norm_src);
432 sources.insert("rms_norm_bf16".into(), rms_norm_src);
433 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
434 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
435 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
436 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
437 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
438 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
439 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
441 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
442 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
443 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
445 sources.insert("l2_norm_f32".into(), l2_norm_src);
446 sources.insert("l2_norm_f16".into(), l2_norm_src);
447 sources.insert("l2_norm_bf16".into(), l2_norm_src);
448 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
450 sources.insert("cumsum_f32".into(), cumsum_src);
451 sources.insert("cumsum_bf16".into(), cumsum_src);
452 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
454 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
455 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
456 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
457 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
458 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
460 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
461 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
462 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
464 sources.insert("rope_multi_f32".into(), rope_multi_src);
465 sources.insert("rope_multi_bf16".into(), rope_multi_src);
466 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
468 sources.insert("gated_delta_net_f32".into(), gdn_src);
469 let gdn_decode_src: &'static str =
474 include_str!("shaders/gated_delta_net_decode.metal");
475 sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
476 sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
477 sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
478 let gdn_chunk_src: &'static str =
482 include_str!("shaders/gated_delta_net_chunk.metal");
483 sources.insert(
484 "gated_delta_net_chunk_inter_state_bf16".into(),
485 gdn_chunk_src,
486 );
487 let gdn_kkt_src: &'static str =
490 include_str!("shaders/gated_delta_net_kkt.metal");
491 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
492 let gdn_recompute_wu_src: &'static str =
496 include_str!("shaders/gated_delta_net_recompute_wu.metal");
497 sources.insert(
498 "gated_delta_net_recompute_wu_bf16".into(),
499 gdn_recompute_wu_src,
500 );
501 let gdn_chunk_o_src: &'static str =
504 include_str!("shaders/gated_delta_net_chunk_o.metal");
505 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
506 let chunk_local_cumsum_g_src: &'static str =
511 include_str!("shaders/chunk_local_cumsum_g.metal");
512 sources.insert(
513 "chunk_local_cumsum_g_f32".into(),
514 chunk_local_cumsum_g_src,
515 );
516 let chunk_tri_solve_invert_src: &'static str =
517 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
518 sources.insert(
519 "chunk_tri_solve_invert_f32".into(),
520 chunk_tri_solve_invert_src,
521 );
522 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
524 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
525 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
526 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
527 sources.insert("silu_mul_f32".into(), silu_mul_src);
528 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
529 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
530 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
531 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
532 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
533 sources.insert("gelu_f32".into(), gelu_src);
534 sources.insert("gelu_f16".into(), gelu_src);
535 sources.insert("gelu_bf16".into(), gelu_src);
536 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
537 sources.insert("softmax_f32".into(), softmax_src);
538 sources.insert("softmax_f16".into(), softmax_src);
539 sources.insert("softmax_bf16".into(), softmax_src);
540 let softmax_backward_src: &'static str =
541 include_str!("shaders/softmax_backward.metal");
542 sources.insert("softmax_backward_f32".into(), softmax_backward_src);
543 let log_elementwise_src: &'static str =
544 include_str!("shaders/log_elementwise.metal");
545 sources.insert("log_f32".into(), log_elementwise_src);
546 sources.insert("log_backward_f32".into(), log_elementwise_src);
547 let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
548 sources.insert("row_sum_f32".into(), row_sum_src);
549 sources.insert("row_sum_backward_f32".into(), row_sum_src);
550 let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
554 sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
555 sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
556 let rms_norm_backward_src: &'static str =
560 include_str!("shaders/rms_norm_backward.metal");
561 sources.insert(
562 "rms_norm_compute_rms_inv_f32".into(),
563 rms_norm_backward_src,
564 );
565 sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
566 sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
567 let slice_concat_2d_src: &'static str =
572 include_str!("shaders/slice_concat_2d.metal");
573 sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
574 sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
575 let silu_backward_src: &'static str =
578 include_str!("shaders/silu_backward.metal");
579 sources.insert("silu_f32".into(), silu_backward_src);
580 sources.insert("silu_backward_f32".into(), silu_backward_src);
581 let embedding_autograd_src: &'static str =
583 include_str!("shaders/embedding_autograd.metal");
584 sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
585 sources.insert(
586 "embedding_scatter_add_f32".into(),
587 embedding_autograd_src,
588 );
589 let adam_update_src: &'static str =
592 include_str!("shaders/adam_update.metal");
593 sources.insert("adam_update_f32".into(), adam_update_src);
594 let qdq_affine_src: &'static str =
598 include_str!("shaders/qdq_affine.metal");
599 sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
600 sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
601 sources.insert(
602 "qdq_affine_backward_scales_f32".into(),
603 qdq_affine_src,
604 );
605 sources.insert(
606 "qdq_affine_backward_biases_f32".into(),
607 qdq_affine_src,
608 );
609 let qmm_affine_src: &'static str =
613 include_str!("shaders/qmm_affine.metal");
614 sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
615 let qmm_affine_tiled_src: &'static str =
619 include_str!("shaders/qmm_affine_tiled.metal");
620 sources.insert(
621 "qmm_affine_t_f32_tiled".into(),
622 qmm_affine_tiled_src,
623 );
624 let qmm_affine_simd_src: &'static str =
629 include_str!("shaders/qmm_affine_simd.metal");
630 sources.insert(
631 "qmm_affine_t_f32_simd".into(),
632 qmm_affine_simd_src,
633 );
634 let qmm_affine_simd4_src: &'static str =
639 include_str!("shaders/qmm_affine_simd4.metal");
640 sources.insert(
641 "qmm_affine_t_f32_simd4".into(),
642 qmm_affine_simd4_src,
643 );
644 let qmm_affine_simd4_gs64_src: &'static str =
648 include_str!("shaders/qmm_affine_simd4_gs64.metal");
649 sources.insert(
650 "qmm_affine_t_f32_simd4_gs64".into(),
651 qmm_affine_simd4_gs64_src,
652 );
653 let conv1d_dwc_src: &'static str =
658 include_str!("shaders/conv1d_depthwise_causal.metal");
659 sources.insert(
660 "conv1d_depthwise_causal_forward_f32".into(),
661 conv1d_dwc_src,
662 );
663 sources.insert(
664 "conv1d_depthwise_causal_backward_dx_f32".into(),
665 conv1d_dwc_src,
666 );
667 sources.insert(
668 "conv1d_depthwise_causal_backward_dw_f32".into(),
669 conv1d_dwc_src,
670 );
671 let exp_src: &'static str =
674 include_str!("shaders/exp_elementwise.metal");
675 sources.insert("exp_f32".into(), exp_src);
676 sources.insert("exp_backward_f32".into(), exp_src);
677 let outer_src: &'static str =
681 include_str!("shaders/outer_product.metal");
682 sources.insert("outer_product_f32".into(), outer_src);
683 sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
684 sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
685 let taa_src: &'static str =
688 include_str!("shaders/take_along_axis.metal");
689 sources.insert("take_along_axis_f32".into(), taa_src);
690 sources.insert("take_along_axis_backward_f32".into(), taa_src);
691 let div_src: &'static str =
693 include_str!("shaders/divide_elementwise.metal");
694 sources.insert("divide_f32".into(), div_src);
695 sources.insert("divide_backward_f32".into(), div_src);
696 let sqrt_src: &'static str =
698 include_str!("shaders/sqrt_elementwise.metal");
699 sources.insert("sqrt_f32".into(), sqrt_src);
700 sources.insert("sqrt_backward_f32".into(), sqrt_src);
701 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
702 sources.insert("softcap_f32".into(), softcap_src);
703 sources.insert("softcap_f16".into(), softcap_src);
704 sources.insert("softcap_bf16".into(), softcap_src);
705
706 let fused_norm_add_src: &'static str =
709 include_str!("shaders/fused_norm_add_bf16.metal");
710 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
711 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
712
713 let fused_hnr_f32_src: &'static str =
715 include_str!("shaders/fused_head_norm_rope_f32.metal");
716 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
717
718 let fused_hnr_bf16_src: &'static str =
721 include_str!("shaders/fused_head_norm_rope_bf16.metal");
722 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
723 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
724
725 let fused_norm_add_f32_src: &'static str =
727 include_str!("shaders/fused_norm_add_f32.metal");
728 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
729 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
730 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
731 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
732 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
733 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
734 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
735 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
736
737 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
739 sources.insert("argsort_desc_f32".into(), argsort_src);
740
741 let gather_src: &'static str = include_str!("shaders/gather.metal");
743 sources.insert("gather_f32".into(), gather_src);
744
745 let kv_cache_copy_src: &'static str =
747 include_str!("shaders/kv_cache_copy.metal");
748 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
749 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
750
751 let copy_src: &'static str = include_str!("shaders/copy.metal");
753 sources.insert("strided_copy_f32".into(), copy_src);
754 sources.insert("offset_copy_f32".into(), copy_src);
755
756 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
760 sources.insert("qkv_split_f32".into(), qkv_split_src);
761
762 let repeat_tiled_src: &'static str =
766 include_str!("shaders/repeat_tiled.metal");
767 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
768
769 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
771 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
772 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
773 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
774 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
776 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
778
779 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
781 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
782 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
783 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
785 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
786 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
787 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
788
789 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
791 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
792 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
793 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
795 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
796
797 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
799 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
800 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
802
803 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
805 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
806 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
807
808 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
810 sources.insert("argmax_f32".into(), argmax_src);
811 let softmax_sample_src: &'static str =
812 include_str!("shaders/softmax_sample.metal");
813 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
814 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
816 sources.insert("top_k_f32".into(), top_k_src);
817
818 let moe_stk_src: &'static str =
821 include_str!("shaders/moe_softmax_topk.metal");
822 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
823 let moe_wr_src: &'static str =
824 include_str!("shaders/moe_weighted_reduce.metal");
825 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
826 let sdpa_decode_src: &'static str =
827 include_str!("shaders/sdpa_decode.metal");
828 sources.insert("sdpa_decode".into(), sdpa_decode_src);
829
830 Self {
831 cache: HashMap::new(),
832 sources,
833 }
834 }
835
836 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
839 let name = name.into();
840 self.cache.remove(&name);
842 self.sources.insert(name, source);
843 }
844
845 pub fn get_pipeline(
857 &mut self,
858 name: &str,
859 device: &metal::DeviceRef,
860 ) -> Result<&ComputePipelineState> {
861 if !self.cache.contains_key(name) {
862 let source = self.sources.get(name).ok_or_else(|| {
864 MlxError::KernelNotFound(name.to_string())
865 })?;
866
867 let compile_opts = metal::CompileOptions::new();
868 let library = device
869 .new_library_with_source(source, &compile_opts)
870 .map_err(|msg| MlxError::ShaderCompilationError {
871 name: name.to_string(),
872 message: msg,
873 })?;
874
875 let function = library
876 .get_function(name, None)
877 .map_err(|msg| MlxError::ShaderCompilationError {
878 name: name.to_string(),
879 message: msg,
880 })?;
881
882 let descriptor = ComputePipelineDescriptor::new();
892 descriptor.set_compute_function(Some(&function));
893 descriptor.set_label(name);
894
895 let pipeline = device
896 .new_compute_pipeline_state(&descriptor)
897 .map_err(|msg| MlxError::ShaderCompilationError {
898 name: name.to_string(),
899 message: msg,
900 })?;
901
902 self.cache.insert(name.to_string(), pipeline);
903 }
904
905 self.cache.get(name).ok_or_else(|| {
908 MlxError::KernelNotFound(name.to_string())
909 })
910 }
911
912 pub fn get_pipeline_with_constants(
934 &mut self,
935 name: &str,
936 device: &metal::DeviceRef,
937 bool_constants: &[(usize, bool)],
938 int_constants: &[(usize, i32)],
939 ) -> Result<&ComputePipelineState> {
940 let mut cache_key = name.to_string();
945 for &(index, value) in bool_constants {
946 cache_key.push('|');
947 cache_key.push_str(&index.to_string());
948 cache_key.push_str(if value { ":b1" } else { ":b0" });
949 }
950 for &(index, value) in int_constants {
951 cache_key.push('|');
952 cache_key.push_str(&index.to_string());
953 cache_key.push(':');
954 cache_key.push('i');
955 cache_key.push_str(&value.to_string());
956 }
957
958 if !self.cache.contains_key(&cache_key) {
959 let source = self.sources.get(name).ok_or_else(|| {
961 MlxError::KernelNotFound(name.to_string())
962 })?;
963
964 let compile_opts = metal::CompileOptions::new();
965 let library = device
966 .new_library_with_source(source, &compile_opts)
967 .map_err(|msg| MlxError::ShaderCompilationError {
968 name: name.to_string(),
969 message: msg,
970 })?;
971
972 let fcv = FunctionConstantValues::new();
977
978 for &(index, value) in bool_constants {
979 let v: u8 = if value { 1 } else { 0 };
982 fcv.set_constant_value_at_index(
983 (&v as *const u8).cast::<std::ffi::c_void>(),
984 MTLDataType::Bool,
985 index as u64,
986 );
987 }
988
989 for &(index, value) in int_constants {
990 fcv.set_constant_value_at_index(
994 (&value as *const i32).cast::<std::ffi::c_void>(),
995 MTLDataType::Int,
996 index as u64,
997 );
998 }
999
1000 let function = library
1001 .get_function(name, Some(fcv))
1002 .map_err(|msg| MlxError::ShaderCompilationError {
1003 name: name.to_string(),
1004 message: msg,
1005 })?;
1006
1007 let descriptor = ComputePipelineDescriptor::new();
1014 descriptor.set_compute_function(Some(&function));
1015 descriptor.set_label(&cache_key);
1016
1017 let pipeline = device
1018 .new_compute_pipeline_state(&descriptor)
1019 .map_err(|msg| MlxError::ShaderCompilationError {
1020 name: name.to_string(),
1021 message: msg,
1022 })?;
1023
1024 self.cache.insert(cache_key.clone(), pipeline);
1025 }
1026
1027 self.cache.get(&cache_key).ok_or_else(|| {
1028 MlxError::KernelNotFound(name.to_string())
1029 })
1030 }
1031
1032 pub fn get_pipeline_with_bool_constants(
1050 &mut self,
1051 name: &str,
1052 device: &metal::DeviceRef,
1053 bool_constants: &[(usize, bool)],
1054 ) -> Result<&ComputePipelineState> {
1055 self.get_pipeline_with_constants(name, device, bool_constants, &[])
1056 }
1057
1058 pub fn is_cached(&self, name: &str) -> bool {
1060 self.cache.contains_key(name)
1061 }
1062
1063 pub fn cached_count(&self) -> usize {
1065 self.cache.len()
1066 }
1067
1068 pub fn source_count(&self) -> usize {
1070 self.sources.len()
1071 }
1072}
1073
1074impl Default for KernelRegistry {
1075 fn default() -> Self {
1076 Self::new()
1077 }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082 use super::*;
1083
1084 const INT_FC_TEST_SHADER: &str = r#"
1094#include <metal_stdlib>
1095using namespace metal;
1096
1097constant int test_N [[function_constant(100)]];
1098
1099kernel void int_fc_test_kernel(
1100 device int* out [[buffer(0)]],
1101 uint tid [[thread_position_in_grid]])
1102{
1103 if (tid == 0) {
1104 out[0] = test_N;
1105 }
1106}
1107"#;
1108
1109 #[test]
1117 fn test_int_fc_distinct_pipelines_and_bool_compat() {
1118 let device = metal::Device::system_default()
1119 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1120
1121 let mut registry = KernelRegistry::new();
1122
1123 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1126
1127 let p4_ptr = registry
1129 .get_pipeline_with_constants(
1130 "int_fc_test_kernel",
1131 &device,
1132 &[], &[(100, 4_i32)], )
1135 .expect("pipeline N=4 should compile") as *const _;
1136
1137 let count_after_n4 = registry.cached_count();
1141
1142 let p8_ptr = registry
1144 .get_pipeline_with_constants(
1145 "int_fc_test_kernel",
1146 &device,
1147 &[],
1148 &[(100, 8_i32)],
1149 )
1150 .expect("pipeline N=8 should compile") as *const _;
1151
1152 assert_eq!(
1154 registry.cached_count(),
1155 count_after_n4 + 1,
1156 "N=8 must produce a new cache entry"
1157 );
1158
1159 assert_ne!(
1161 p4_ptr, p8_ptr,
1162 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
1163 );
1164
1165 let p4_again_ptr = registry
1168 .get_pipeline_with_constants(
1169 "int_fc_test_kernel",
1170 &device,
1171 &[],
1172 &[(100, 4_i32)],
1173 )
1174 .expect("pipeline N=4 cache hit should succeed") as *const _;
1175
1176 assert_eq!(
1177 registry.cached_count(),
1178 count_after_n4 + 1,
1179 "repeated N=4 call must be a cache hit, not a new entry"
1180 );
1181 assert_eq!(
1182 p4_ptr, p4_again_ptr,
1183 "repeated N=4 call must return the same pipeline pointer"
1184 );
1185
1186 const BARE_SHADER: &str = r#"
1200#include <metal_stdlib>
1201using namespace metal;
1202kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1203 if (tid == 0) { out[0] = 42; }
1204}
1205"#;
1206 registry.register_source("bare_kernel", BARE_SHADER);
1207
1208 let count_before_bool = registry.cached_count();
1209 let _bool_pipeline = registry
1210 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
1211 .expect("bool-constants wrapper with empty slice must succeed");
1212
1213 assert_eq!(
1214 registry.cached_count(),
1215 count_before_bool + 1,
1216 "bool-constants wrapper must insert one new cache entry"
1217 );
1218 }
1219
1220 #[test]
1231 fn test_pipeline_labels_propagate_for_mst() {
1232 let device = metal::Device::system_default()
1233 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
1234
1235 let mut registry = KernelRegistry::new();
1236
1237 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
1239
1240 const BARE_SHADER_LABEL_TEST: &str = r#"
1241#include <metal_stdlib>
1242using namespace metal;
1243kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
1244 if (tid == 0) { out[0] = 7; }
1245}
1246"#;
1247 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
1248
1249 let plain_label = registry
1253 .get_pipeline("label_smoke_kernel", &device)
1254 .expect("plain pipeline must compile")
1255 .label()
1256 .to_string();
1257 assert_eq!(
1258 plain_label, "label_smoke_kernel",
1259 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1260 );
1261
1262 let label_v7 = registry
1267 .get_pipeline_with_constants(
1268 "int_fc_test_kernel",
1269 &device,
1270 &[],
1271 &[(100, 7_i32)],
1272 )
1273 .expect("specialised pipeline must compile")
1274 .label()
1275 .to_string();
1276 assert_eq!(
1277 label_v7, "int_fc_test_kernel|100:i7",
1278 "get_pipeline_with_constants must label with the cache_key so each \
1279 specialisation is distinct in xctrace MST"
1280 );
1281
1282 let label_v13 = registry
1284 .get_pipeline_with_constants(
1285 "int_fc_test_kernel",
1286 &device,
1287 &[],
1288 &[(100, 13_i32)],
1289 )
1290 .expect("second specialised pipeline must compile")
1291 .label()
1292 .to_string();
1293 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1294 assert_ne!(
1295 label_v7, label_v13,
1296 "distinct constant values must yield distinct pipeline labels"
1297 );
1298 }
1299}