Skip to main content

oxillama_quant/
dispatch.rs

1//! Runtime kernel selection and dispatch.
2//!
3//! Selects the best available [`QuantKernel`] implementation for a given
4//! quantization type based on compile-time feature flags and runtime
5//! CPU feature detection.
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use oxillama_gguf::GgufTensorType;
11
12use crate::error::{QuantError, QuantResult};
13use crate::reference::{
14    Bf16Ref, F16Ref, F32Ref, Iq1MRef, Iq1SRef, Iq2SRef, Iq2XsRef, Iq2XxsRef, Iq3SRef, Iq3XxsRef,
15    Iq4NlRef, Iq4XsRef, Q1_0G128Ref, Q2KRef, Q3KRef, Q4KRef, Q4_0Ref, Q4_1Ref, Q5KRef, Q5_0Ref,
16    Q5_1Ref, Q6KRef, Q8KRef, Q8_0Ref, Q8_1Ref, Tq1_0Ref, Tq2_0Ref,
17};
18#[cfg(any(
19    all(feature = "simd-avx512", target_arch = "x86_64"),
20    all(feature = "simd-avx2", target_arch = "x86_64"),
21    all(feature = "simd-neon", target_arch = "aarch64"),
22))]
23use crate::simd;
24use crate::simd::float_gemm::{Bf16OxiblasKernel, F16OxiblasKernel, F32OxiblasKernel};
25use crate::traits::QuantKernel;
26
27/// Detected CPU SIMD capabilities.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct SimdCapabilities {
30    /// x86_64 AVX2 support.
31    pub avx2: bool,
32    /// x86_64 AVX-512F support.
33    pub avx512f: bool,
34    /// x86_64 FMA support (usually paired with AVX2).
35    pub fma: bool,
36    /// ARM NEON support.
37    pub neon: bool,
38}
39
40impl SimdCapabilities {
41    /// Detect CPU SIMD capabilities at runtime.
42    pub fn detect() -> Self {
43        Self {
44            avx2: Self::detect_avx2(),
45            avx512f: Self::detect_avx512f(),
46            fma: Self::detect_fma(),
47            neon: Self::detect_neon(),
48        }
49    }
50
51    /// Returns the best available SIMD tier name for display.
52    pub fn best_tier(&self) -> &'static str {
53        if self.avx512f {
54            "AVX-512"
55        } else if self.avx2 && self.fma {
56            "AVX2+FMA"
57        } else if self.avx2 {
58            "AVX2"
59        } else if self.neon {
60            "NEON"
61        } else {
62            "scalar"
63        }
64    }
65
66    #[cfg(target_arch = "x86_64")]
67    fn detect_avx2() -> bool {
68        std::arch::is_x86_feature_detected!("avx2")
69    }
70
71    #[cfg(not(target_arch = "x86_64"))]
72    fn detect_avx2() -> bool {
73        false
74    }
75
76    #[cfg(target_arch = "x86_64")]
77    fn detect_avx512f() -> bool {
78        std::arch::is_x86_feature_detected!("avx512f")
79    }
80
81    #[cfg(not(target_arch = "x86_64"))]
82    fn detect_avx512f() -> bool {
83        false
84    }
85
86    #[cfg(target_arch = "x86_64")]
87    fn detect_fma() -> bool {
88        std::arch::is_x86_feature_detected!("fma")
89    }
90
91    #[cfg(not(target_arch = "x86_64"))]
92    fn detect_fma() -> bool {
93        false
94    }
95
96    #[cfg(target_arch = "aarch64")]
97    fn detect_neon() -> bool {
98        // NEON is mandatory on aarch64
99        true
100    }
101
102    #[cfg(not(target_arch = "aarch64"))]
103    fn detect_neon() -> bool {
104        false
105    }
106}
107
108/// Kernel dispatcher — selects and caches the best kernel for each quant type.
109///
110/// The dispatcher checks (in order):
111/// 1. Platform-specific SIMD kernels (AVX-512, AVX2, NEON) if features enabled.
112/// 2. Portable SIMD kernels.
113/// 3. Reference (naive) scalar kernels.
114#[derive(Debug)]
115pub struct KernelDispatcher {
116    /// Detected CPU SIMD capabilities.
117    pub capabilities: SimdCapabilities,
118}
119
120impl Default for KernelDispatcher {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl KernelDispatcher {
127    /// Create a new dispatcher with runtime CPU feature detection.
128    pub fn new() -> Self {
129        Self {
130            capabilities: SimdCapabilities::detect(),
131        }
132    }
133
134    /// Get the best available kernel for the given quantization type.
135    ///
136    /// Currently returns reference (scalar) kernels for all types.
137    /// When SIMD features are enabled and the CPU supports them,
138    /// this will return optimized SIMD kernels instead.
139    ///
140    /// Returns an error if the quantization type is not yet implemented.
141    pub fn get_kernel(&self, tensor_type: GgufTensorType) -> QuantResult<Box<dyn QuantKernel>> {
142        // 1. AVX-512 path
143        #[cfg(all(feature = "simd-avx512", target_arch = "x86_64"))]
144        if simd::cached_capabilities().avx512f {
145            match tensor_type {
146                GgufTensorType::Q4_0 => return Ok(Box::new(simd::avx512::Q4_0Avx512)),
147                GgufTensorType::Q4_1 => return Ok(Box::new(simd::avx512::Q4_1Avx512)),
148                GgufTensorType::Q8_0 => return Ok(Box::new(simd::avx512::Q8_0Avx512)),
149                GgufTensorType::Q8_1 => return Ok(Box::new(simd::avx512::Q8_1Avx512)),
150                GgufTensorType::Q2K => return Ok(Box::new(simd::avx512::Q2_KAvx512)),
151                GgufTensorType::Q3K => return Ok(Box::new(simd::avx512::Q3_KAvx512)),
152                GgufTensorType::Q4K => return Ok(Box::new(simd::avx512::Q4_KAvx512)),
153                GgufTensorType::Q1_0G128 => return Ok(Box::new(simd::avx512::Q1_0G128Avx512)),
154                GgufTensorType::Q5K => return Ok(Box::new(simd::avx512::Q5_KAvx512)),
155                GgufTensorType::Q5_1 => return Ok(Box::new(simd::avx512::Q5_1Avx512)),
156                GgufTensorType::Q6K => return Ok(Box::new(simd::avx512::Q6_KAvx512)),
157                GgufTensorType::Tq1_0 => return Ok(Box::new(simd::avx512::Tq1_0Avx512)),
158                GgufTensorType::Tq2_0 => return Ok(Box::new(simd::avx512::Tq2_0Avx512)),
159                GgufTensorType::Q5_0 => return Ok(Box::new(simd::avx512::Q5_0Avx512)),
160                GgufTensorType::Q8K => return Ok(Box::new(simd::avx512::Q8_KAvx512)),
161                GgufTensorType::Iq2Xxs => return Ok(Box::new(simd::avx512::Iq2XxsAvx512)),
162                GgufTensorType::Iq2Xs => return Ok(Box::new(simd::avx512::Iq2XsAvx512)),
163                GgufTensorType::Iq3S => return Ok(Box::new(simd::avx512::Iq3SAvx512)),
164                GgufTensorType::Iq4Xs => return Ok(Box::new(simd::avx512::Iq4XsAvx512)),
165                _ => {}
166            }
167        }
168
169        // 2. AVX2 path
170        #[cfg(all(feature = "simd-avx2", target_arch = "x86_64"))]
171        if simd::cached_capabilities().avx2 {
172            match tensor_type {
173                GgufTensorType::Q4_0 => return Ok(Box::new(simd::avx2::Q4_0Avx2)),
174                GgufTensorType::Q5_0 => return Ok(Box::new(simd::avx2::Q5_0Avx2)),
175                GgufTensorType::Q8_0 => return Ok(Box::new(simd::avx2::Q8_0Avx2)),
176                GgufTensorType::Q4K => return Ok(Box::new(simd::avx2::Q4_KAvx2)),
177                GgufTensorType::Q5K => return Ok(Box::new(simd::avx2::Q5_KAvx2)),
178                GgufTensorType::Q6K => return Ok(Box::new(simd::avx2::Q6_KAvx2)),
179                GgufTensorType::Q1_0G128 => return Ok(Box::new(simd::avx2::Q1_0G128Avx2)),
180                GgufTensorType::Q2K => return Ok(Box::new(simd::avx2::Q2_KAvx2)),
181                GgufTensorType::Q3K => return Ok(Box::new(simd::avx2::Q3_KAvx2)),
182                GgufTensorType::Q4_1 => return Ok(Box::new(simd::avx2::Q4_1Avx2)),
183                GgufTensorType::Q5_1 => return Ok(Box::new(simd::avx2::Q5_1Avx2)),
184                GgufTensorType::Q8_1 => return Ok(Box::new(simd::avx2::Q8_1Avx2)),
185                GgufTensorType::Iq1S => return Ok(Box::new(simd::avx2::Iq1SAvx2)),
186                GgufTensorType::Iq1M => return Ok(Box::new(simd::avx2::Iq1MAvx2)),
187                GgufTensorType::Iq2Xs => return Ok(Box::new(simd::avx2::Iq2XsAvx2)),
188                GgufTensorType::Iq2Xxs => return Ok(Box::new(simd::avx2::Iq2XxsAvx2)),
189                GgufTensorType::Iq2S => return Ok(Box::new(simd::avx2::Iq2SAvx2)),
190                GgufTensorType::Iq3Xxs => return Ok(Box::new(simd::avx2::Iq3XxsAvx2)),
191                GgufTensorType::Iq3S => return Ok(Box::new(simd::avx2::Iq3SAvx2)),
192                GgufTensorType::Iq4Nl => return Ok(Box::new(simd::avx2::Iq4NlAvx2)),
193                GgufTensorType::Iq4Xs => return Ok(Box::new(simd::avx2::Iq4XsAvx2)),
194                GgufTensorType::Q8K => return Ok(Box::new(simd::avx2::Q8_KAvx2)),
195                GgufTensorType::Tq1_0 => return Ok(Box::new(simd::avx2::Tq1_0Avx2)),
196                GgufTensorType::Tq2_0 => return Ok(Box::new(simd::avx2::Tq2_0Avx2)),
197                _ => {}
198            }
199        }
200
201        // 3. NEON path
202        #[cfg(all(feature = "simd-neon", target_arch = "aarch64"))]
203        if simd::cached_capabilities().neon {
204            match tensor_type {
205                GgufTensorType::Q4_0 => return Ok(Box::new(simd::neon::Q4_0Neon)),
206                GgufTensorType::Q4_1 => return Ok(Box::new(simd::neon::Q4_1Neon)),
207                GgufTensorType::Q5_0 => return Ok(Box::new(simd::neon::Q5_0Neon)),
208                GgufTensorType::Q5_1 => return Ok(Box::new(simd::neon::Q5_1Neon)),
209                GgufTensorType::Q8_0 => return Ok(Box::new(simd::neon::Q8_0Neon)),
210                GgufTensorType::Q8_1 => return Ok(Box::new(simd::neon::Q8_1Neon)),
211                GgufTensorType::Q2K => return Ok(Box::new(simd::neon::Q2_KNeon)),
212                GgufTensorType::Q3K => return Ok(Box::new(simd::neon::Q3_KNeon)),
213                GgufTensorType::Q4K => return Ok(Box::new(simd::neon::Q4_KNeon)),
214                GgufTensorType::Q5K => return Ok(Box::new(simd::neon::Q5_KNeon)),
215                GgufTensorType::Q1_0G128 => return Ok(Box::new(simd::neon::Q1_0G128Neon)),
216                GgufTensorType::Q6K => return Ok(Box::new(simd::neon::Q6_KNeon)),
217                GgufTensorType::Q8K => return Ok(Box::new(simd::neon::Q8_KNeon)),
218                GgufTensorType::Iq1S => return Ok(Box::new(simd::neon::Iq1SNeon)),
219                GgufTensorType::Iq1M => return Ok(Box::new(simd::neon::Iq1MNeon)),
220                GgufTensorType::Iq2S => return Ok(Box::new(simd::neon::Iq2SNeon)),
221                GgufTensorType::Iq2Xxs => return Ok(Box::new(simd::neon::Iq2XxsNeon)),
222                GgufTensorType::Iq2Xs => return Ok(Box::new(simd::neon::Iq2XsNeon)),
223                GgufTensorType::Iq3Xxs => return Ok(Box::new(simd::neon::Iq3XxsNeon)),
224                GgufTensorType::Iq3S => return Ok(Box::new(simd::neon::Iq3SNeon)),
225                GgufTensorType::Iq4Xs => return Ok(Box::new(simd::neon::Iq4XsNeon)),
226                GgufTensorType::Iq4Nl => return Ok(Box::new(simd::neon::Iq4NlNeon)),
227                GgufTensorType::Tq1_0 => return Ok(Box::new(simd::neon::Tq1_0Neon)),
228                GgufTensorType::Tq2_0 => return Ok(Box::new(simd::neon::Tq2_0Neon)),
229                _ => {}
230            }
231        }
232
233        // 4. oxiblas-backed float kernels (always available, no CPU feature gate needed)
234        match tensor_type {
235            GgufTensorType::F32 => return Ok(Box::new(F32OxiblasKernel)),
236            GgufTensorType::F16 => return Ok(Box::new(F16OxiblasKernel)),
237            GgufTensorType::Bf16 => return Ok(Box::new(Bf16OxiblasKernel)),
238            _ => {}
239        }
240
241        // 5. Scalar reference fallback
242        self.get_reference_kernel(tensor_type)
243    }
244
245    /// Get the reference (scalar) kernel for a given type.
246    fn get_reference_kernel(
247        &self,
248        tensor_type: GgufTensorType,
249    ) -> QuantResult<Box<dyn QuantKernel>> {
250        match tensor_type {
251            GgufTensorType::F32 => Ok(Box::new(F32Ref)),
252            GgufTensorType::F16 => Ok(Box::new(F16Ref)),
253            GgufTensorType::Bf16 => Ok(Box::new(Bf16Ref)),
254            GgufTensorType::Q4_0 => Ok(Box::new(Q4_0Ref)),
255            GgufTensorType::Q4_1 => Ok(Box::new(Q4_1Ref)),
256            GgufTensorType::Q5_0 => Ok(Box::new(Q5_0Ref)),
257            GgufTensorType::Q5_1 => Ok(Box::new(Q5_1Ref)),
258            GgufTensorType::Q8_0 => Ok(Box::new(Q8_0Ref)),
259            GgufTensorType::Q8_1 => Ok(Box::new(Q8_1Ref)),
260            GgufTensorType::Q2K => Ok(Box::new(Q2KRef)),
261            GgufTensorType::Q3K => Ok(Box::new(Q3KRef)),
262            GgufTensorType::Q4K => Ok(Box::new(Q4KRef)),
263            GgufTensorType::Q5K => Ok(Box::new(Q5KRef)),
264            GgufTensorType::Q6K => Ok(Box::new(Q6KRef)),
265            GgufTensorType::Q8K => Ok(Box::new(Q8KRef)),
266            GgufTensorType::Q1_0G128 => Ok(Box::new(Q1_0G128Ref)),
267            GgufTensorType::Iq1S => Ok(Box::new(Iq1SRef)),
268            GgufTensorType::Iq1M => Ok(Box::new(Iq1MRef)),
269            GgufTensorType::Iq2Xxs => Ok(Box::new(Iq2XxsRef)),
270            GgufTensorType::Iq2Xs => Ok(Box::new(Iq2XsRef)),
271            GgufTensorType::Iq2S => Ok(Box::new(Iq2SRef)),
272            GgufTensorType::Iq3Xxs => Ok(Box::new(Iq3XxsRef)),
273            GgufTensorType::Iq3S => Ok(Box::new(Iq3SRef)),
274            GgufTensorType::Iq4Nl => Ok(Box::new(Iq4NlRef)),
275            GgufTensorType::Iq4Xs => Ok(Box::new(Iq4XsRef)),
276            GgufTensorType::Tq1_0 => Ok(Box::new(Tq1_0Ref)),
277            GgufTensorType::Tq2_0 => Ok(Box::new(Tq2_0Ref)),
278            _ => Err(QuantError::UnsupportedType {
279                quant_type: tensor_type.name().to_string(),
280            }),
281        }
282    }
283
284    /// Check if a kernel is available for the given quantization type.
285    pub fn is_supported(&self, tensor_type: GgufTensorType) -> bool {
286        matches!(
287            tensor_type,
288            GgufTensorType::F32
289                | GgufTensorType::F16
290                | GgufTensorType::Bf16
291                | GgufTensorType::Q4_0
292                | GgufTensorType::Q4_1
293                | GgufTensorType::Q5_0
294                | GgufTensorType::Q5_1
295                | GgufTensorType::Q8_0
296                | GgufTensorType::Q8_1
297                | GgufTensorType::Q2K
298                | GgufTensorType::Q3K
299                | GgufTensorType::Q4K
300                | GgufTensorType::Q5K
301                | GgufTensorType::Q6K
302                | GgufTensorType::Q8K
303                | GgufTensorType::Q1_0G128
304                | GgufTensorType::Iq1S
305                | GgufTensorType::Iq1M
306                | GgufTensorType::Iq2Xxs
307                | GgufTensorType::Iq2Xs
308                | GgufTensorType::Iq2S
309                | GgufTensorType::Iq3Xxs
310                | GgufTensorType::Iq3S
311                | GgufTensorType::Iq4Nl
312                | GgufTensorType::Iq4Xs
313                | GgufTensorType::Tq1_0
314                | GgufTensorType::Tq2_0
315        )
316    }
317
318    /// List all currently supported quantization types.
319    pub fn supported_types(&self) -> Vec<GgufTensorType> {
320        vec![
321            GgufTensorType::F32,
322            GgufTensorType::F16,
323            GgufTensorType::Bf16,
324            GgufTensorType::Q2K,
325            GgufTensorType::Q3K,
326            GgufTensorType::Q4_0,
327            GgufTensorType::Q4_1,
328            GgufTensorType::Q4K,
329            GgufTensorType::Q5_0,
330            GgufTensorType::Q5_1,
331            GgufTensorType::Q5K,
332            GgufTensorType::Q6K,
333            GgufTensorType::Q8_0,
334            GgufTensorType::Q8_1,
335            GgufTensorType::Q8K,
336            GgufTensorType::Q1_0G128,
337            GgufTensorType::Iq1S,
338            GgufTensorType::Iq1M,
339            GgufTensorType::Iq2Xxs,
340            GgufTensorType::Iq2Xs,
341            GgufTensorType::Iq2S,
342            GgufTensorType::Iq3Xxs,
343            GgufTensorType::Iq3S,
344            GgufTensorType::Iq4Nl,
345            GgufTensorType::Iq4Xs,
346            GgufTensorType::Tq1_0,
347            GgufTensorType::Tq2_0,
348        ]
349    }
350}
351
352/// Cached kernel dispatcher — singleton per process.
353///
354/// Returns `Arc<dyn QuantKernel>` instead of `Box<dyn QuantKernel>`,
355/// allowing zero-allocation kernel lookup after the first call for each type.
356pub struct CachedDispatcher {
357    inner: KernelDispatcher,
358    cache: Mutex<HashMap<GgufTensorType, Arc<dyn QuantKernel>>>,
359}
360
361impl CachedDispatcher {
362    /// Create a new cached dispatcher with runtime CPU feature detection.
363    pub fn new() -> Self {
364        Self {
365            inner: KernelDispatcher::new(),
366            cache: Mutex::new(HashMap::new()),
367        }
368    }
369
370    /// Get or create a cached kernel for the given tensor type.
371    ///
372    /// The first call for each type allocates; subsequent calls return
373    /// a clone of the `Arc` (cheap reference count bump).
374    pub fn get_kernel(&self, tensor_type: GgufTensorType) -> QuantResult<Arc<dyn QuantKernel>> {
375        // Fast path: check if already cached
376        {
377            let cache = self.cache.lock().map_err(|_| QuantError::Internal {
378                message: "kernel cache lock poisoned".to_string(),
379            })?;
380            if let Some(kernel) = cache.get(&tensor_type) {
381                return Ok(Arc::clone(kernel));
382            }
383        }
384
385        // Slow path: create and cache
386        let kernel: Arc<dyn QuantKernel> = self.inner.get_kernel(tensor_type)?.into();
387        let mut cache = self.cache.lock().map_err(|_| QuantError::Internal {
388            message: "kernel cache lock poisoned".to_string(),
389        })?;
390        cache
391            .entry(tensor_type)
392            .or_insert_with(|| Arc::clone(&kernel));
393        Ok(kernel)
394    }
395
396    /// Access the underlying capabilities.
397    pub fn capabilities(&self) -> &SimdCapabilities {
398        &self.inner.capabilities
399    }
400
401    /// Check if a type is supported (delegates to inner).
402    pub fn is_supported(&self, tensor_type: GgufTensorType) -> bool {
403        self.inner.is_supported(tensor_type)
404    }
405
406    /// List supported types (delegates to inner).
407    pub fn supported_types(&self) -> Vec<GgufTensorType> {
408        self.inner.supported_types()
409    }
410}
411
412impl Default for CachedDispatcher {
413    fn default() -> Self {
414        Self::new()
415    }
416}
417
418/// Process-global cached dispatcher instance.
419static GLOBAL_DISPATCHER: OnceLock<CachedDispatcher> = OnceLock::new();
420
421/// Get the global cached dispatcher singleton.
422///
423/// The dispatcher is initialized on first call with runtime SIMD detection.
424/// Subsequent calls return a reference to the same instance.
425pub fn global_dispatcher() -> &'static CachedDispatcher {
426    GLOBAL_DISPATCHER.get_or_init(CachedDispatcher::new)
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_simd_detection() {
435        let caps = SimdCapabilities::detect();
436        // Just verify it doesn't panic and returns something
437        let tier = caps.best_tier();
438        assert!(!tier.is_empty());
439    }
440
441    #[test]
442    fn test_dispatcher_returns_all_supported() {
443        let dispatcher = KernelDispatcher::new();
444        for tensor_type in dispatcher.supported_types() {
445            let kernel = dispatcher.get_kernel(tensor_type);
446            assert!(kernel.is_ok(), "failed to get kernel for {:?}", tensor_type);
447        }
448    }
449
450    #[test]
451    fn test_dispatcher_unsupported() {
452        let dispatcher = KernelDispatcher::new();
453        // TQ1_0 and TQ2_0 are now supported
454        assert!(dispatcher.is_supported(GgufTensorType::Tq1_0));
455        assert!(dispatcher.is_supported(GgufTensorType::Tq2_0));
456        // Experimental packed types are not supported
457        assert!(!dispatcher.is_supported(GgufTensorType::Q4_0_4_4));
458        // All IQ types are now supported
459        assert!(dispatcher.is_supported(GgufTensorType::Iq1S));
460        assert!(dispatcher.is_supported(GgufTensorType::Iq4Nl));
461    }
462
463    #[test]
464    fn test_cached_dispatcher_returns_same_kernel() {
465        let dispatcher = CachedDispatcher::new();
466        let k1 = dispatcher.get_kernel(GgufTensorType::Q4_0).expect("k1");
467        let k2 = dispatcher.get_kernel(GgufTensorType::Q4_0).expect("k2");
468        assert_eq!(k1.name(), k2.name());
469        assert!(
470            Arc::ptr_eq(&k1, &k2),
471            "second call should return cached Arc"
472        );
473    }
474
475    #[test]
476    fn test_global_dispatcher_singleton() {
477        let d1 = global_dispatcher();
478        let d2 = global_dispatcher();
479        assert!(
480            std::ptr::eq(d1, d2),
481            "global_dispatcher should return same instance"
482        );
483    }
484
485    #[test]
486    fn test_cached_dispatcher_all_types() {
487        let dispatcher = CachedDispatcher::new();
488        for t in dispatcher.supported_types() {
489            let k = dispatcher.get_kernel(t);
490            assert!(k.is_ok(), "cached dispatch failed for {:?}", t);
491        }
492    }
493}