1use std::collections::HashMap;
9
10use metal::ComputePipelineState;
11
12use crate::error::{MlxError, Result};
13
14pub struct KernelRegistry {
31 cache: HashMap<String, ComputePipelineState>,
33 sources: HashMap<String, &'static str>,
37}
38
39impl KernelRegistry {
40 pub fn new() -> Self {
44 let mut sources = HashMap::new();
45
46 sources.insert(
48 "placeholder".into(),
49 include_str!("shaders/placeholder.metal"),
50 );
51 sources.insert(
52 "quantized_matmul".into(),
53 include_str!("shaders/quantized_matmul.metal"),
54 );
55 sources.insert(
56 "quantized_matmul_simd".into(),
57 include_str!("shaders/quantized_matmul.metal"),
58 );
59 sources.insert(
60 "quantized_matmul_simd_bf16".into(),
61 include_str!("shaders/quantized_matmul.metal"),
62 );
63 sources.insert(
64 "quantized_matmul_simd_bf16_expert".into(),
65 include_str!("shaders/quantized_matmul.metal"),
66 );
67
68 let ggml_src: &'static str =
70 include_str!("shaders/quantized_matmul_ggml.metal");
71 sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
72 sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
73 sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
74
75 sources.insert(
77 "quantized_matmul_id".into(),
78 include_str!("shaders/quantized_matmul_id.metal"),
79 );
80
81 let ggml_id_src: &'static str =
83 include_str!("shaders/quantized_matmul_id_ggml.metal");
84 sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
85 sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
86 sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
87
88 let embedding_src: &'static str = include_str!("shaders/embedding.metal");
90 sources.insert("embedding_gather_4bit".into(), embedding_src);
91 sources.insert("embedding_gather_6bit".into(), embedding_src);
92
93 let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
95 sources.insert("moe_gate".into(), moe_gate_src);
96
97 let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
99 sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
100 sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
101 sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
102 sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
103 sources.insert("moe_accumulate".into(), moe_dispatch_src);
104 sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
105 sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
106 sources.insert("zero_buffer".into(), moe_dispatch_src);
107 sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
108 sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
109
110 let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
112 sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
113 sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
114 sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
115 sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
116
117 let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
119 sources.insert("elementwise_add_f32".into(), elementwise_src);
120 sources.insert("elementwise_add_f16".into(), elementwise_src);
121 sources.insert("elementwise_mul_f32".into(), elementwise_src);
122 sources.insert("elementwise_mul_f16".into(), elementwise_src);
123 sources.insert("elementwise_add_bf16".into(), elementwise_src);
124 sources.insert("elementwise_mul_bf16".into(), elementwise_src);
125 sources.insert("cast_f16_to_f32".into(), elementwise_src);
126 sources.insert("cast_f32_to_f16".into(), elementwise_src);
127 sources.insert("cast_bf16_to_f32".into(), elementwise_src);
128 sources.insert("cast_f32_to_bf16".into(), elementwise_src);
129 sources.insert("scalar_mul_bf16".into(), elementwise_src);
130 sources.insert("scalar_mul_f32".into(), elementwise_src);
131 sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
132 sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
133 sources.insert("permute_021_bf16".into(), elementwise_src);
134 sources.insert("permute_021_f32".into(), elementwise_src);
135 sources.insert("transpose_2d_f32".into(), elementwise_src);
136 sources.insert("transpose_2d_f16".into(), elementwise_src);
137
138 let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
140 sources.insert("sdpa".into(), sdpa_src);
141 sources.insert("sdpa_bf16".into(), sdpa_src);
142 let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
143 sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
144 sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
145
146 let flash_attn_vec_src: &'static str =
149 include_str!("shaders/flash_attn_vec.metal");
150 sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
151 sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
152 sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
153 sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
154 sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
156 sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
157
158 let rope_src: &'static str = include_str!("shaders/rope.metal");
160 sources.insert("rope_f32".into(), rope_src);
161 sources.insert("rope_f16".into(), rope_src);
162 sources.insert("rope_bf16".into(), rope_src);
163 sources.insert("rope_neox_bf16".into(), rope_src);
164 sources.insert("rope_neox_f32".into(), rope_src);
165 let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
166 sources.insert("rms_norm_f32".into(), rms_norm_src);
167 sources.insert("rms_norm_f16".into(), rms_norm_src);
168 sources.insert("rms_norm_bf16".into(), rms_norm_src);
169 sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
170 sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
171 sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
173 sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
174 sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
175 let gelu_src: &'static str = include_str!("shaders/gelu.metal");
176 sources.insert("gelu_f32".into(), gelu_src);
177 sources.insert("gelu_f16".into(), gelu_src);
178 sources.insert("gelu_bf16".into(), gelu_src);
179 let softmax_src: &'static str = include_str!("shaders/softmax.metal");
180 sources.insert("softmax_f32".into(), softmax_src);
181 sources.insert("softmax_f16".into(), softmax_src);
182 sources.insert("softmax_bf16".into(), softmax_src);
183 let softcap_src: &'static str = include_str!("shaders/softcap.metal");
184 sources.insert("softcap_f32".into(), softcap_src);
185 sources.insert("softcap_f16".into(), softcap_src);
186 sources.insert("softcap_bf16".into(), softcap_src);
187
188 let fused_norm_add_src: &'static str =
191 include_str!("shaders/fused_norm_add_bf16.metal");
192 sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
193 sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
194
195 let fused_hnr_f32_src: &'static str =
197 include_str!("shaders/fused_head_norm_rope_f32.metal");
198 sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
199
200 let fused_norm_add_f32_src: &'static str =
202 include_str!("shaders/fused_norm_add_f32.metal");
203 sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
204 sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
205 sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
206 sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
207 sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
208 sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
209
210 let argsort_src: &'static str = include_str!("shaders/argsort.metal");
212 sources.insert("argsort_desc_f32".into(), argsort_src);
213
214 let gather_src: &'static str = include_str!("shaders/gather.metal");
216 sources.insert("gather_f32".into(), gather_src);
217
218 let kv_cache_copy_src: &'static str =
220 include_str!("shaders/kv_cache_copy.metal");
221 sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
222 sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
223
224 let copy_src: &'static str = include_str!("shaders/copy.metal");
226 sources.insert("strided_copy_f32".into(), copy_src);
227 sources.insert("offset_copy_f32".into(), copy_src);
228
229 let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
231 sources.insert("dense_gemm_f16".into(), dense_gemm_src);
232 sources.insert("dense_matvec_f16".into(), dense_gemm_src);
233 sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
234
235 let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
237 sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
238 sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
239
240 let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
242 sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
243 sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
244
245 let argmax_src: &'static str = include_str!("shaders/argmax.metal");
247 sources.insert("argmax_f32".into(), argmax_src);
248 let softmax_sample_src: &'static str =
249 include_str!("shaders/softmax_sample.metal");
250 sources.insert("softmax_sample_f32".into(), softmax_sample_src);
251 let top_k_src: &'static str = include_str!("shaders/top_k.metal");
253 sources.insert("top_k_f32".into(), top_k_src);
254
255 Self {
256 cache: HashMap::new(),
257 sources,
258 }
259 }
260
261 pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
264 let name = name.into();
265 self.cache.remove(&name);
267 self.sources.insert(name, source);
268 }
269
270 pub fn get_pipeline(
282 &mut self,
283 name: &str,
284 device: &metal::DeviceRef,
285 ) -> Result<&ComputePipelineState> {
286 if !self.cache.contains_key(name) {
287 let source = self.sources.get(name).ok_or_else(|| {
289 MlxError::KernelNotFound(name.to_string())
290 })?;
291
292 let compile_opts = metal::CompileOptions::new();
293 let library = device
294 .new_library_with_source(source, &compile_opts)
295 .map_err(|msg| MlxError::ShaderCompilationError {
296 name: name.to_string(),
297 message: msg,
298 })?;
299
300 let function = library
301 .get_function(name, None)
302 .map_err(|msg| MlxError::ShaderCompilationError {
303 name: name.to_string(),
304 message: msg,
305 })?;
306
307 let pipeline = device
308 .new_compute_pipeline_state_with_function(&function)
309 .map_err(|msg| MlxError::ShaderCompilationError {
310 name: name.to_string(),
311 message: msg,
312 })?;
313
314 self.cache.insert(name.to_string(), pipeline);
315 }
316
317 self.cache.get(name).ok_or_else(|| {
320 MlxError::KernelNotFound(name.to_string())
321 })
322 }
323
324 pub fn is_cached(&self, name: &str) -> bool {
326 self.cache.contains_key(name)
327 }
328
329 pub fn cached_count(&self) -> usize {
331 self.cache.len()
332 }
333
334 pub fn source_count(&self) -> usize {
336 self.sources.len()
337 }
338}
339
340impl Default for KernelRegistry {
341 fn default() -> Self {
342 Self::new()
343 }
344}