Skip to main content

mlx_native/
kernel_registry.rs

1//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
2//!
3//! MSL shader source is embedded at compile time via `include_str!`.  On first
4//! access, the source is compiled into a Metal library, the named function is
5//! extracted, and a `ComputePipelineState` is created and cached.  Subsequent
6//! calls return the cached pipeline.
7
8use std::collections::HashMap;
9
10use metal::ComputePipelineState;
11
12use crate::error::{MlxError, Result};
13
14/// Registry that lazily compiles and caches Metal compute pipelines from
15/// embedded MSL source.
16///
17/// # Usage
18///
19/// ```ignore
20/// let mut registry = KernelRegistry::new();
21/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
22/// encoder.encode(&pipeline, &buffers, grid, tg);
23/// ```
24///
25/// # Thread Safety
26///
27/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
28/// `get_pipeline` to allow mutable cache insertion).  If you need concurrent
29/// access, wrap it in a `Mutex` or use one registry per thread.
30pub struct KernelRegistry {
31    /// Cached pipelines keyed by kernel function name.
32    cache: HashMap<String, ComputePipelineState>,
33    /// MSL source text keyed by kernel function name.
34    ///
35    /// Populated at construction time with all embedded shader sources.
36    sources: HashMap<String, &'static str>,
37}
38
39impl KernelRegistry {
40    /// Create a new registry with all embedded shader sources pre-registered.
41    ///
42    /// No compilation happens here — shaders are compiled lazily on first use.
43    pub fn new() -> Self {
44        let mut sources = HashMap::new();
45
46        // Register embedded shader sources.
47        sources.insert(
48            "placeholder".into(),
49            include_str!("shaders/placeholder.metal"),
50        );
51        sources.insert(
52            "quantized_matmul".into(),
53            include_str!("shaders/quantized_matmul.metal"),
54        );
55        sources.insert(
56            "quantized_matmul_simd".into(),
57            include_str!("shaders/quantized_matmul.metal"),
58        );
59        sources.insert(
60            "quantized_matmul_simd_bf16".into(),
61            include_str!("shaders/quantized_matmul.metal"),
62        );
63        sources.insert(
64            "quantized_matmul_simd_bf16_expert".into(),
65            include_str!("shaders/quantized_matmul.metal"),
66        );
67
68        // GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
69        let ggml_src: &'static str =
70            include_str!("shaders/quantized_matmul_ggml.metal");
71        sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
72        sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
73        sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
74
75        // Expert-routed (MoE) quantized matmul kernel (Story 2.1)
76        sources.insert(
77            "quantized_matmul_id".into(),
78            include_str!("shaders/quantized_matmul_id.metal"),
79        );
80
81        // Expert-routed (MoE) GGML block-format quantized matmul kernels
82        let ggml_id_src: &'static str =
83            include_str!("shaders/quantized_matmul_id_ggml.metal");
84        sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
85        sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
86        sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
87
88        // Embedding kernels (Story 1.5)
89        let embedding_src: &'static str = include_str!("shaders/embedding.metal");
90        sources.insert("embedding_gather_4bit".into(), embedding_src);
91        sources.insert("embedding_gather_6bit".into(), embedding_src);
92
93        // MoE gate kernel (Story 1.5)
94        let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
95        sources.insert("moe_gate".into(), moe_gate_src);
96
97        // MoE dispatch kernels (Story 1.5)
98        let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
99        sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
100        sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
101        sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
102        sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
103        sources.insert("moe_accumulate".into(), moe_dispatch_src);
104        sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
105        sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
106        sources.insert("zero_buffer".into(), moe_dispatch_src);
107        sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
108        sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
109
110        // Batched KV cache copy kernels
111        let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
112        sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
113        sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
114        sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
115        sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
116
117        // Elementwise and transpose kernels (Story 1.5)
118        let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
119        sources.insert("elementwise_add_f32".into(), elementwise_src);
120        sources.insert("elementwise_add_f16".into(), elementwise_src);
121        sources.insert("elementwise_mul_f32".into(), elementwise_src);
122        sources.insert("elementwise_mul_f16".into(), elementwise_src);
123        sources.insert("elementwise_add_bf16".into(), elementwise_src);
124        sources.insert("elementwise_mul_bf16".into(), elementwise_src);
125        sources.insert("cast_f16_to_f32".into(), elementwise_src);
126        sources.insert("cast_f32_to_f16".into(), elementwise_src);
127        sources.insert("cast_bf16_to_f32".into(), elementwise_src);
128        sources.insert("cast_f32_to_bf16".into(), elementwise_src);
129        sources.insert("scalar_mul_bf16".into(), elementwise_src);
130        sources.insert("scalar_mul_f32".into(), elementwise_src);
131        sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
132        sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
133        sources.insert("permute_021_bf16".into(), elementwise_src);
134        sources.insert("permute_021_f32".into(), elementwise_src);
135        sources.insert("transpose_2d_f32".into(), elementwise_src);
136        sources.insert("transpose_2d_f16".into(), elementwise_src);
137
138        // Attention kernels (Story 1.3)
139        let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
140        sources.insert("sdpa".into(), sdpa_src);
141        sources.insert("sdpa_bf16".into(), sdpa_src);
142        let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
143        sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
144        sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
145
146        // Flash attention vector kernels — SIMD-vectorized decode-path SDPA
147        // (ported from llama.cpp flash_attn_ext_vec)
148        let flash_attn_vec_src: &'static str =
149            include_str!("shaders/flash_attn_vec.metal");
150        sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
151        sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
152        sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
153        sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
154        // F16 KV variants (Phase 4a)
155        sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
156        sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
157
158        // RoPE, normalization, activation kernels (Story 1.4)
159        let rope_src: &'static str = include_str!("shaders/rope.metal");
160        sources.insert("rope_f32".into(), rope_src);
161        sources.insert("rope_f16".into(), rope_src);
162        sources.insert("rope_bf16".into(), rope_src);
163        sources.insert("rope_neox_bf16".into(), rope_src);
164        sources.insert("rope_neox_f32".into(), rope_src);
165        let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
166        sources.insert("rms_norm_f32".into(), rms_norm_src);
167        sources.insert("rms_norm_f16".into(), rms_norm_src);
168        sources.insert("rms_norm_bf16".into(), rms_norm_src);
169        sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
170        sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
171        // Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
172        sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
173        sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
174        sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
175        let gelu_src: &'static str = include_str!("shaders/gelu.metal");
176        sources.insert("gelu_f32".into(), gelu_src);
177        sources.insert("gelu_f16".into(), gelu_src);
178        sources.insert("gelu_bf16".into(), gelu_src);
179        let softmax_src: &'static str = include_str!("shaders/softmax.metal");
180        sources.insert("softmax_f32".into(), softmax_src);
181        sources.insert("softmax_f16".into(), softmax_src);
182        sources.insert("softmax_bf16".into(), softmax_src);
183        let softcap_src: &'static str = include_str!("shaders/softcap.metal");
184        sources.insert("softcap_f32".into(), softcap_src);
185        sources.insert("softcap_f16".into(), softcap_src);
186        sources.insert("softcap_bf16".into(), softcap_src);
187
188        // Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
189        //   normed = rms_norm(input, weight, eps);  output = residual + normed
190        let fused_norm_add_src: &'static str =
191            include_str!("shaders/fused_norm_add_bf16.metal");
192        sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
193        sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
194
195        // Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
196        let fused_hnr_f32_src: &'static str =
197            include_str!("shaders/fused_head_norm_rope_f32.metal");
198        sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
199
200        // Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
201        let fused_norm_add_f32_src: &'static str =
202            include_str!("shaders/fused_norm_add_f32.metal");
203        sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
204        sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
205        sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
206        sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
207        sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
208        sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
209
210        // Argsort kernel (Story 2.3) — MoE top-K routing
211        let argsort_src: &'static str = include_str!("shaders/argsort.metal");
212        sources.insert("argsort_desc_f32".into(), argsort_src);
213
214        // Gather / index_select kernel (Story 2.4)
215        let gather_src: &'static str = include_str!("shaders/gather.metal");
216        sources.insert("gather_f32".into(), gather_src);
217
218        // F32 KV cache copy kernel (Session merge S1+S2)
219        let kv_cache_copy_src: &'static str =
220            include_str!("shaders/kv_cache_copy.metal");
221        sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
222        sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
223
224        // Strided copy kernel (Story 2.5)
225        let copy_src: &'static str = include_str!("shaders/copy.metal");
226        sources.insert("strided_copy_f32".into(), copy_src);
227        sources.insert("offset_copy_f32".into(), copy_src);
228
229        // Dense F16 GEMM kernel (Story 2.6) — lm_head projection
230        let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
231        sources.insert("dense_gemm_f16".into(), dense_gemm_src);
232        sources.insert("dense_matvec_f16".into(), dense_gemm_src);
233        sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
234
235        // Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
236        let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
237        sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
238        sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
239
240        // Fast Hadamard quantize (SIMD shuffle, zero barriers)
241        let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
242        sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
243        sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
244
245        // GPU sampling kernels — eliminate logits readback (Phase 6)
246        let argmax_src: &'static str = include_str!("shaders/argmax.metal");
247        sources.insert("argmax_f32".into(), argmax_src);
248        let softmax_sample_src: &'static str =
249            include_str!("shaders/softmax_sample.metal");
250        sources.insert("softmax_sample_f32".into(), softmax_sample_src);
251        // Top-K kernel for Q8 rerank: avoids full-logits readback.
252        let top_k_src: &'static str = include_str!("shaders/top_k.metal");
253        sources.insert("top_k_f32".into(), top_k_src);
254
255        Self {
256            cache: HashMap::new(),
257            sources,
258        }
259    }
260
261    /// Register a shader source at runtime (useful for testing and dynamic
262    /// kernel generation).
263    pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
264        let name = name.into();
265        // Invalidate any cached pipeline for this name since the source changed.
266        self.cache.remove(&name);
267        self.sources.insert(name, source);
268    }
269
270    /// Get a compiled compute pipeline for the named kernel function.
271    ///
272    /// On first call for a given name, this compiles the MSL source into a
273    /// Metal library, extracts the named function, and creates a
274    /// `ComputePipelineState`.  Subsequent calls return the cached pipeline.
275    ///
276    /// # Errors
277    ///
278    /// * `MlxError::KernelNotFound` — no source registered for this name.
279    /// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
280    ///   creation failed.
281    pub fn get_pipeline(
282        &mut self,
283        name: &str,
284        device: &metal::DeviceRef,
285    ) -> Result<&ComputePipelineState> {
286        if !self.cache.contains_key(name) {
287            // Slow path: compile the shader.
288            let source = self.sources.get(name).ok_or_else(|| {
289                MlxError::KernelNotFound(name.to_string())
290            })?;
291
292            let compile_opts = metal::CompileOptions::new();
293            let library = device
294                .new_library_with_source(source, &compile_opts)
295                .map_err(|msg| MlxError::ShaderCompilationError {
296                    name: name.to_string(),
297                    message: msg,
298                })?;
299
300            let function = library
301                .get_function(name, None)
302                .map_err(|msg| MlxError::ShaderCompilationError {
303                    name: name.to_string(),
304                    message: msg,
305                })?;
306
307            let pipeline = device
308                .new_compute_pipeline_state_with_function(&function)
309                .map_err(|msg| MlxError::ShaderCompilationError {
310                    name: name.to_string(),
311                    message: msg,
312                })?;
313
314            self.cache.insert(name.to_string(), pipeline);
315        }
316
317        // At this point the pipeline is guaranteed to be in the cache.
318        // We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
319        self.cache.get(name).ok_or_else(|| {
320            MlxError::KernelNotFound(name.to_string())
321        })
322    }
323
324    /// Check if a pipeline for the given name is already compiled and cached.
325    pub fn is_cached(&self, name: &str) -> bool {
326        self.cache.contains_key(name)
327    }
328
329    /// Number of compiled pipelines currently in the cache.
330    pub fn cached_count(&self) -> usize {
331        self.cache.len()
332    }
333
334    /// Number of registered shader sources.
335    pub fn source_count(&self) -> usize {
336        self.sources.len()
337    }
338}
339
340impl Default for KernelRegistry {
341    fn default() -> Self {
342        Self::new()
343    }
344}