1use std::collections::HashMap;
9
10use metal::ComputePipelineState;
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
31 cache: HashMap<String, ComputePipelineState>,
33 sources: HashMap<String, &'static str>,
37}
38
39impl KernelRegistry {
40 pub fn new() -> Self {
44 let mut sources = HashMap::new();
45
46 sources.insert(
48 "placeholder".into(),
49 include_str!("shaders/placeholder.metal"),
50 );
51 sources.insert(
52 "quantized_matmul".into(),
53 include_str!("shaders/quantized_matmul.metal"),
54 );
55 sources.insert(
56 "quantized_matmul_simd".into(),
57 include_str!("shaders/quantized_matmul.metal"),
58 );
59 sources.insert(
60 "quantized_matmul_simd_bf16".into(),
61 include_str!("shaders/quantized_matmul.metal"),
62 );
63 sources.insert(
64 "quantized_matmul_simd_bf16_expert".into(),
65 include_str!("shaders/quantized_matmul.metal"),
66 );
67
68 let ggml_src: &'static str =
70 include_str!("shaders/quantized_matmul_ggml.metal");
71 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
72 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
73 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
74
75 sources.insert(
77 "quantized_matmul_id".into(),
78 include_str!("shaders/quantized_matmul_id.metal"),
79 );
80
81 let ggml_id_src: &'static str =
83 include_str!("shaders/quantized_matmul_id_ggml.metal");
84 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
85 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
86 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
87
88 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
90 sources.insert("embedding_gather_4bit".into(), embedding_src);
91 sources.insert("embedding_gather_6bit".into(), embedding_src);
92
93 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
95 sources.insert("moe_gate".into(), moe_gate_src);
96
97 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
99 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
100 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
101 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
102 sources.insert("moe_accumulate".into(), moe_dispatch_src);
103 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
104 sources.insert("zero_buffer".into(), moe_dispatch_src);
105 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
106 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
107
108 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
110 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
111 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
112
113 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
115 sources.insert("elementwise_add_f32".into(), elementwise_src);
116 sources.insert("elementwise_add_f16".into(), elementwise_src);
117 sources.insert("elementwise_mul_f32".into(), elementwise_src);
118 sources.insert("elementwise_mul_f16".into(), elementwise_src);
119 sources.insert("elementwise_add_bf16".into(), elementwise_src);
120 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
121 sources.insert("cast_f16_to_f32".into(), elementwise_src);
122 sources.insert("cast_f32_to_f16".into(), elementwise_src);
123 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
124 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
125 sources.insert("scalar_mul_bf16".into(), elementwise_src);
126 sources.insert("scalar_mul_f32".into(), elementwise_src);
127 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
128 sources.insert("permute_021_bf16".into(), elementwise_src);
129 sources.insert("transpose_2d_f32".into(), elementwise_src);
130 sources.insert("transpose_2d_f16".into(), elementwise_src);
131
132 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
134 sources.insert("sdpa".into(), sdpa_src);
135 sources.insert("sdpa_bf16".into(), sdpa_src);
136 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
137 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
138 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
139
140 let flash_attn_vec_src: &'static str =
143 include_str!("shaders/flash_attn_vec.metal");
144 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
145 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
146 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
147 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
148 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
150 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
151
152 let rope_src: &'static str = include_str!("shaders/rope.metal");
154 sources.insert("rope_f32".into(), rope_src);
155 sources.insert("rope_f16".into(), rope_src);
156 sources.insert("rope_bf16".into(), rope_src);
157 sources.insert("rope_neox_bf16".into(), rope_src);
158 sources.insert("rope_neox_f32".into(), rope_src);
159 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
160 sources.insert("rms_norm_f32".into(), rms_norm_src);
161 sources.insert("rms_norm_f16".into(), rms_norm_src);
162 sources.insert("rms_norm_bf16".into(), rms_norm_src);
163 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
164 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
165 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
167 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
168 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
169 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
170 sources.insert("gelu_f32".into(), gelu_src);
171 sources.insert("gelu_f16".into(), gelu_src);
172 sources.insert("gelu_bf16".into(), gelu_src);
173 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
174 sources.insert("softmax_f32".into(), softmax_src);
175 sources.insert("softmax_f16".into(), softmax_src);
176 sources.insert("softmax_bf16".into(), softmax_src);
177 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
178 sources.insert("softcap_f32".into(), softcap_src);
179 sources.insert("softcap_f16".into(), softcap_src);
180 sources.insert("softcap_bf16".into(), softcap_src);
181
182 let fused_norm_add_src: &'static str =
185 include_str!("shaders/fused_norm_add_bf16.metal");
186 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
187 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
188
189 let fused_hnr_f32_src: &'static str =
191 include_str!("shaders/fused_head_norm_rope_f32.metal");
192 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
193
194 let fused_norm_add_f32_src: &'static str =
196 include_str!("shaders/fused_norm_add_f32.metal");
197 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
198 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
199 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
200 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
201 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
202
203 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
205 sources.insert("argsort_desc_f32".into(), argsort_src);
206
207 let gather_src: &'static str = include_str!("shaders/gather.metal");
209 sources.insert("gather_f32".into(), gather_src);
210
211 let kv_cache_copy_src: &'static str =
213 include_str!("shaders/kv_cache_copy.metal");
214 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
215 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
216
217 let copy_src: &'static str = include_str!("shaders/copy.metal");
219 sources.insert("strided_copy_f32".into(), copy_src);
220
221 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
223 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
224 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
225 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
226
227 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
229 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
230 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
231
232 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
234 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
235 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
236
237 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
239 sources.insert("argmax_f32".into(), argmax_src);
240 let softmax_sample_src: &'static str =
241 include_str!("shaders/softmax_sample.metal");
242 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
243
244 Self {
245 cache: HashMap::new(),
246 sources,
247 }
248 }
249
250 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
253 let name = name.into();
254 self.cache.remove(&name);
256 self.sources.insert(name, source);
257 }
258
259 pub fn get_pipeline(
271 &mut self,
272 name: &str,
273 device: &metal::DeviceRef,
274 ) -> Result<&ComputePipelineState> {
275 if !self.cache.contains_key(name) {
276 let source = self.sources.get(name).ok_or_else(|| {
278 MlxError::KernelNotFound(name.to_string())
279 })?;
280
281 let compile_opts = metal::CompileOptions::new();
282 let library = device
283 .new_library_with_source(source, &compile_opts)
284 .map_err(|msg| MlxError::ShaderCompilationError {
285 name: name.to_string(),
286 message: msg,
287 })?;
288
289 let function = library
290 .get_function(name, None)
291 .map_err(|msg| MlxError::ShaderCompilationError {
292 name: name.to_string(),
293 message: msg,
294 })?;
295
296 let pipeline = device
297 .new_compute_pipeline_state_with_function(&function)
298 .map_err(|msg| MlxError::ShaderCompilationError {
299 name: name.to_string(),
300 message: msg,
301 })?;
302
303 self.cache.insert(name.to_string(), pipeline);
304 }
305
306 self.cache.get(name).ok_or_else(|| {
309 MlxError::KernelNotFound(name.to_string())
310 })
311 }
312
313 pub fn is_cached(&self, name: &str) -> bool {
315 self.cache.contains_key(name)
316 }
317
318 pub fn cached_count(&self) -> usize {
320 self.cache.len()
321 }
322
323 pub fn source_count(&self) -> usize {
325 self.sources.len()
326 }
327}
328
329impl Default for KernelRegistry {
330 fn default() -> Self {
331 Self::new()
332 }
333}