1use std::collections::HashMap;
9
10use metal::{ComputePipelineState, FunctionConstantValues, MTLDataType};
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
37 cache: HashMap<String, ComputePipelineState>,
39 sources: HashMap<String, &'static str>,
43}
44
45impl KernelRegistry {
46 pub fn new() -> Self {
50 let mut sources = HashMap::new();
51
52 sources.insert(
54 "placeholder".into(),
55 include_str!("shaders/placeholder.metal"),
56 );
57 sources.insert(
58 "quantized_matmul".into(),
59 include_str!("shaders/quantized_matmul.metal"),
60 );
61 sources.insert(
62 "quantized_matmul_simd".into(),
63 include_str!("shaders/quantized_matmul.metal"),
64 );
65 sources.insert(
66 "quantized_matmul_simd_bf16".into(),
67 include_str!("shaders/quantized_matmul.metal"),
68 );
69 sources.insert(
70 "quantized_matmul_simd_bf16_expert".into(),
71 include_str!("shaders/quantized_matmul.metal"),
72 );
73
74 let ggml_src: &'static str =
76 include_str!("shaders/quantized_matmul_ggml.metal");
77 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
78 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
79 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
80
81 let ggml_mm_src: &'static str =
87 include_str!("shaders/quantized_matmul_mm.metal");
88 sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
89 sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
90 sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
91
92 let ggml_mm_tensor_src: &'static str =
103 include_str!("shaders/quantized_matmul_mm_tensor.metal");
104 sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
105 sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
106 sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
107 sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
108 sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
109
110 let dense_mm_bf16_tensor_src: &'static str =
117 include_str!("shaders/dense_mm_bf16_tensor.metal");
118 sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
119
120 let scale_mask_softmax_src: &'static str =
126 include_str!("shaders/scale_mask_softmax.metal");
127 sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
128
129 sources.insert(
131 "quantized_matmul_id".into(),
132 include_str!("shaders/quantized_matmul_id.metal"),
133 );
134
135 let ggml_id_src: &'static str =
137 include_str!("shaders/quantized_matmul_id_ggml.metal");
138 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
139 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
140 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
141
142 let ggml_id_mm_src: &'static str =
150 include_str!("shaders/quantized_matmul_id_mm.metal");
151 sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
152 sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
153 sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
154 sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
155 sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
156
157 let ggml_id_mm_tensor_src: &'static str =
163 include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
164 sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
165 sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
166 sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
167
168 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
170 sources.insert("embedding_gather_4bit".into(), embedding_src);
171 sources.insert("embedding_gather_6bit".into(), embedding_src);
172
173 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
175 sources.insert("moe_gate".into(), moe_gate_src);
176
177 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
179 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
180 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
181 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
182 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
183 sources.insert("moe_accumulate".into(), moe_dispatch_src);
184 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
185 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
186 sources.insert("zero_buffer".into(), moe_dispatch_src);
187 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
188 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
189 sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
191 sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
192 sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
193
194 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
196 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
197 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
198 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
199 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
200 sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
202 sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
203 sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
205
206 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
208 sources.insert("elementwise_add_f32".into(), elementwise_src);
209 sources.insert("elementwise_add_f16".into(), elementwise_src);
210 sources.insert("elementwise_mul_f32".into(), elementwise_src);
211 sources.insert("elementwise_mul_f16".into(), elementwise_src);
212 sources.insert("elementwise_add_bf16".into(), elementwise_src);
213 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
214 sources.insert("cast_f16_to_f32".into(), elementwise_src);
215 sources.insert("cast_f32_to_f16".into(), elementwise_src);
216 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
217 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
218 sources.insert("scalar_mul_bf16".into(), elementwise_src);
219 sources.insert("scalar_mul_f32".into(), elementwise_src);
220 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
221 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
222 sources.insert("permute_021_bf16".into(), elementwise_src);
223 sources.insert("transpose_last2_bf16".into(), elementwise_src);
224 sources.insert("permute_021_f32".into(), elementwise_src);
225 sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
226 sources.insert("transpose_2d_f32".into(), elementwise_src);
227 sources.insert("transpose_2d_f16".into(), elementwise_src);
228
229 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
231 sources.insert("sdpa".into(), sdpa_src);
232 sources.insert("sdpa_bf16".into(), sdpa_src);
233 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
234 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
235 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
236
237 let flash_attn_prefill_src: &'static str =
242 include_str!("shaders/flash_attn_prefill.metal");
243 sources.insert(
245 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
246 flash_attn_prefill_src,
247 );
248 sources.insert(
249 "steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
250 flash_attn_prefill_src,
251 );
252 sources.insert(
253 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
254 flash_attn_prefill_src,
255 );
256 sources.insert(
257 "steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
258 flash_attn_prefill_src,
259 );
260 sources.insert(
261 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
262 flash_attn_prefill_src,
263 );
264 sources.insert(
265 "steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
266 flash_attn_prefill_src,
267 );
268 sources.insert(
272 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
273 flash_attn_prefill_src,
274 );
275 sources.insert(
276 "steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
277 flash_attn_prefill_src,
278 );
279 sources.insert(
280 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
281 flash_attn_prefill_src,
282 );
283 sources.insert(
284 "steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
285 flash_attn_prefill_src,
286 );
287
288 let flash_attn_vec_src: &'static str =
291 include_str!("shaders/flash_attn_vec.metal");
292 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
293 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
294 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
295 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
296 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
298 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
299
300 let rope_src: &'static str = include_str!("shaders/rope.metal");
302 sources.insert("rope_f32".into(), rope_src);
303 sources.insert("rope_f16".into(), rope_src);
304 sources.insert("rope_bf16".into(), rope_src);
305 sources.insert("rope_neox_bf16".into(), rope_src);
306 sources.insert("rope_neox_f32".into(), rope_src);
307 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
308 sources.insert("rms_norm_f32".into(), rms_norm_src);
309 sources.insert("rms_norm_f16".into(), rms_norm_src);
310 sources.insert("rms_norm_bf16".into(), rms_norm_src);
311 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
312 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
313 sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
314 sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
315 sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
316 sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
317 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
319 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
320 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
321 let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
323 sources.insert("l2_norm_f32".into(), l2_norm_src);
324 sources.insert("l2_norm_f16".into(), l2_norm_src);
325 sources.insert("l2_norm_bf16".into(), l2_norm_src);
326 let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
328 sources.insert("cumsum_f32".into(), cumsum_src);
329 sources.insert("cumsum_bf16".into(), cumsum_src);
330 let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
332 sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
333 sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
334 sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
335 sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
336 let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
338 sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
339 sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
340 let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
342 sources.insert("rope_multi_f32".into(), rope_multi_src);
343 sources.insert("rope_multi_bf16".into(), rope_multi_src);
344 let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
346 sources.insert("gated_delta_net_f32".into(), gdn_src);
347 let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
349 sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
350 sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
351 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
352 sources.insert("gelu_f32".into(), gelu_src);
353 sources.insert("gelu_f16".into(), gelu_src);
354 sources.insert("gelu_bf16".into(), gelu_src);
355 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
356 sources.insert("softmax_f32".into(), softmax_src);
357 sources.insert("softmax_f16".into(), softmax_src);
358 sources.insert("softmax_bf16".into(), softmax_src);
359 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
360 sources.insert("softcap_f32".into(), softcap_src);
361 sources.insert("softcap_f16".into(), softcap_src);
362 sources.insert("softcap_bf16".into(), softcap_src);
363
364 let fused_norm_add_src: &'static str =
367 include_str!("shaders/fused_norm_add_bf16.metal");
368 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
369 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
370
371 let fused_hnr_f32_src: &'static str =
373 include_str!("shaders/fused_head_norm_rope_f32.metal");
374 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
375
376 let fused_hnr_bf16_src: &'static str =
379 include_str!("shaders/fused_head_norm_rope_bf16.metal");
380 sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
381 sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
382
383 let fused_norm_add_f32_src: &'static str =
385 include_str!("shaders/fused_norm_add_f32.metal");
386 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
387 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
388 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
389 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
390 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
391 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
392 sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
393 sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
394
395 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
397 sources.insert("argsort_desc_f32".into(), argsort_src);
398
399 let gather_src: &'static str = include_str!("shaders/gather.metal");
401 sources.insert("gather_f32".into(), gather_src);
402
403 let kv_cache_copy_src: &'static str =
405 include_str!("shaders/kv_cache_copy.metal");
406 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
407 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
408
409 let copy_src: &'static str = include_str!("shaders/copy.metal");
411 sources.insert("strided_copy_f32".into(), copy_src);
412 sources.insert("offset_copy_f32".into(), copy_src);
413
414 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
416 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
417 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
418 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
419
420 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
422 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
423 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
424 sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
426 sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
427 sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
428 sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
429
430 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
432 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
433 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
434 sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
436 sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
437
438 let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
440 sources.insert("tq_dequantize_kv".into(), tq_dq_src);
441 sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
443
444 let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
446 sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
447 sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
448
449 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
451 sources.insert("argmax_f32".into(), argmax_src);
452 let softmax_sample_src: &'static str =
453 include_str!("shaders/softmax_sample.metal");
454 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
455 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
457 sources.insert("top_k_f32".into(), top_k_src);
458
459 Self {
460 cache: HashMap::new(),
461 sources,
462 }
463 }
464
465 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
468 let name = name.into();
469 self.cache.remove(&name);
471 self.sources.insert(name, source);
472 }
473
474 pub fn get_pipeline(
486 &mut self,
487 name: &str,
488 device: &metal::DeviceRef,
489 ) -> Result<&ComputePipelineState> {
490 if !self.cache.contains_key(name) {
491 let source = self.sources.get(name).ok_or_else(|| {
493 MlxError::KernelNotFound(name.to_string())
494 })?;
495
496 let compile_opts = metal::CompileOptions::new();
497 let library = device
498 .new_library_with_source(source, &compile_opts)
499 .map_err(|msg| MlxError::ShaderCompilationError {
500 name: name.to_string(),
501 message: msg,
502 })?;
503
504 let function = library
505 .get_function(name, None)
506 .map_err(|msg| MlxError::ShaderCompilationError {
507 name: name.to_string(),
508 message: msg,
509 })?;
510
511 let pipeline = device
512 .new_compute_pipeline_state_with_function(&function)
513 .map_err(|msg| MlxError::ShaderCompilationError {
514 name: name.to_string(),
515 message: msg,
516 })?;
517
518 self.cache.insert(name.to_string(), pipeline);
519 }
520
521 self.cache.get(name).ok_or_else(|| {
524 MlxError::KernelNotFound(name.to_string())
525 })
526 }
527
528 pub fn get_pipeline_with_constants(
550 &mut self,
551 name: &str,
552 device: &metal::DeviceRef,
553 bool_constants: &[(usize, bool)],
554 int_constants: &[(usize, i32)],
555 ) -> Result<&ComputePipelineState> {
556 let mut cache_key = name.to_string();
561 for &(index, value) in bool_constants {
562 cache_key.push('|');
563 cache_key.push_str(&index.to_string());
564 cache_key.push_str(if value { ":b1" } else { ":b0" });
565 }
566 for &(index, value) in int_constants {
567 cache_key.push('|');
568 cache_key.push_str(&index.to_string());
569 cache_key.push(':');
570 cache_key.push('i');
571 cache_key.push_str(&value.to_string());
572 }
573
574 if !self.cache.contains_key(&cache_key) {
575 let source = self.sources.get(name).ok_or_else(|| {
577 MlxError::KernelNotFound(name.to_string())
578 })?;
579
580 let compile_opts = metal::CompileOptions::new();
581 let library = device
582 .new_library_with_source(source, &compile_opts)
583 .map_err(|msg| MlxError::ShaderCompilationError {
584 name: name.to_string(),
585 message: msg,
586 })?;
587
588 let fcv = FunctionConstantValues::new();
593
594 for &(index, value) in bool_constants {
595 let v: u8 = if value { 1 } else { 0 };
598 fcv.set_constant_value_at_index(
599 (&v as *const u8).cast::<std::ffi::c_void>(),
600 MTLDataType::Bool,
601 index as u64,
602 );
603 }
604
605 for &(index, value) in int_constants {
606 fcv.set_constant_value_at_index(
610 (&value as *const i32).cast::<std::ffi::c_void>(),
611 MTLDataType::Int,
612 index as u64,
613 );
614 }
615
616 let function = library
617 .get_function(name, Some(fcv))
618 .map_err(|msg| MlxError::ShaderCompilationError {
619 name: name.to_string(),
620 message: msg,
621 })?;
622
623 let pipeline = device
624 .new_compute_pipeline_state_with_function(&function)
625 .map_err(|msg| MlxError::ShaderCompilationError {
626 name: name.to_string(),
627 message: msg,
628 })?;
629
630 self.cache.insert(cache_key.clone(), pipeline);
631 }
632
633 self.cache.get(&cache_key).ok_or_else(|| {
634 MlxError::KernelNotFound(name.to_string())
635 })
636 }
637
638 pub fn get_pipeline_with_bool_constants(
656 &mut self,
657 name: &str,
658 device: &metal::DeviceRef,
659 bool_constants: &[(usize, bool)],
660 ) -> Result<&ComputePipelineState> {
661 self.get_pipeline_with_constants(name, device, bool_constants, &[])
662 }
663
664 pub fn is_cached(&self, name: &str) -> bool {
666 self.cache.contains_key(name)
667 }
668
669 pub fn cached_count(&self) -> usize {
671 self.cache.len()
672 }
673
674 pub fn source_count(&self) -> usize {
676 self.sources.len()
677 }
678}
679
680impl Default for KernelRegistry {
681 fn default() -> Self {
682 Self::new()
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 const INT_FC_TEST_SHADER: &str = r#"
700#include <metal_stdlib>
701using namespace metal;
702
703constant int test_N [[function_constant(100)]];
704
705kernel void int_fc_test_kernel(
706 device int* out [[buffer(0)]],
707 uint tid [[thread_position_in_grid]])
708{
709 if (tid == 0) {
710 out[0] = test_N;
711 }
712}
713"#;
714
715 #[test]
723 fn test_int_fc_distinct_pipelines_and_bool_compat() {
724 let device = metal::Device::system_default()
725 .expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
726
727 let mut registry = KernelRegistry::new();
728
729 registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
732
733 let p4_ptr = registry
735 .get_pipeline_with_constants(
736 "int_fc_test_kernel",
737 &device,
738 &[], &[(100, 4_i32)], )
741 .expect("pipeline N=4 should compile") as *const _;
742
743 let count_after_n4 = registry.cached_count();
747
748 let p8_ptr = registry
750 .get_pipeline_with_constants(
751 "int_fc_test_kernel",
752 &device,
753 &[],
754 &[(100, 8_i32)],
755 )
756 .expect("pipeline N=8 should compile") as *const _;
757
758 assert_eq!(
760 registry.cached_count(),
761 count_after_n4 + 1,
762 "N=8 must produce a new cache entry"
763 );
764
765 assert_ne!(
767 p4_ptr, p8_ptr,
768 "N=4 and N=8 specialisations must be separate ComputePipelineState objects"
769 );
770
771 let p4_again_ptr = registry
774 .get_pipeline_with_constants(
775 "int_fc_test_kernel",
776 &device,
777 &[],
778 &[(100, 4_i32)],
779 )
780 .expect("pipeline N=4 cache hit should succeed") as *const _;
781
782 assert_eq!(
783 registry.cached_count(),
784 count_after_n4 + 1,
785 "repeated N=4 call must be a cache hit, not a new entry"
786 );
787 assert_eq!(
788 p4_ptr, p4_again_ptr,
789 "repeated N=4 call must return the same pipeline pointer"
790 );
791
792 const BARE_SHADER: &str = r#"
806#include <metal_stdlib>
807using namespace metal;
808kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
809 if (tid == 0) { out[0] = 42; }
810}
811"#;
812 registry.register_source("bare_kernel", BARE_SHADER);
813
814 let count_before_bool = registry.cached_count();
815 let _bool_pipeline = registry
816 .get_pipeline_with_bool_constants("bare_kernel", &device, &[])
817 .expect("bool-constants wrapper with empty slice must succeed");
818
819 assert_eq!(
820 registry.cached_count(),
821 count_before_bool + 1,
822 "bool-constants wrapper must insert one new cache entry"
823 );
824 }
825}