1use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81 let ggml_mm_src: &'static str =
87 include_str!("shaders/quantized_matmul_mm.metal");
88 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92 let ggml_mm_tensor_src: &'static str =
103 include_str!("shaders/quantized_matmul_mm_tensor.metal");
104 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110 let dense_mm_bf16_tensor_src: &'static str =
117 include_str!("shaders/dense_mm_bf16_tensor.metal");
118 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120 let dense_mm_f32_f32_tensor_src: &'static str =
129 include_str!("shaders/dense_mm_f32_f32.metal");
130 sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
131
132 let dense_gemv_bf16_src: &'static str =
139 include_str!("shaders/dense_gemv_bf16.metal");
140 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
141
142 let scale_mask_softmax_src: &'static str =
148 include_str!("shaders/scale_mask_softmax.metal");
149 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
150
151 sources.insert(
153 "quantized_matmul_id".into(),
154 include_str!("shaders/quantized_matmul_id.metal"),
155 );
156
157 let ggml_id_src: &'static str =
159 include_str!("shaders/quantized_matmul_id_ggml.metal");
160 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
161 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
162 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
163 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
164
165 let ggml_id_mm_src: &'static str =
173 include_str!("shaders/quantized_matmul_id_mm.metal");
174 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
175 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
176 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
177 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
178 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
179
180 let ggml_id_mm_tensor_src: &'static str =
186 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
187 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
188 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
189 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
190
191 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
193 sources.insert("embedding_gather_4bit".into(), embedding_src);
194 sources.insert("embedding_gather_6bit".into(), embedding_src);
195
196 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
198 sources.insert("moe_gate".into(), moe_gate_src);
199
200 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
202 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
203 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
204 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
205 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
206 sources.insert("moe_accumulate".into(), moe_dispatch_src);
207 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
208 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
209 sources.insert("zero_buffer".into(), moe_dispatch_src);
210 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
211 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
212 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
214 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
215 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
216
217 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
219 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
220 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
221 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
222 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
223 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
225 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
226 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
228
229 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
231 sources.insert("elementwise_add_f32".into(), elementwise_src);
232 sources.insert("elementwise_add_f16".into(), elementwise_src);
233 sources.insert("elementwise_mul_f32".into(), elementwise_src);
234 sources.insert("elementwise_mul_f16".into(), elementwise_src);
235 sources.insert("elementwise_add_bf16".into(), elementwise_src);
236 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
237 sources.insert("cast_f16_to_f32".into(), elementwise_src);
238 sources.insert("cast_f32_to_f16".into(), elementwise_src);
239 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
240 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
241 sources.insert("scalar_mul_bf16".into(), elementwise_src);
242 sources.insert("scalar_mul_f32".into(), elementwise_src);
243 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
244 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
245 sources.insert("permute_021_bf16".into(), elementwise_src);
246 sources.insert("transpose_last2_bf16".into(), elementwise_src);
247 sources.insert("permute_021_f32".into(), elementwise_src);
248 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
249 sources.insert("transpose_2d_f32".into(), elementwise_src);
250 sources.insert("transpose_2d_f16".into(), elementwise_src);
251
252 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
254 sources.insert("sdpa".into(), sdpa_src);
255 sources.insert("sdpa_bf16".into(), sdpa_src);
256 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
257 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
258 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
259
260 let flash_attn_prefill_src: &'static str =
265 include_str!("shaders/flash_attn_prefill.metal");
266 sources.insert(
268 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
269 flash_attn_prefill_src,
270 );
271 sources.insert(
272 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
273 flash_attn_prefill_src,
274 );
275 sources.insert(
276 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
277 flash_attn_prefill_src,
278 );
279 sources.insert(
280 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
281 flash_attn_prefill_src,
282 );
283 sources.insert(
284 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
285 flash_attn_prefill_src,
286 );
287 sources.insert(
288 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
289 flash_attn_prefill_src,
290 );
291 sources.insert(
295 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
296 flash_attn_prefill_src,
297 );
298 sources.insert(
299 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
300 flash_attn_prefill_src,
301 );
302 sources.insert(
303 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
304 flash_attn_prefill_src,
305 );
306 sources.insert(
307 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
308 flash_attn_prefill_src,
309 );
310
311 let flash_attn_vec_src: &'static str =
314 include_str!("shaders/flash_attn_vec.metal");
315 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
316 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
317 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
318 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
319 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
321 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
322
323 let rope_src: &'static str = include_str!("shaders/rope.metal");
325 sources.insert("rope_f32".into(), rope_src);
326 sources.insert("rope_f16".into(), rope_src);
327 sources.insert("rope_bf16".into(), rope_src);
328 sources.insert("rope_neox_bf16".into(), rope_src);
329 sources.insert("rope_neox_f32".into(), rope_src);
330 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
331 sources.insert("rms_norm_f32".into(), rms_norm_src);
332 sources.insert("rms_norm_f16".into(), rms_norm_src);
333 sources.insert("rms_norm_bf16".into(), rms_norm_src);
334 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
335 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
336 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
337 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
338 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
339 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
340 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
342 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
343 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
344 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
346 sources.insert("l2_norm_f32".into(), l2_norm_src);
347 sources.insert("l2_norm_f16".into(), l2_norm_src);
348 sources.insert("l2_norm_bf16".into(), l2_norm_src);
349 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
351 sources.insert("cumsum_f32".into(), cumsum_src);
352 sources.insert("cumsum_bf16".into(), cumsum_src);
353 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
355 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
356 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
357 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
358 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
359 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
361 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
362 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
363 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
365 sources.insert("rope_multi_f32".into(), rope_multi_src);
366 sources.insert("rope_multi_bf16".into(), rope_multi_src);
367 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
369 sources.insert("gated_delta_net_f32".into(), gdn_src);
370 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
372 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
373 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
374 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
375 sources.insert("silu_mul_f32".into(), silu_mul_src);
376 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
377 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
378 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
379 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
380 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
381 sources.insert("gelu_f32".into(), gelu_src);
382 sources.insert("gelu_f16".into(), gelu_src);
383 sources.insert("gelu_bf16".into(), gelu_src);
384 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
385 sources.insert("softmax_f32".into(), softmax_src);
386 sources.insert("softmax_f16".into(), softmax_src);
387 sources.insert("softmax_bf16".into(), softmax_src);
388 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
389 sources.insert("softcap_f32".into(), softcap_src);
390 sources.insert("softcap_f16".into(), softcap_src);
391 sources.insert("softcap_bf16".into(), softcap_src);
392
393 let fused_norm_add_src: &'static str =
396 include_str!("shaders/fused_norm_add_bf16.metal");
397 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
398 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
399
400 let fused_hnr_f32_src: &'static str =
402 include_str!("shaders/fused_head_norm_rope_f32.metal");
403 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
404
405 let fused_hnr_bf16_src: &'static str =
408 include_str!("shaders/fused_head_norm_rope_bf16.metal");
409 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
410 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
411
412 let fused_norm_add_f32_src: &'static str =
414 include_str!("shaders/fused_norm_add_f32.metal");
415 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
416 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
417 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
418 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
419 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
420 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
421 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
422 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
423
424 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
426 sources.insert("argsort_desc_f32".into(), argsort_src);
427
428 let gather_src: &'static str = include_str!("shaders/gather.metal");
430 sources.insert("gather_f32".into(), gather_src);
431
432 let kv_cache_copy_src: &'static str =
434 include_str!("shaders/kv_cache_copy.metal");
435 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
436 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
437
438 let copy_src: &'static str = include_str!("shaders/copy.metal");
440 sources.insert("strided_copy_f32".into(), copy_src);
441 sources.insert("offset_copy_f32".into(), copy_src);
442
443 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
445 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
446 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
447 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
448 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
450 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
452
453 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
455 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
456 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
457 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
459 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
460 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
461 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
462
463 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
465 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
466 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
467 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
469 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
470
471 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
473 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
474 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
476
477 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
479 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
480 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
481
482 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
484 sources.insert("argmax_f32".into(), argmax_src);
485 let softmax_sample_src: &'static str =
486 include_str!("shaders/softmax_sample.metal");
487 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
488 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
490 sources.insert("top_k_f32".into(), top_k_src);
491
492 let moe_stk_src: &'static str =
495 include_str!("shaders/moe_softmax_topk.metal");
496 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
497 let moe_wr_src: &'static str =
498 include_str!("shaders/moe_weighted_reduce.metal");
499 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
500 let sdpa_decode_src: &'static str =
501 include_str!("shaders/sdpa_decode.metal");
502 sources.insert("sdpa_decode".into(), sdpa_decode_src);
503
504 Self {
505 cache: HashMap::new(),
506 sources,
507 }
508 }
509
510 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
513 let name = name.into();
514 self.cache.remove(&name);
516 self.sources.insert(name, source);
517 }
518
519 pub fn get_pipeline(
531 &mut self,
532 name: &str,
533 device: &metal::DeviceRef,
534 ) -> Result<&ComputePipelineState> {
535 if !self.cache.contains_key(name) {
536 let source = self.sources.get(name).ok_or_else(|| {
538 MlxError::KernelNotFound(name.to_string())
539 })?;
540
541 let compile_opts = metal::CompileOptions::new();
542 let library = device
543 .new_library_with_source(source, &compile_opts)
544 .map_err(|msg| MlxError::ShaderCompilationError {
545 name: name.to_string(),
546 message: msg,
547 })?;
548
549 let function = library
550 .get_function(name, None)
551 .map_err(|msg| MlxError::ShaderCompilationError {
552 name: name.to_string(),
553 message: msg,
554 })?;
555
556 let pipeline = device
557 .new_compute_pipeline_state_with_function(&function)
558 .map_err(|msg| MlxError::ShaderCompilationError {
559 name: name.to_string(),
560 message: msg,
561 })?;
562
563 self.cache.insert(name.to_string(), pipeline);
564 }
565
566 self.cache.get(name).ok_or_else(|| {
569 MlxError::KernelNotFound(name.to_string())
570 })
571 }
572
573 pub fn get_pipeline_with_constants(
595 &mut self,
596 name: &str,
597 device: &metal::DeviceRef,
598 bool_constants: &[(usize, bool)],
599 int_constants: &[(usize, i32)],
600 ) -> Result<&ComputePipelineState> {
601 let mut cache_key = name.to_string();
606 for &(index, value) in bool_constants {
607 cache_key.push('|');
608 cache_key.push_str(&index.to_string());
609 cache_key.push_str(if value { ":b1" } else { ":b0" });
610 }
611 for &(index, value) in int_constants {
612 cache_key.push('|');
613 cache_key.push_str(&index.to_string());
614 cache_key.push(':');
615 cache_key.push('i');
616 cache_key.push_str(&value.to_string());
617 }
618
619 if !self.cache.contains_key(&cache_key) {
620 let source = self.sources.get(name).ok_or_else(|| {
622 MlxError::KernelNotFound(name.to_string())
623 })?;
624
625 let compile_opts = metal::CompileOptions::new();
626 let library = device
627 .new_library_with_source(source, &compile_opts)
628 .map_err(|msg| MlxError::ShaderCompilationError {
629 name: name.to_string(),
630 message: msg,
631 })?;
632
633 let fcv = FunctionConstantValues::new();
638
639 for &(index, value) in bool_constants {
640 let v: u8 = if value { 1 } else { 0 };
643 fcv.set_constant_value_at_index(
644 (&v as *const u8).cast::<std::ffi::c_void>(),
645 MTLDataType::Bool,
646 index as u64,
647 );
648 }
649
650 for &(index, value) in int_constants {
651 fcv.set_constant_value_at_index(
655 (&value as *const i32).cast::<std::ffi::c_void>(),
656 MTLDataType::Int,
657 index as u64,
658 );
659 }
660
661 let function = library
662 .get_function(name, Some(fcv))
663 .map_err(|msg| MlxError::ShaderCompilationError {
664 name: name.to_string(),
665 message: msg,
666 })?;
667
668 let pipeline = device
669 .new_compute_pipeline_state_with_function(&function)
670 .map_err(|msg| MlxError::ShaderCompilationError {
671 name: name.to_string(),
672 message: msg,
673 })?;
674
675 self.cache.insert(cache_key.clone(), pipeline);
676 }
677
678 self.cache.get(&cache_key).ok_or_else(|| {
679 MlxError::KernelNotFound(name.to_string())
680 })
681 }
682
683 pub fn get_pipeline_with_bool_constants(
701 &mut self,
702 name: &str,
703 device: &metal::DeviceRef,
704 bool_constants: &[(usize, bool)],
705 ) -> Result<&ComputePipelineState> {
706 self.get_pipeline_with_constants(name, device, bool_constants, &[])
707 }
708
709 pub fn is_cached(&self, name: &str) -> bool {
711 self.cache.contains_key(name)
712 }
713
714 pub fn cached_count(&self) -> usize {
716 self.cache.len()
717 }
718
719 pub fn source_count(&self) -> usize {
721 self.sources.len()
722 }
723}
724
725impl Default for KernelRegistry {
726 fn default() -> Self {
727 Self::new()
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 const INT_FC_TEST_SHADER: &str = r#"
745#include <metal_stdlib>
746using namespace metal;
747
748constant int test_N [[function_constant(100)]];
749
750kernel void int_fc_test_kernel(
751 device int* out [[buffer(0)]],
752 uint tid [[thread_position_in_grid]])
753{
754 if (tid == 0) {
755 out[0] = test_N;
756 }
757}
758"#;
759
760 #[test]
768 fn test_int_fc_distinct_pipelines_and_bool_compat() {
769 let device = metal::Device::system_default()
770 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
771
772 let mut registry = KernelRegistry::new();
773
774 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
777
778 let p4_ptr = registry
780 .get_pipeline_with_constants(
781 "int_fc_test_kernel",
782 &device,
783 &[], &[(100, 4_i32)], )
786 .expect("pipeline N=4 should compile") as *const _;
787
788 let count_after_n4 = registry.cached_count();
792
793 let p8_ptr = registry
795 .get_pipeline_with_constants(
796 "int_fc_test_kernel",
797 &device,
798 &[],
799 &[(100, 8_i32)],
800 )
801 .expect("pipeline N=8 should compile") as *const _;
802
803 assert_eq!(
805 registry.cached_count(),
806 count_after_n4 + 1,
807 "N=8 must produce a new cache entry"
808 );
809
810 assert_ne!(
812 p4_ptr, p8_ptr,
813 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
814 );
815
816 let p4_again_ptr = registry
819 .get_pipeline_with_constants(
820 "int_fc_test_kernel",
821 &device,
822 &[],
823 &[(100, 4_i32)],
824 )
825 .expect("pipeline N=4 cache hit should succeed") as *const _;
826
827 assert_eq!(
828 registry.cached_count(),
829 count_after_n4 + 1,
830 "repeated N=4 call must be a cache hit, not a new entry"
831 );
832 assert_eq!(
833 p4_ptr, p4_again_ptr,
834 "repeated N=4 call must return the same pipeline pointer"
835 );
836
837 const BARE_SHADER: &str = r#"
851#include <metal_stdlib>
852using namespace metal;
853kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
854 if (tid == 0) { out[0] = 42; }
855}
856"#;
857 registry.register_source("bare_kernel", BARE_SHADER);
858
859 let count_before_bool = registry.cached_count();
860 let _bool_pipeline = registry
861 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
862 .expect("bool-constants wrapper with empty slice must succeed");
863
864 assert_eq!(
865 registry.cached_count(),
866 count_before_bool + 1,
867 "bool-constants wrapper must insert one new cache entry"
868 );
869 }
870}