1use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81 let ggml_mm_src: &'static str =
87 include_str!("shaders/quantized_matmul_mm.metal");
88 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92 let ggml_mm_tensor_src: &'static str =
103 include_str!("shaders/quantized_matmul_mm_tensor.metal");
104 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110 let dense_mm_bf16_tensor_src: &'static str =
117 include_str!("shaders/dense_mm_bf16_tensor.metal");
118 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120 let dense_mm_f32_f32_tensor_src: &'static str =
129 include_str!("shaders/dense_mm_f32_f32.metal");
130 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
131
132 let dense_mm_f16_tensor_src: &'static str =
144 include_str!("shaders/dense_mm_f16_tensor.metal");
145 sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
146
147 let dense_gemv_bf16_src: &'static str =
154 include_str!("shaders/dense_gemv_bf16.metal");
155 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
156
157 let scale_mask_softmax_src: &'static str =
163 include_str!("shaders/scale_mask_softmax.metal");
164 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
165
166 sources.insert(
168 "quantized_matmul_id".into(),
169 include_str!("shaders/quantized_matmul_id.metal"),
170 );
171
172 let ggml_id_src: &'static str =
174 include_str!("shaders/quantized_matmul_id_ggml.metal");
175 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
176 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
177 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
178 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
179
180 let ggml_id_mm_src: &'static str =
188 include_str!("shaders/quantized_matmul_id_mm.metal");
189 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
190 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
191 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
192 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
193 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
194
195 let ggml_id_mm_tensor_src: &'static str =
201 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
202 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
203 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
204 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
205
206 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
208 sources.insert("embedding_gather_4bit".into(), embedding_src);
209 sources.insert("embedding_gather_6bit".into(), embedding_src);
210
211 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
213 sources.insert("moe_gate".into(), moe_gate_src);
214
215 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
217 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
218 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
219 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
220 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
221 sources.insert("moe_accumulate".into(), moe_dispatch_src);
222 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
223 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
224 sources.insert("zero_buffer".into(), moe_dispatch_src);
225 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
226 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
227 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
229 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
230 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
231
232 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
234 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
235 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
236 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
237 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
238 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
240 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
241 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
243
244 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
246 sources.insert("elementwise_add_f32".into(), elementwise_src);
247 sources.insert("elementwise_add_f16".into(), elementwise_src);
248 sources.insert("elementwise_mul_f32".into(), elementwise_src);
249 sources.insert("elementwise_mul_f16".into(), elementwise_src);
250 sources.insert("elementwise_add_bf16".into(), elementwise_src);
251 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
252 sources.insert("cast_f16_to_f32".into(), elementwise_src);
253 sources.insert("cast_f32_to_f16".into(), elementwise_src);
254 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
255 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
256 sources.insert("scalar_mul_bf16".into(), elementwise_src);
257 sources.insert("scalar_mul_f32".into(), elementwise_src);
258 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
259 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
260 sources.insert("permute_021_bf16".into(), elementwise_src);
261 sources.insert("transpose_last2_bf16".into(), elementwise_src);
262 sources.insert("transpose_last2_f16".into(), elementwise_src);
263 sources.insert("permute_021_f32".into(), elementwise_src);
264 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
265 sources.insert("transpose_2d_f32".into(), elementwise_src);
266 sources.insert("transpose_2d_f16".into(), elementwise_src);
267
268 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
270 sources.insert("sdpa".into(), sdpa_src);
271 sources.insert("sdpa_bf16".into(), sdpa_src);
272 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
273 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
274 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
275
276 let flash_attn_prefill_src: &'static str =
281 include_str!("shaders/flash_attn_prefill.metal");
282 sources.insert(
284 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
285 flash_attn_prefill_src,
286 );
287 sources.insert(
288 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
289 flash_attn_prefill_src,
290 );
291 sources.insert(
292 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
293 flash_attn_prefill_src,
294 );
295 sources.insert(
296 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
297 flash_attn_prefill_src,
298 );
299 sources.insert(
300 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
301 flash_attn_prefill_src,
302 );
303 sources.insert(
304 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
305 flash_attn_prefill_src,
306 );
307 sources.insert(
311 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
312 flash_attn_prefill_src,
313 );
314 sources.insert(
315 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
316 flash_attn_prefill_src,
317 );
318 sources.insert(
319 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
320 flash_attn_prefill_src,
321 );
322 sources.insert(
323 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
324 flash_attn_prefill_src,
325 );
326
327 let flash_attn_vec_src: &'static str =
330 include_str!("shaders/flash_attn_vec.metal");
331 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
332 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
333 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
334 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
335 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
337 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
338
339 let rope_src: &'static str = include_str!("shaders/rope.metal");
341 sources.insert("rope_f32".into(), rope_src);
342 sources.insert("rope_f16".into(), rope_src);
343 sources.insert("rope_bf16".into(), rope_src);
344 sources.insert("rope_neox_bf16".into(), rope_src);
345 sources.insert("rope_neox_f32".into(), rope_src);
346 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
347 sources.insert("rms_norm_f32".into(), rms_norm_src);
348 sources.insert("rms_norm_f16".into(), rms_norm_src);
349 sources.insert("rms_norm_bf16".into(), rms_norm_src);
350 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
351 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
352 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
353 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
354 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
355 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
356 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
358 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
359 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
360 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
362 sources.insert("l2_norm_f32".into(), l2_norm_src);
363 sources.insert("l2_norm_f16".into(), l2_norm_src);
364 sources.insert("l2_norm_bf16".into(), l2_norm_src);
365 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
367 sources.insert("cumsum_f32".into(), cumsum_src);
368 sources.insert("cumsum_bf16".into(), cumsum_src);
369 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
371 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
372 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
373 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
374 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
375 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
377 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
378 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
379 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
381 sources.insert("rope_multi_f32".into(), rope_multi_src);
382 sources.insert("rope_multi_bf16".into(), rope_multi_src);
383 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
385 sources.insert("gated_delta_net_f32".into(), gdn_src);
386 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
388 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
389 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
390 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
391 sources.insert("silu_mul_f32".into(), silu_mul_src);
392 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
393 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
394 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
395 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
396 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
397 sources.insert("gelu_f32".into(), gelu_src);
398 sources.insert("gelu_f16".into(), gelu_src);
399 sources.insert("gelu_bf16".into(), gelu_src);
400 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
401 sources.insert("softmax_f32".into(), softmax_src);
402 sources.insert("softmax_f16".into(), softmax_src);
403 sources.insert("softmax_bf16".into(), softmax_src);
404 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
405 sources.insert("softcap_f32".into(), softcap_src);
406 sources.insert("softcap_f16".into(), softcap_src);
407 sources.insert("softcap_bf16".into(), softcap_src);
408
409 let fused_norm_add_src: &'static str =
412 include_str!("shaders/fused_norm_add_bf16.metal");
413 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
414 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
415
416 let fused_hnr_f32_src: &'static str =
418 include_str!("shaders/fused_head_norm_rope_f32.metal");
419 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
420
421 let fused_hnr_bf16_src: &'static str =
424 include_str!("shaders/fused_head_norm_rope_bf16.metal");
425 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
426 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
427
428 let fused_norm_add_f32_src: &'static str =
430 include_str!("shaders/fused_norm_add_f32.metal");
431 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
432 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
433 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
434 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
435 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
436 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
437 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
438 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
439
440 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
442 sources.insert("argsort_desc_f32".into(), argsort_src);
443
444 let gather_src: &'static str = include_str!("shaders/gather.metal");
446 sources.insert("gather_f32".into(), gather_src);
447
448 let kv_cache_copy_src: &'static str =
450 include_str!("shaders/kv_cache_copy.metal");
451 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
452 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
453
454 let copy_src: &'static str = include_str!("shaders/copy.metal");
456 sources.insert("strided_copy_f32".into(), copy_src);
457 sources.insert("offset_copy_f32".into(), copy_src);
458
459 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
461 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
462 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
463 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
464 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
466 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
468
469 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
471 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
472 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
473 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
475 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
476 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
477 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
478
479 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
481 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
482 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
483 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
485 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
486
487 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
489 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
490 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
492
493 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
495 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
496 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
497
498 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
500 sources.insert("argmax_f32".into(), argmax_src);
501 let softmax_sample_src: &'static str =
502 include_str!("shaders/softmax_sample.metal");
503 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
504 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
506 sources.insert("top_k_f32".into(), top_k_src);
507
508 let moe_stk_src: &'static str =
511 include_str!("shaders/moe_softmax_topk.metal");
512 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
513 let moe_wr_src: &'static str =
514 include_str!("shaders/moe_weighted_reduce.metal");
515 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
516 let sdpa_decode_src: &'static str =
517 include_str!("shaders/sdpa_decode.metal");
518 sources.insert("sdpa_decode".into(), sdpa_decode_src);
519
520 Self {
521 cache: HashMap::new(),
522 sources,
523 }
524 }
525
526 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
529 let name = name.into();
530 self.cache.remove(&name);
532 self.sources.insert(name, source);
533 }
534
535 pub fn get_pipeline(
547 &mut self,
548 name: &str,
549 device: &metal::DeviceRef,
550 ) -> Result<&ComputePipelineState> {
551 if !self.cache.contains_key(name) {
552 let source = self.sources.get(name).ok_or_else(|| {
554 MlxError::KernelNotFound(name.to_string())
555 })?;
556
557 let compile_opts = metal::CompileOptions::new();
558 let library = device
559 .new_library_with_source(source, &compile_opts)
560 .map_err(|msg| MlxError::ShaderCompilationError {
561 name: name.to_string(),
562 message: msg,
563 })?;
564
565 let function = library
566 .get_function(name, None)
567 .map_err(|msg| MlxError::ShaderCompilationError {
568 name: name.to_string(),
569 message: msg,
570 })?;
571
572 let pipeline = device
573 .new_compute_pipeline_state_with_function(&function)
574 .map_err(|msg| MlxError::ShaderCompilationError {
575 name: name.to_string(),
576 message: msg,
577 })?;
578
579 self.cache.insert(name.to_string(), pipeline);
580 }
581
582 self.cache.get(name).ok_or_else(|| {
585 MlxError::KernelNotFound(name.to_string())
586 })
587 }
588
589 pub fn get_pipeline_with_constants(
611 &mut self,
612 name: &str,
613 device: &metal::DeviceRef,
614 bool_constants: &[(usize, bool)],
615 int_constants: &[(usize, i32)],
616 ) -> Result<&ComputePipelineState> {
617 let mut cache_key = name.to_string();
622 for &(index, value) in bool_constants {
623 cache_key.push('|');
624 cache_key.push_str(&index.to_string());
625 cache_key.push_str(if value { ":b1" } else { ":b0" });
626 }
627 for &(index, value) in int_constants {
628 cache_key.push('|');
629 cache_key.push_str(&index.to_string());
630 cache_key.push(':');
631 cache_key.push('i');
632 cache_key.push_str(&value.to_string());
633 }
634
635 if !self.cache.contains_key(&cache_key) {
636 let source = self.sources.get(name).ok_or_else(|| {
638 MlxError::KernelNotFound(name.to_string())
639 })?;
640
641 let compile_opts = metal::CompileOptions::new();
642 let library = device
643 .new_library_with_source(source, &compile_opts)
644 .map_err(|msg| MlxError::ShaderCompilationError {
645 name: name.to_string(),
646 message: msg,
647 })?;
648
649 let fcv = FunctionConstantValues::new();
654
655 for &(index, value) in bool_constants {
656 let v: u8 = if value { 1 } else { 0 };
659 fcv.set_constant_value_at_index(
660 (&v as *const u8).cast::<std::ffi::c_void>(),
661 MTLDataType::Bool,
662 index as u64,
663 );
664 }
665
666 for &(index, value) in int_constants {
667 fcv.set_constant_value_at_index(
671 (&value as *const i32).cast::<std::ffi::c_void>(),
672 MTLDataType::Int,
673 index as u64,
674 );
675 }
676
677 let function = library
678 .get_function(name, Some(fcv))
679 .map_err(|msg| MlxError::ShaderCompilationError {
680 name: name.to_string(),
681 message: msg,
682 })?;
683
684 let pipeline = device
685 .new_compute_pipeline_state_with_function(&function)
686 .map_err(|msg| MlxError::ShaderCompilationError {
687 name: name.to_string(),
688 message: msg,
689 })?;
690
691 self.cache.insert(cache_key.clone(), pipeline);
692 }
693
694 self.cache.get(&cache_key).ok_or_else(|| {
695 MlxError::KernelNotFound(name.to_string())
696 })
697 }
698
699 pub fn get_pipeline_with_bool_constants(
717 &mut self,
718 name: &str,
719 device: &metal::DeviceRef,
720 bool_constants: &[(usize, bool)],
721 ) -> Result<&ComputePipelineState> {
722 self.get_pipeline_with_constants(name, device, bool_constants, &[])
723 }
724
725 pub fn is_cached(&self, name: &str) -> bool {
727 self.cache.contains_key(name)
728 }
729
730 pub fn cached_count(&self) -> usize {
732 self.cache.len()
733 }
734
735 pub fn source_count(&self) -> usize {
737 self.sources.len()
738 }
739}
740
741impl Default for KernelRegistry {
742 fn default() -> Self {
743 Self::new()
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 const INT_FC_TEST_SHADER: &str = r#"
761#include <metal_stdlib>
762using namespace metal;
763
764constant int test_N [[function_constant(100)]];
765
766kernel void int_fc_test_kernel(
767 device int* out [[buffer(0)]],
768 uint tid [[thread_position_in_grid]])
769{
770 if (tid == 0) {
771 out[0] = test_N;
772 }
773}
774"#;
775
776 #[test]
784 fn test_int_fc_distinct_pipelines_and_bool_compat() {
785 let device = metal::Device::system_default()
786 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
787
788 let mut registry = KernelRegistry::new();
789
790 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
793
794 let p4_ptr = registry
796 .get_pipeline_with_constants(
797 "int_fc_test_kernel",
798 &device,
799 &[], &[(100, 4_i32)], )
802 .expect("pipeline N=4 should compile") as *const _;
803
804 let count_after_n4 = registry.cached_count();
808
809 let p8_ptr = registry
811 .get_pipeline_with_constants(
812 "int_fc_test_kernel",
813 &device,
814 &[],
815 &[(100, 8_i32)],
816 )
817 .expect("pipeline N=8 should compile") as *const _;
818
819 assert_eq!(
821 registry.cached_count(),
822 count_after_n4 + 1,
823 "N=8 must produce a new cache entry"
824 );
825
826 assert_ne!(
828 p4_ptr, p8_ptr,
829 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
830 );
831
832 let p4_again_ptr = registry
835 .get_pipeline_with_constants(
836 "int_fc_test_kernel",
837 &device,
838 &[],
839 &[(100, 4_i32)],
840 )
841 .expect("pipeline N=4 cache hit should succeed") as *const _;
842
843 assert_eq!(
844 registry.cached_count(),
845 count_after_n4 + 1,
846 "repeated N=4 call must be a cache hit, not a new entry"
847 );
848 assert_eq!(
849 p4_ptr, p4_again_ptr,
850 "repeated N=4 call must return the same pipeline pointer"
851 );
852
853 const BARE_SHADER: &str = r#"
867#include <metal_stdlib>
868using namespace metal;
869kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
870 if (tid == 0) { out[0] = 42; }
871}
872"#;
873 registry.register_source("bare_kernel", BARE_SHADER);
874
875 let count_before_bool = registry.cached_count();
876 let _bool_pipeline = registry
877 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
878 .expect("bool-constants wrapper with empty slice must succeed");
879
880 assert_eq!(
881 registry.cached_count(),
882 count_before_bool + 1,
883 "bool-constants wrapper must insert one new cache entry"
884 );
885 }
886}