1use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81 let ggml_mm_src: &'static str =
87 include_str!("shaders/quantized_matmul_mm.metal");
88 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92 let ggml_mm_tensor_src: &'static str =
103 include_str!("shaders/quantized_matmul_mm_tensor.metal");
104 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110 let dense_mm_bf16_tensor_src: &'static str =
117 include_str!("shaders/dense_mm_bf16_tensor.metal");
118 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120 let dense_gemv_bf16_src: &'static str =
127 include_str!("shaders/dense_gemv_bf16.metal");
128 sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
129
130 let scale_mask_softmax_src: &'static str =
136 include_str!("shaders/scale_mask_softmax.metal");
137 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
138
139 sources.insert(
141 "quantized_matmul_id".into(),
142 include_str!("shaders/quantized_matmul_id.metal"),
143 );
144
145 let ggml_id_src: &'static str =
147 include_str!("shaders/quantized_matmul_id_ggml.metal");
148 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
149 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
150 sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
151 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
152
153 let ggml_id_mm_src: &'static str =
161 include_str!("shaders/quantized_matmul_id_mm.metal");
162 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
163 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
164 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
165 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
166 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
167
168 let ggml_id_mm_tensor_src: &'static str =
174 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
175 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
176 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
177 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
178
179 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
181 sources.insert("embedding_gather_4bit".into(), embedding_src);
182 sources.insert("embedding_gather_6bit".into(), embedding_src);
183
184 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
186 sources.insert("moe_gate".into(), moe_gate_src);
187
188 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
190 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
191 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
192 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
193 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
194 sources.insert("moe_accumulate".into(), moe_dispatch_src);
195 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
196 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
197 sources.insert("zero_buffer".into(), moe_dispatch_src);
198 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
199 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
200 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
202 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
203 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
204
205 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
207 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
208 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
209 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
210 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
211 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
213 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
214 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
216
217 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
219 sources.insert("elementwise_add_f32".into(), elementwise_src);
220 sources.insert("elementwise_add_f16".into(), elementwise_src);
221 sources.insert("elementwise_mul_f32".into(), elementwise_src);
222 sources.insert("elementwise_mul_f16".into(), elementwise_src);
223 sources.insert("elementwise_add_bf16".into(), elementwise_src);
224 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
225 sources.insert("cast_f16_to_f32".into(), elementwise_src);
226 sources.insert("cast_f32_to_f16".into(), elementwise_src);
227 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
228 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
229 sources.insert("scalar_mul_bf16".into(), elementwise_src);
230 sources.insert("scalar_mul_f32".into(), elementwise_src);
231 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
232 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
233 sources.insert("permute_021_bf16".into(), elementwise_src);
234 sources.insert("transpose_last2_bf16".into(), elementwise_src);
235 sources.insert("permute_021_f32".into(), elementwise_src);
236 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
237 sources.insert("transpose_2d_f32".into(), elementwise_src);
238 sources.insert("transpose_2d_f16".into(), elementwise_src);
239
240 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
242 sources.insert("sdpa".into(), sdpa_src);
243 sources.insert("sdpa_bf16".into(), sdpa_src);
244 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
245 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
246 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
247
248 let flash_attn_prefill_src: &'static str =
253 include_str!("shaders/flash_attn_prefill.metal");
254 sources.insert(
256 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
257 flash_attn_prefill_src,
258 );
259 sources.insert(
260 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
261 flash_attn_prefill_src,
262 );
263 sources.insert(
264 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
265 flash_attn_prefill_src,
266 );
267 sources.insert(
268 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
269 flash_attn_prefill_src,
270 );
271 sources.insert(
272 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
273 flash_attn_prefill_src,
274 );
275 sources.insert(
276 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
277 flash_attn_prefill_src,
278 );
279 sources.insert(
283 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
284 flash_attn_prefill_src,
285 );
286 sources.insert(
287 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
288 flash_attn_prefill_src,
289 );
290 sources.insert(
291 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
292 flash_attn_prefill_src,
293 );
294 sources.insert(
295 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
296 flash_attn_prefill_src,
297 );
298
299 let flash_attn_vec_src: &'static str =
302 include_str!("shaders/flash_attn_vec.metal");
303 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
304 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
305 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
306 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
307 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
309 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
310
311 let rope_src: &'static str = include_str!("shaders/rope.metal");
313 sources.insert("rope_f32".into(), rope_src);
314 sources.insert("rope_f16".into(), rope_src);
315 sources.insert("rope_bf16".into(), rope_src);
316 sources.insert("rope_neox_bf16".into(), rope_src);
317 sources.insert("rope_neox_f32".into(), rope_src);
318 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
319 sources.insert("rms_norm_f32".into(), rms_norm_src);
320 sources.insert("rms_norm_f16".into(), rms_norm_src);
321 sources.insert("rms_norm_bf16".into(), rms_norm_src);
322 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
323 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
324 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
325 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
326 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
327 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
328 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
330 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
331 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
332 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
334 sources.insert("l2_norm_f32".into(), l2_norm_src);
335 sources.insert("l2_norm_f16".into(), l2_norm_src);
336 sources.insert("l2_norm_bf16".into(), l2_norm_src);
337 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
339 sources.insert("cumsum_f32".into(), cumsum_src);
340 sources.insert("cumsum_bf16".into(), cumsum_src);
341 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
343 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
344 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
345 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
346 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
347 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
349 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
350 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
351 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
353 sources.insert("rope_multi_f32".into(), rope_multi_src);
354 sources.insert("rope_multi_bf16".into(), rope_multi_src);
355 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
357 sources.insert("gated_delta_net_f32".into(), gdn_src);
358 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
360 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
361 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
362 let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
363 sources.insert("silu_mul_f32".into(), silu_mul_src);
364 let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
365 sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
366 let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
367 sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
368 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
369 sources.insert("gelu_f32".into(), gelu_src);
370 sources.insert("gelu_f16".into(), gelu_src);
371 sources.insert("gelu_bf16".into(), gelu_src);
372 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
373 sources.insert("softmax_f32".into(), softmax_src);
374 sources.insert("softmax_f16".into(), softmax_src);
375 sources.insert("softmax_bf16".into(), softmax_src);
376 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
377 sources.insert("softcap_f32".into(), softcap_src);
378 sources.insert("softcap_f16".into(), softcap_src);
379 sources.insert("softcap_bf16".into(), softcap_src);
380
381 let fused_norm_add_src: &'static str =
384 include_str!("shaders/fused_norm_add_bf16.metal");
385 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
386 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
387
388 let fused_hnr_f32_src: &'static str =
390 include_str!("shaders/fused_head_norm_rope_f32.metal");
391 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
392
393 let fused_hnr_bf16_src: &'static str =
396 include_str!("shaders/fused_head_norm_rope_bf16.metal");
397 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
398 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
399
400 let fused_norm_add_f32_src: &'static str =
402 include_str!("shaders/fused_norm_add_f32.metal");
403 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
404 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
405 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
406 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
407 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
408 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
409 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
410 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
411
412 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
414 sources.insert("argsort_desc_f32".into(), argsort_src);
415
416 let gather_src: &'static str = include_str!("shaders/gather.metal");
418 sources.insert("gather_f32".into(), gather_src);
419
420 let kv_cache_copy_src: &'static str =
422 include_str!("shaders/kv_cache_copy.metal");
423 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
424 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
425
426 let copy_src: &'static str = include_str!("shaders/copy.metal");
428 sources.insert("strided_copy_f32".into(), copy_src);
429 sources.insert("offset_copy_f32".into(), copy_src);
430
431 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
433 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
434 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
435 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
436 sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
438 sources.insert("dense_matvec_f32".into(), dense_gemm_src);
440
441 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
443 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
444 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
445 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
447 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
448 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
449 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
450
451 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
453 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
454 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
455 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
457 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
458
459 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
461 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
462 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
464
465 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
467 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
468 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
469
470 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
472 sources.insert("argmax_f32".into(), argmax_src);
473 let softmax_sample_src: &'static str =
474 include_str!("shaders/softmax_sample.metal");
475 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
476 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
478 sources.insert("top_k_f32".into(), top_k_src);
479
480 let moe_stk_src: &'static str =
483 include_str!("shaders/moe_softmax_topk.metal");
484 sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
485 let moe_wr_src: &'static str =
486 include_str!("shaders/moe_weighted_reduce.metal");
487 sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
488 let sdpa_decode_src: &'static str =
489 include_str!("shaders/sdpa_decode.metal");
490 sources.insert("sdpa_decode".into(), sdpa_decode_src);
491
492 Self {
493 cache: HashMap::new(),
494 sources,
495 }
496 }
497
498 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
501 let name = name.into();
502 self.cache.remove(&name);
504 self.sources.insert(name, source);
505 }
506
507 pub fn get_pipeline(
519 &mut self,
520 name: &str,
521 device: &metal::DeviceRef,
522 ) -> Result<&ComputePipelineState> {
523 if !self.cache.contains_key(name) {
524 let source = self.sources.get(name).ok_or_else(|| {
526 MlxError::KernelNotFound(name.to_string())
527 })?;
528
529 let compile_opts = metal::CompileOptions::new();
530 let library = device
531 .new_library_with_source(source, &compile_opts)
532 .map_err(|msg| MlxError::ShaderCompilationError {
533 name: name.to_string(),
534 message: msg,
535 })?;
536
537 let function = library
538 .get_function(name, None)
539 .map_err(|msg| MlxError::ShaderCompilationError {
540 name: name.to_string(),
541 message: msg,
542 })?;
543
544 let pipeline = device
545 .new_compute_pipeline_state_with_function(&function)
546 .map_err(|msg| MlxError::ShaderCompilationError {
547 name: name.to_string(),
548 message: msg,
549 })?;
550
551 self.cache.insert(name.to_string(), pipeline);
552 }
553
554 self.cache.get(name).ok_or_else(|| {
557 MlxError::KernelNotFound(name.to_string())
558 })
559 }
560
561 pub fn get_pipeline_with_constants(
583 &mut self,
584 name: &str,
585 device: &metal::DeviceRef,
586 bool_constants: &[(usize, bool)],
587 int_constants: &[(usize, i32)],
588 ) -> Result<&ComputePipelineState> {
589 let mut cache_key = name.to_string();
594 for &(index, value) in bool_constants {
595 cache_key.push('|');
596 cache_key.push_str(&index.to_string());
597 cache_key.push_str(if value { ":b1" } else { ":b0" });
598 }
599 for &(index, value) in int_constants {
600 cache_key.push('|');
601 cache_key.push_str(&index.to_string());
602 cache_key.push(':');
603 cache_key.push('i');
604 cache_key.push_str(&value.to_string());
605 }
606
607 if !self.cache.contains_key(&cache_key) {
608 let source = self.sources.get(name).ok_or_else(|| {
610 MlxError::KernelNotFound(name.to_string())
611 })?;
612
613 let compile_opts = metal::CompileOptions::new();
614 let library = device
615 .new_library_with_source(source, &compile_opts)
616 .map_err(|msg| MlxError::ShaderCompilationError {
617 name: name.to_string(),
618 message: msg,
619 })?;
620
621 let fcv = FunctionConstantValues::new();
626
627 for &(index, value) in bool_constants {
628 let v: u8 = if value { 1 } else { 0 };
631 fcv.set_constant_value_at_index(
632 (&v as *const u8).cast::<std::ffi::c_void>(),
633 MTLDataType::Bool,
634 index as u64,
635 );
636 }
637
638 for &(index, value) in int_constants {
639 fcv.set_constant_value_at_index(
643 (&value as *const i32).cast::<std::ffi::c_void>(),
644 MTLDataType::Int,
645 index as u64,
646 );
647 }
648
649 let function = library
650 .get_function(name, Some(fcv))
651 .map_err(|msg| MlxError::ShaderCompilationError {
652 name: name.to_string(),
653 message: msg,
654 })?;
655
656 let pipeline = device
657 .new_compute_pipeline_state_with_function(&function)
658 .map_err(|msg| MlxError::ShaderCompilationError {
659 name: name.to_string(),
660 message: msg,
661 })?;
662
663 self.cache.insert(cache_key.clone(), pipeline);
664 }
665
666 self.cache.get(&cache_key).ok_or_else(|| {
667 MlxError::KernelNotFound(name.to_string())
668 })
669 }
670
671 pub fn get_pipeline_with_bool_constants(
689 &mut self,
690 name: &str,
691 device: &metal::DeviceRef,
692 bool_constants: &[(usize, bool)],
693 ) -> Result<&ComputePipelineState> {
694 self.get_pipeline_with_constants(name, device, bool_constants, &[])
695 }
696
697 pub fn is_cached(&self, name: &str) -> bool {
699 self.cache.contains_key(name)
700 }
701
702 pub fn cached_count(&self) -> usize {
704 self.cache.len()
705 }
706
707 pub fn source_count(&self) -> usize {
709 self.sources.len()
710 }
711}
712
713impl Default for KernelRegistry {
714 fn default() -> Self {
715 Self::new()
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 const INT_FC_TEST_SHADER: &str = r#"
733#include <metal_stdlib>
734using namespace metal;
735
736constant int test_N [[function_constant(100)]];
737
738kernel void int_fc_test_kernel(
739 device int* out [[buffer(0)]],
740 uint tid [[thread_position_in_grid]])
741{
742 if (tid == 0) {
743 out[0] = test_N;
744 }
745}
746"#;
747
748 #[test]
756 fn test_int_fc_distinct_pipelines_and_bool_compat() {
757 let device = metal::Device::system_default()
758 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
759
760 let mut registry = KernelRegistry::new();
761
762 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
765
766 let p4_ptr = registry
768 .get_pipeline_with_constants(
769 "int_fc_test_kernel",
770 &device,
771 &[], &[(100, 4_i32)], )
774 .expect("pipeline N=4 should compile") as *const _;
775
776 let count_after_n4 = registry.cached_count();
780
781 let p8_ptr = registry
783 .get_pipeline_with_constants(
784 "int_fc_test_kernel",
785 &device,
786 &[],
787 &[(100, 8_i32)],
788 )
789 .expect("pipeline N=8 should compile") as *const _;
790
791 assert_eq!(
793 registry.cached_count(),
794 count_after_n4 + 1,
795 "N=8 must produce a new cache entry"
796 );
797
798 assert_ne!(
800 p4_ptr, p8_ptr,
801 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
802 );
803
804 let p4_again_ptr = registry
807 .get_pipeline_with_constants(
808 "int_fc_test_kernel",
809 &device,
810 &[],
811 &[(100, 4_i32)],
812 )
813 .expect("pipeline N=4 cache hit should succeed") as *const _;
814
815 assert_eq!(
816 registry.cached_count(),
817 count_after_n4 + 1,
818 "repeated N=4 call must be a cache hit, not a new entry"
819 );
820 assert_eq!(
821 p4_ptr, p4_again_ptr,
822 "repeated N=4 call must return the same pipeline pointer"
823 );
824
825 const BARE_SHADER: &str = r#"
839#include <metal_stdlib>
840using namespace metal;
841kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
842 if (tid == 0) { out[0] = 42; }
843}
844"#;
845 registry.register_source("bare_kernel", BARE_SHADER);
846
847 let count_before_bool = registry.cached_count();
848 let _bool_pipeline = registry
849 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
850 .expect("bool-constants wrapper with empty slice must succeed");
851
852 assert_eq!(
853 registry.cached_count(),
854 count_before_bool + 1,
855 "bool-constants wrapper must insert one new cache entry"
856 );
857 }
858}