1use std::collections::HashMap;
9
10use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81 let ggml_mm_src: &'static str =
87 include_str!("shaders/quantized_matmul_mm.metal");
88 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92 let ggml_mm_tensor_src: &'static str =
103 include_str!("shaders/quantized_matmul_mm_tensor.metal");
104 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110 let dense_mm_bf16_tensor_src: &'static str =
117 include_str!("shaders/dense_mm_bf16_tensor.metal");
118 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120 let dense_mm_f32_f32_tensor_src: &'static str =
129 include_str!("shaders/dense_mm_f32_f32.metal");
130 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
131
132 let dense_mm_f16_tensor_src: &'static str =
144 include_str!("shaders/dense_mm_f16_tensor.metal");
145 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
146
147 let dense_gemv_bf16_src: &'static str =
154 include_str!("shaders/dense_gemv_bf16.metal");
155 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
156
157 let scale_mask_softmax_src: &'static str =
163 include_str!("shaders/scale_mask_softmax.metal");
164 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
165
166 sources.insert(
168 "quantized_matmul_id".into(),
169 include_str!("shaders/quantized_matmul_id.metal"),
170 );
171
172 let ggml_id_src: &'static str =
174 include_str!("shaders/quantized_matmul_id_ggml.metal");
175 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
176 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
177 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
178 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
179 sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
183
184 let ggml_id_mm_src: &'static str =
192 include_str!("shaders/quantized_matmul_id_mm.metal");
193 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
194 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
195 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
196 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
197 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
198
199 let ggml_id_mm_tensor_src: &'static str =
205 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
206 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
207 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
208 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
209
210 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
212 sources.insert("embedding_gather_4bit".into(), embedding_src);
213 sources.insert("embedding_gather_6bit".into(), embedding_src);
214
215 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
217 sources.insert("moe_gate".into(), moe_gate_src);
218
219 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
221 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
222 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
223 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
224 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
225 sources.insert("moe_accumulate".into(), moe_dispatch_src);
226 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
227 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
228 sources.insert("zero_buffer".into(), moe_dispatch_src);
229 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
230 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
231 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
233 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
234 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
235
236 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
238 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
239 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
240 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
241 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
242 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
244 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
245 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
247
248 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
250 sources.insert("elementwise_add_f32".into(), elementwise_src);
251 sources.insert("elementwise_add_f16".into(), elementwise_src);
252 sources.insert("elementwise_mul_f32".into(), elementwise_src);
253 sources.insert("elementwise_mul_f16".into(), elementwise_src);
254 sources.insert("elementwise_add_bf16".into(), elementwise_src);
255 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
256 sources.insert("cast_f16_to_f32".into(), elementwise_src);
257 sources.insert("cast_f32_to_f16".into(), elementwise_src);
258 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
259 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
260 sources.insert("scalar_mul_bf16".into(), elementwise_src);
261 sources.insert("scalar_mul_f32".into(), elementwise_src);
262 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
263 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
264 sources.insert("permute_021_bf16".into(), elementwise_src);
265 sources.insert("transpose_last2_bf16".into(), elementwise_src);
266 sources.insert("transpose_last2_f16".into(), elementwise_src);
267 sources.insert("permute_021_f32".into(), elementwise_src);
268 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
269 sources.insert("transpose_2d_f32".into(), elementwise_src);
270 sources.insert("transpose_2d_f16".into(), elementwise_src);
271
272 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
274 sources.insert("sdpa".into(), sdpa_src);
275 sources.insert("sdpa_bf16".into(), sdpa_src);
276 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
277 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
278 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
279
280 let flash_attn_prefill_src: &'static str =
285 include_str!("shaders/flash_attn_prefill.metal");
286 sources.insert(
288 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
289 flash_attn_prefill_src,
290 );
291 sources.insert(
292 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
293 flash_attn_prefill_src,
294 );
295 sources.insert(
296 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
297 flash_attn_prefill_src,
298 );
299 sources.insert(
300 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
301 flash_attn_prefill_src,
302 );
303 sources.insert(
304 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
305 flash_attn_prefill_src,
306 );
307 sources.insert(
308 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
309 flash_attn_prefill_src,
310 );
311 sources.insert(
315 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
316 flash_attn_prefill_src,
317 );
318 sources.insert(
319 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
320 flash_attn_prefill_src,
321 );
322 sources.insert(
323 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
324 flash_attn_prefill_src,
325 );
326 sources.insert(
327 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
328 flash_attn_prefill_src,
329 );
330
331 let flash_attn_vec_src: &'static str =
334 include_str!("shaders/flash_attn_vec.metal");
335 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
336 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
337 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
338 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
339 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
341 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
342
343 let rope_src: &'static str = include_str!("shaders/rope.metal");
345 sources.insert("rope_f32".into(), rope_src);
346 sources.insert("rope_f16".into(), rope_src);
347 sources.insert("rope_bf16".into(), rope_src);
348 sources.insert("rope_neox_bf16".into(), rope_src);
349 sources.insert("rope_neox_f32".into(), rope_src);
350 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
351 sources.insert("rms_norm_f32".into(), rms_norm_src);
352 sources.insert("rms_norm_f16".into(), rms_norm_src);
353 sources.insert("rms_norm_bf16".into(), rms_norm_src);
354 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
355 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
356 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
357 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
358 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
359 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
360 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
362 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
363 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
364 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
366 sources.insert("l2_norm_f32".into(), l2_norm_src);
367 sources.insert("l2_norm_f16".into(), l2_norm_src);
368 sources.insert("l2_norm_bf16".into(), l2_norm_src);
369 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
371 sources.insert("cumsum_f32".into(), cumsum_src);
372 sources.insert("cumsum_bf16".into(), cumsum_src);
373 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
375 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
376 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
377 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
378 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
379 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
381 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
382 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
383 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
385 sources.insert("rope_multi_f32".into(), rope_multi_src);
386 sources.insert("rope_multi_bf16".into(), rope_multi_src);
387 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
389 sources.insert("gated_delta_net_f32".into(), gdn_src);
390 let gdn_chunk_src: &'static str =
394 include_str!("shaders/gated_delta_net_chunk.metal");
395 sources.insert(
396 "gated_delta_net_chunk_inter_state_bf16".into(),
397 gdn_chunk_src,
398 );
399 let gdn_kkt_src: &'static str =
402 include_str!("shaders/gated_delta_net_kkt.metal");
403 sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
404 let gdn_recompute_wu_src: &'static str =
408 include_str!("shaders/gated_delta_net_recompute_wu.metal");
409 sources.insert(
410 "gated_delta_net_recompute_wu_bf16".into(),
411 gdn_recompute_wu_src,
412 );
413 let gdn_chunk_o_src: &'static str =
416 include_str!("shaders/gated_delta_net_chunk_o.metal");
417 sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
418 let chunk_local_cumsum_g_src: &'static str =
423 include_str!("shaders/chunk_local_cumsum_g.metal");
424 sources.insert(
425 "chunk_local_cumsum_g_f32".into(),
426 chunk_local_cumsum_g_src,
427 );
428 let chunk_tri_solve_invert_src: &'static str =
429 include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
430 sources.insert(
431 "chunk_tri_solve_invert_f32".into(),
432 chunk_tri_solve_invert_src,
433 );
434 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
436 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
437 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
438 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
439 sources.insert("silu_mul_f32".into(), silu_mul_src);
440 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
441 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
442 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
443 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
444 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
445 sources.insert("gelu_f32".into(), gelu_src);
446 sources.insert("gelu_f16".into(), gelu_src);
447 sources.insert("gelu_bf16".into(), gelu_src);
448 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
449 sources.insert("softmax_f32".into(), softmax_src);
450 sources.insert("softmax_f16".into(), softmax_src);
451 sources.insert("softmax_bf16".into(), softmax_src);
452 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
453 sources.insert("softcap_f32".into(), softcap_src);
454 sources.insert("softcap_f16".into(), softcap_src);
455 sources.insert("softcap_bf16".into(), softcap_src);
456
457 let fused_norm_add_src: &'static str =
460 include_str!("shaders/fused_norm_add_bf16.metal");
461 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
462 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
463
464 let fused_hnr_f32_src: &'static str =
466 include_str!("shaders/fused_head_norm_rope_f32.metal");
467 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
468
469 let fused_hnr_bf16_src: &'static str =
472 include_str!("shaders/fused_head_norm_rope_bf16.metal");
473 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
474 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
475
476 let fused_norm_add_f32_src: &'static str =
478 include_str!("shaders/fused_norm_add_f32.metal");
479 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
480 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
481 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
482 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
483 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
484 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
485 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
486 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
487
488 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
490 sources.insert("argsort_desc_f32".into(), argsort_src);
491
492 let gather_src: &'static str = include_str!("shaders/gather.metal");
494 sources.insert("gather_f32".into(), gather_src);
495
496 let kv_cache_copy_src: &'static str =
498 include_str!("shaders/kv_cache_copy.metal");
499 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
500 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
501
502 let copy_src: &'static str = include_str!("shaders/copy.metal");
504 sources.insert("strided_copy_f32".into(), copy_src);
505 sources.insert("offset_copy_f32".into(), copy_src);
506
507 let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
511 sources.insert("qkv_split_f32".into(), qkv_split_src);
512
513 let repeat_tiled_src: &'static str =
517 include_str!("shaders/repeat_tiled.metal");
518 sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
519
520 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
522 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
523 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
524 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
525 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
527 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
529
530 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
532 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
533 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
534 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
536 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
537 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
538 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
539
540 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
542 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
543 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
544 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
546 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
547
548 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
550 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
551 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
553
554 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
556 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
557 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
558
559 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
561 sources.insert("argmax_f32".into(), argmax_src);
562 let softmax_sample_src: &'static str =
563 include_str!("shaders/softmax_sample.metal");
564 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
565 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
567 sources.insert("top_k_f32".into(), top_k_src);
568
569 let moe_stk_src: &'static str =
572 include_str!("shaders/moe_softmax_topk.metal");
573 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
574 let moe_wr_src: &'static str =
575 include_str!("shaders/moe_weighted_reduce.metal");
576 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
577 let sdpa_decode_src: &'static str =
578 include_str!("shaders/sdpa_decode.metal");
579 sources.insert("sdpa_decode".into(), sdpa_decode_src);
580
581 Self {
582 cache: HashMap::new(),
583 sources,
584 }
585 }
586
587 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
590 let name = name.into();
591 self.cache.remove(&name);
593 self.sources.insert(name, source);
594 }
595
596 pub fn get_pipeline(
608 &mut self,
609 name: &str,
610 device: &metal::DeviceRef,
611 ) -> Result<&ComputePipelineState> {
612 if !self.cache.contains_key(name) {
613 let source = self.sources.get(name).ok_or_else(|| {
615 MlxError::KernelNotFound(name.to_string())
616 })?;
617
618 let compile_opts = metal::CompileOptions::new();
619 let library = device
620 .new_library_with_source(source, &compile_opts)
621 .map_err(|msg| MlxError::ShaderCompilationError {
622 name: name.to_string(),
623 message: msg,
624 })?;
625
626 let function = library
627 .get_function(name, None)
628 .map_err(|msg| MlxError::ShaderCompilationError {
629 name: name.to_string(),
630 message: msg,
631 })?;
632
633 let descriptor = ComputePipelineDescriptor::new();
643 descriptor.set_compute_function(Some(&function));
644 descriptor.set_label(name);
645
646 let pipeline = device
647 .new_compute_pipeline_state(&descriptor)
648 .map_err(|msg| MlxError::ShaderCompilationError {
649 name: name.to_string(),
650 message: msg,
651 })?;
652
653 self.cache.insert(name.to_string(), pipeline);
654 }
655
656 self.cache.get(name).ok_or_else(|| {
659 MlxError::KernelNotFound(name.to_string())
660 })
661 }
662
663 pub fn get_pipeline_with_constants(
685 &mut self,
686 name: &str,
687 device: &metal::DeviceRef,
688 bool_constants: &[(usize, bool)],
689 int_constants: &[(usize, i32)],
690 ) -> Result<&ComputePipelineState> {
691 let mut cache_key = name.to_string();
696 for &(index, value) in bool_constants {
697 cache_key.push('|');
698 cache_key.push_str(&index.to_string());
699 cache_key.push_str(if value { ":b1" } else { ":b0" });
700 }
701 for &(index, value) in int_constants {
702 cache_key.push('|');
703 cache_key.push_str(&index.to_string());
704 cache_key.push(':');
705 cache_key.push('i');
706 cache_key.push_str(&value.to_string());
707 }
708
709 if !self.cache.contains_key(&cache_key) {
710 let source = self.sources.get(name).ok_or_else(|| {
712 MlxError::KernelNotFound(name.to_string())
713 })?;
714
715 let compile_opts = metal::CompileOptions::new();
716 let library = device
717 .new_library_with_source(source, &compile_opts)
718 .map_err(|msg| MlxError::ShaderCompilationError {
719 name: name.to_string(),
720 message: msg,
721 })?;
722
723 let fcv = FunctionConstantValues::new();
728
729 for &(index, value) in bool_constants {
730 let v: u8 = if value { 1 } else { 0 };
733 fcv.set_constant_value_at_index(
734 (&v as *const u8).cast::<std::ffi::c_void>(),
735 MTLDataType::Bool,
736 index as u64,
737 );
738 }
739
740 for &(index, value) in int_constants {
741 fcv.set_constant_value_at_index(
745 (&value as *const i32).cast::<std::ffi::c_void>(),
746 MTLDataType::Int,
747 index as u64,
748 );
749 }
750
751 let function = library
752 .get_function(name, Some(fcv))
753 .map_err(|msg| MlxError::ShaderCompilationError {
754 name: name.to_string(),
755 message: msg,
756 })?;
757
758 let descriptor = ComputePipelineDescriptor::new();
765 descriptor.set_compute_function(Some(&function));
766 descriptor.set_label(&cache_key);
767
768 let pipeline = device
769 .new_compute_pipeline_state(&descriptor)
770 .map_err(|msg| MlxError::ShaderCompilationError {
771 name: name.to_string(),
772 message: msg,
773 })?;
774
775 self.cache.insert(cache_key.clone(), pipeline);
776 }
777
778 self.cache.get(&cache_key).ok_or_else(|| {
779 MlxError::KernelNotFound(name.to_string())
780 })
781 }
782
783 pub fn get_pipeline_with_bool_constants(
801 &mut self,
802 name: &str,
803 device: &metal::DeviceRef,
804 bool_constants: &[(usize, bool)],
805 ) -> Result<&ComputePipelineState> {
806 self.get_pipeline_with_constants(name, device, bool_constants, &[])
807 }
808
809 pub fn is_cached(&self, name: &str) -> bool {
811 self.cache.contains_key(name)
812 }
813
814 pub fn cached_count(&self) -> usize {
816 self.cache.len()
817 }
818
819 pub fn source_count(&self) -> usize {
821 self.sources.len()
822 }
823}
824
825impl Default for KernelRegistry {
826 fn default() -> Self {
827 Self::new()
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834
835 const INT_FC_TEST_SHADER: &str = r#"
845#include <metal_stdlib>
846using namespace metal;
847
848constant int test_N [[function_constant(100)]];
849
850kernel void int_fc_test_kernel(
851 device int* out [[buffer(0)]],
852 uint tid [[thread_position_in_grid]])
853{
854 if (tid == 0) {
855 out[0] = test_N;
856 }
857}
858"#;
859
860 #[test]
868 fn test_int_fc_distinct_pipelines_and_bool_compat() {
869 let device = metal::Device::system_default()
870 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
871
872 let mut registry = KernelRegistry::new();
873
874 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
877
878 let p4_ptr = registry
880 .get_pipeline_with_constants(
881 "int_fc_test_kernel",
882 &device,
883 &[], &[(100, 4_i32)], )
886 .expect("pipeline N=4 should compile") as *const _;
887
888 let count_after_n4 = registry.cached_count();
892
893 let p8_ptr = registry
895 .get_pipeline_with_constants(
896 "int_fc_test_kernel",
897 &device,
898 &[],
899 &[(100, 8_i32)],
900 )
901 .expect("pipeline N=8 should compile") as *const _;
902
903 assert_eq!(
905 registry.cached_count(),
906 count_after_n4 + 1,
907 "N=8 must produce a new cache entry"
908 );
909
910 assert_ne!(
912 p4_ptr, p8_ptr,
913 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
914 );
915
916 let p4_again_ptr = registry
919 .get_pipeline_with_constants(
920 "int_fc_test_kernel",
921 &device,
922 &[],
923 &[(100, 4_i32)],
924 )
925 .expect("pipeline N=4 cache hit should succeed") as *const _;
926
927 assert_eq!(
928 registry.cached_count(),
929 count_after_n4 + 1,
930 "repeated N=4 call must be a cache hit, not a new entry"
931 );
932 assert_eq!(
933 p4_ptr, p4_again_ptr,
934 "repeated N=4 call must return the same pipeline pointer"
935 );
936
937 const BARE_SHADER: &str = r#"
951#include <metal_stdlib>
952using namespace metal;
953kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
954 if (tid == 0) { out[0] = 42; }
955}
956"#;
957 registry.register_source("bare_kernel", BARE_SHADER);
958
959 let count_before_bool = registry.cached_count();
960 let _bool_pipeline = registry
961 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
962 .expect("bool-constants wrapper with empty slice must succeed");
963
964 assert_eq!(
965 registry.cached_count(),
966 count_before_bool + 1,
967 "bool-constants wrapper must insert one new cache entry"
968 );
969 }
970
971 #[test]
982 fn test_pipeline_labels_propagate_for_mst() {
983 let device = metal::Device::system_default()
984 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
985
986 let mut registry = KernelRegistry::new();
987
988 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
990
991 const BARE_SHADER_LABEL_TEST: &str = r#"
992#include <metal_stdlib>
993using namespace metal;
994kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
995 if (tid == 0) { out[0] = 7; }
996}
997"#;
998 registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
999
1000 let plain_label = registry
1004 .get_pipeline("label_smoke_kernel", &device)
1005 .expect("plain pipeline must compile")
1006 .label()
1007 .to_string();
1008 assert_eq!(
1009 plain_label, "label_smoke_kernel",
1010 "get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
1011 );
1012
1013 let label_v7 = registry
1018 .get_pipeline_with_constants(
1019 "int_fc_test_kernel",
1020 &device,
1021 &[],
1022 &[(100, 7_i32)],
1023 )
1024 .expect("specialised pipeline must compile")
1025 .label()
1026 .to_string();
1027 assert_eq!(
1028 label_v7, "int_fc_test_kernel|100:i7",
1029 "get_pipeline_with_constants must label with the cache_key so each \
1030 specialisation is distinct in xctrace MST"
1031 );
1032
1033 let label_v13 = registry
1035 .get_pipeline_with_constants(
1036 "int_fc_test_kernel",
1037 &device,
1038 &[],
1039 &[(100, 13_i32)],
1040 )
1041 .expect("second specialised pipeline must compile")
1042 .label()
1043 .to_string();
1044 assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
1045 assert_ne!(
1046 label_v7, label_v13,
1047 "distinct constant values must yield distinct pipeline labels"
1048 );
1049 }
1050}