1use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use oxillama_gguf::GgufTensorType;
11
12use crate::error::{QuantError, QuantResult};
13use crate::reference::{
14 Bf16Ref, F16Ref, F32Ref, Iq1MRef, Iq1SRef, Iq2SRef, Iq2XsRef, Iq2XxsRef, Iq3SRef, Iq3XxsRef,
15 Iq4NlRef, Iq4XsRef, Q1_0G128Ref, Q2KRef, Q3KRef, Q4KRef, Q4_0Ref, Q4_1Ref, Q5KRef, Q5_0Ref,
16 Q5_1Ref, Q6KRef, Q8KRef, Q8_0Ref, Q8_1Ref, Tq1_0Ref, Tq2_0Ref,
17};
18#[cfg(any(
19 all(feature = "simd-avx512", target_arch = "x86_64"),
20 all(feature = "simd-avx2", target_arch = "x86_64"),
21 all(feature = "simd-neon", target_arch = "aarch64"),
22))]
23use crate::simd;
24use crate::simd::float_gemm::{Bf16OxiblasKernel, F16OxiblasKernel, F32OxiblasKernel};
25use crate::traits::QuantKernel;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct SimdCapabilities {
30 pub avx2: bool,
32 pub avx512f: bool,
34 pub fma: bool,
36 pub neon: bool,
38}
39
40impl SimdCapabilities {
41 pub fn detect() -> Self {
43 Self {
44 avx2: Self::detect_avx2(),
45 avx512f: Self::detect_avx512f(),
46 fma: Self::detect_fma(),
47 neon: Self::detect_neon(),
48 }
49 }
50
51 pub fn best_tier(&self) -> &'static str {
53 if self.avx512f {
54 "AVX-512"
55 } else if self.avx2 && self.fma {
56 "AVX2+FMA"
57 } else if self.avx2 {
58 "AVX2"
59 } else if self.neon {
60 "NEON"
61 } else {
62 "scalar"
63 }
64 }
65
66 #[cfg(target_arch = "x86_64")]
67 fn detect_avx2() -> bool {
68 std::arch::is_x86_feature_detected!("avx2")
69 }
70
71 #[cfg(not(target_arch = "x86_64"))]
72 fn detect_avx2() -> bool {
73 false
74 }
75
76 #[cfg(target_arch = "x86_64")]
77 fn detect_avx512f() -> bool {
78 std::arch::is_x86_feature_detected!("avx512f")
79 }
80
81 #[cfg(not(target_arch = "x86_64"))]
82 fn detect_avx512f() -> bool {
83 false
84 }
85
86 #[cfg(target_arch = "x86_64")]
87 fn detect_fma() -> bool {
88 std::arch::is_x86_feature_detected!("fma")
89 }
90
91 #[cfg(not(target_arch = "x86_64"))]
92 fn detect_fma() -> bool {
93 false
94 }
95
96 #[cfg(target_arch = "aarch64")]
97 fn detect_neon() -> bool {
98 true
100 }
101
102 #[cfg(not(target_arch = "aarch64"))]
103 fn detect_neon() -> bool {
104 false
105 }
106}
107
108#[derive(Debug)]
115pub struct KernelDispatcher {
116 pub capabilities: SimdCapabilities,
118}
119
120impl Default for KernelDispatcher {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl KernelDispatcher {
127 pub fn new() -> Self {
129 Self {
130 capabilities: SimdCapabilities::detect(),
131 }
132 }
133
134 pub fn get_kernel(&self, tensor_type: GgufTensorType) -> QuantResult<Box<dyn QuantKernel>> {
142 #[cfg(all(feature = "simd-avx512", target_arch = "x86_64"))]
144 if simd::cached_capabilities().avx512f {
145 match tensor_type {
146 GgufTensorType::Q4_0 => return Ok(Box::new(simd::avx512::Q4_0Avx512)),
147 GgufTensorType::Q4_1 => return Ok(Box::new(simd::avx512::Q4_1Avx512)),
148 GgufTensorType::Q8_0 => return Ok(Box::new(simd::avx512::Q8_0Avx512)),
149 GgufTensorType::Q8_1 => return Ok(Box::new(simd::avx512::Q8_1Avx512)),
150 GgufTensorType::Q2K => return Ok(Box::new(simd::avx512::Q2_KAvx512)),
151 GgufTensorType::Q3K => return Ok(Box::new(simd::avx512::Q3_KAvx512)),
152 GgufTensorType::Q4K => return Ok(Box::new(simd::avx512::Q4_KAvx512)),
153 GgufTensorType::Q1_0G128 => return Ok(Box::new(simd::avx512::Q1_0G128Avx512)),
154 GgufTensorType::Q5K => return Ok(Box::new(simd::avx512::Q5_KAvx512)),
155 GgufTensorType::Q5_1 => return Ok(Box::new(simd::avx512::Q5_1Avx512)),
156 GgufTensorType::Q6K => return Ok(Box::new(simd::avx512::Q6_KAvx512)),
157 GgufTensorType::Tq1_0 => return Ok(Box::new(simd::avx512::Tq1_0Avx512)),
158 GgufTensorType::Tq2_0 => return Ok(Box::new(simd::avx512::Tq2_0Avx512)),
159 GgufTensorType::Q5_0 => return Ok(Box::new(simd::avx512::Q5_0Avx512)),
160 GgufTensorType::Q8K => return Ok(Box::new(simd::avx512::Q8_KAvx512)),
161 GgufTensorType::Iq2Xxs => return Ok(Box::new(simd::avx512::Iq2XxsAvx512)),
162 GgufTensorType::Iq2Xs => return Ok(Box::new(simd::avx512::Iq2XsAvx512)),
163 GgufTensorType::Iq3S => return Ok(Box::new(simd::avx512::Iq3SAvx512)),
164 GgufTensorType::Iq4Xs => return Ok(Box::new(simd::avx512::Iq4XsAvx512)),
165 _ => {}
166 }
167 }
168
169 #[cfg(all(feature = "simd-avx2", target_arch = "x86_64"))]
171 if simd::cached_capabilities().avx2 {
172 match tensor_type {
173 GgufTensorType::Q4_0 => return Ok(Box::new(simd::avx2::Q4_0Avx2)),
174 GgufTensorType::Q5_0 => return Ok(Box::new(simd::avx2::Q5_0Avx2)),
175 GgufTensorType::Q8_0 => return Ok(Box::new(simd::avx2::Q8_0Avx2)),
176 GgufTensorType::Q4K => return Ok(Box::new(simd::avx2::Q4_KAvx2)),
177 GgufTensorType::Q5K => return Ok(Box::new(simd::avx2::Q5_KAvx2)),
178 GgufTensorType::Q6K => return Ok(Box::new(simd::avx2::Q6_KAvx2)),
179 GgufTensorType::Q1_0G128 => return Ok(Box::new(simd::avx2::Q1_0G128Avx2)),
180 GgufTensorType::Q2K => return Ok(Box::new(simd::avx2::Q2_KAvx2)),
181 GgufTensorType::Q3K => return Ok(Box::new(simd::avx2::Q3_KAvx2)),
182 GgufTensorType::Q4_1 => return Ok(Box::new(simd::avx2::Q4_1Avx2)),
183 GgufTensorType::Q5_1 => return Ok(Box::new(simd::avx2::Q5_1Avx2)),
184 GgufTensorType::Q8_1 => return Ok(Box::new(simd::avx2::Q8_1Avx2)),
185 GgufTensorType::Iq1S => return Ok(Box::new(simd::avx2::Iq1SAvx2)),
186 GgufTensorType::Iq1M => return Ok(Box::new(simd::avx2::Iq1MAvx2)),
187 GgufTensorType::Iq2Xs => return Ok(Box::new(simd::avx2::Iq2XsAvx2)),
188 GgufTensorType::Iq2Xxs => return Ok(Box::new(simd::avx2::Iq2XxsAvx2)),
189 GgufTensorType::Iq2S => return Ok(Box::new(simd::avx2::Iq2SAvx2)),
190 GgufTensorType::Iq3Xxs => return Ok(Box::new(simd::avx2::Iq3XxsAvx2)),
191 GgufTensorType::Iq3S => return Ok(Box::new(simd::avx2::Iq3SAvx2)),
192 GgufTensorType::Iq4Nl => return Ok(Box::new(simd::avx2::Iq4NlAvx2)),
193 GgufTensorType::Iq4Xs => return Ok(Box::new(simd::avx2::Iq4XsAvx2)),
194 GgufTensorType::Q8K => return Ok(Box::new(simd::avx2::Q8_KAvx2)),
195 GgufTensorType::Tq1_0 => return Ok(Box::new(simd::avx2::Tq1_0Avx2)),
196 GgufTensorType::Tq2_0 => return Ok(Box::new(simd::avx2::Tq2_0Avx2)),
197 _ => {}
198 }
199 }
200
201 #[cfg(all(feature = "simd-neon", target_arch = "aarch64"))]
203 if simd::cached_capabilities().neon {
204 match tensor_type {
205 GgufTensorType::Q4_0 => return Ok(Box::new(simd::neon::Q4_0Neon)),
206 GgufTensorType::Q4_1 => return Ok(Box::new(simd::neon::Q4_1Neon)),
207 GgufTensorType::Q5_0 => return Ok(Box::new(simd::neon::Q5_0Neon)),
208 GgufTensorType::Q5_1 => return Ok(Box::new(simd::neon::Q5_1Neon)),
209 GgufTensorType::Q8_0 => return Ok(Box::new(simd::neon::Q8_0Neon)),
210 GgufTensorType::Q8_1 => return Ok(Box::new(simd::neon::Q8_1Neon)),
211 GgufTensorType::Q2K => return Ok(Box::new(simd::neon::Q2_KNeon)),
212 GgufTensorType::Q3K => return Ok(Box::new(simd::neon::Q3_KNeon)),
213 GgufTensorType::Q4K => return Ok(Box::new(simd::neon::Q4_KNeon)),
214 GgufTensorType::Q5K => return Ok(Box::new(simd::neon::Q5_KNeon)),
215 GgufTensorType::Q1_0G128 => return Ok(Box::new(simd::neon::Q1_0G128Neon)),
216 GgufTensorType::Q6K => return Ok(Box::new(simd::neon::Q6_KNeon)),
217 GgufTensorType::Q8K => return Ok(Box::new(simd::neon::Q8_KNeon)),
218 GgufTensorType::Iq1S => return Ok(Box::new(simd::neon::Iq1SNeon)),
219 GgufTensorType::Iq1M => return Ok(Box::new(simd::neon::Iq1MNeon)),
220 GgufTensorType::Iq2S => return Ok(Box::new(simd::neon::Iq2SNeon)),
221 GgufTensorType::Iq2Xxs => return Ok(Box::new(simd::neon::Iq2XxsNeon)),
222 GgufTensorType::Iq2Xs => return Ok(Box::new(simd::neon::Iq2XsNeon)),
223 GgufTensorType::Iq3Xxs => return Ok(Box::new(simd::neon::Iq3XxsNeon)),
224 GgufTensorType::Iq3S => return Ok(Box::new(simd::neon::Iq3SNeon)),
225 GgufTensorType::Iq4Xs => return Ok(Box::new(simd::neon::Iq4XsNeon)),
226 GgufTensorType::Iq4Nl => return Ok(Box::new(simd::neon::Iq4NlNeon)),
227 GgufTensorType::Tq1_0 => return Ok(Box::new(simd::neon::Tq1_0Neon)),
228 GgufTensorType::Tq2_0 => return Ok(Box::new(simd::neon::Tq2_0Neon)),
229 _ => {}
230 }
231 }
232
233 match tensor_type {
235 GgufTensorType::F32 => return Ok(Box::new(F32OxiblasKernel)),
236 GgufTensorType::F16 => return Ok(Box::new(F16OxiblasKernel)),
237 GgufTensorType::Bf16 => return Ok(Box::new(Bf16OxiblasKernel)),
238 _ => {}
239 }
240
241 self.get_reference_kernel(tensor_type)
243 }
244
245 fn get_reference_kernel(
247 &self,
248 tensor_type: GgufTensorType,
249 ) -> QuantResult<Box<dyn QuantKernel>> {
250 match tensor_type {
251 GgufTensorType::F32 => Ok(Box::new(F32Ref)),
252 GgufTensorType::F16 => Ok(Box::new(F16Ref)),
253 GgufTensorType::Bf16 => Ok(Box::new(Bf16Ref)),
254 GgufTensorType::Q4_0 => Ok(Box::new(Q4_0Ref)),
255 GgufTensorType::Q4_1 => Ok(Box::new(Q4_1Ref)),
256 GgufTensorType::Q5_0 => Ok(Box::new(Q5_0Ref)),
257 GgufTensorType::Q5_1 => Ok(Box::new(Q5_1Ref)),
258 GgufTensorType::Q8_0 => Ok(Box::new(Q8_0Ref)),
259 GgufTensorType::Q8_1 => Ok(Box::new(Q8_1Ref)),
260 GgufTensorType::Q2K => Ok(Box::new(Q2KRef)),
261 GgufTensorType::Q3K => Ok(Box::new(Q3KRef)),
262 GgufTensorType::Q4K => Ok(Box::new(Q4KRef)),
263 GgufTensorType::Q5K => Ok(Box::new(Q5KRef)),
264 GgufTensorType::Q6K => Ok(Box::new(Q6KRef)),
265 GgufTensorType::Q8K => Ok(Box::new(Q8KRef)),
266 GgufTensorType::Q1_0G128 => Ok(Box::new(Q1_0G128Ref)),
267 GgufTensorType::Iq1S => Ok(Box::new(Iq1SRef)),
268 GgufTensorType::Iq1M => Ok(Box::new(Iq1MRef)),
269 GgufTensorType::Iq2Xxs => Ok(Box::new(Iq2XxsRef)),
270 GgufTensorType::Iq2Xs => Ok(Box::new(Iq2XsRef)),
271 GgufTensorType::Iq2S => Ok(Box::new(Iq2SRef)),
272 GgufTensorType::Iq3Xxs => Ok(Box::new(Iq3XxsRef)),
273 GgufTensorType::Iq3S => Ok(Box::new(Iq3SRef)),
274 GgufTensorType::Iq4Nl => Ok(Box::new(Iq4NlRef)),
275 GgufTensorType::Iq4Xs => Ok(Box::new(Iq4XsRef)),
276 GgufTensorType::Tq1_0 => Ok(Box::new(Tq1_0Ref)),
277 GgufTensorType::Tq2_0 => Ok(Box::new(Tq2_0Ref)),
278 _ => Err(QuantError::UnsupportedType {
279 quant_type: tensor_type.name().to_string(),
280 }),
281 }
282 }
283
284 pub fn is_supported(&self, tensor_type: GgufTensorType) -> bool {
286 matches!(
287 tensor_type,
288 GgufTensorType::F32
289 | GgufTensorType::F16
290 | GgufTensorType::Bf16
291 | GgufTensorType::Q4_0
292 | GgufTensorType::Q4_1
293 | GgufTensorType::Q5_0
294 | GgufTensorType::Q5_1
295 | GgufTensorType::Q8_0
296 | GgufTensorType::Q8_1
297 | GgufTensorType::Q2K
298 | GgufTensorType::Q3K
299 | GgufTensorType::Q4K
300 | GgufTensorType::Q5K
301 | GgufTensorType::Q6K
302 | GgufTensorType::Q8K
303 | GgufTensorType::Q1_0G128
304 | GgufTensorType::Iq1S
305 | GgufTensorType::Iq1M
306 | GgufTensorType::Iq2Xxs
307 | GgufTensorType::Iq2Xs
308 | GgufTensorType::Iq2S
309 | GgufTensorType::Iq3Xxs
310 | GgufTensorType::Iq3S
311 | GgufTensorType::Iq4Nl
312 | GgufTensorType::Iq4Xs
313 | GgufTensorType::Tq1_0
314 | GgufTensorType::Tq2_0
315 )
316 }
317
318 pub fn supported_types(&self) -> Vec<GgufTensorType> {
320 vec![
321 GgufTensorType::F32,
322 GgufTensorType::F16,
323 GgufTensorType::Bf16,
324 GgufTensorType::Q2K,
325 GgufTensorType::Q3K,
326 GgufTensorType::Q4_0,
327 GgufTensorType::Q4_1,
328 GgufTensorType::Q4K,
329 GgufTensorType::Q5_0,
330 GgufTensorType::Q5_1,
331 GgufTensorType::Q5K,
332 GgufTensorType::Q6K,
333 GgufTensorType::Q8_0,
334 GgufTensorType::Q8_1,
335 GgufTensorType::Q8K,
336 GgufTensorType::Q1_0G128,
337 GgufTensorType::Iq1S,
338 GgufTensorType::Iq1M,
339 GgufTensorType::Iq2Xxs,
340 GgufTensorType::Iq2Xs,
341 GgufTensorType::Iq2S,
342 GgufTensorType::Iq3Xxs,
343 GgufTensorType::Iq3S,
344 GgufTensorType::Iq4Nl,
345 GgufTensorType::Iq4Xs,
346 GgufTensorType::Tq1_0,
347 GgufTensorType::Tq2_0,
348 ]
349 }
350}
351
352pub struct CachedDispatcher {
357 inner: KernelDispatcher,
358 cache: Mutex<HashMap<GgufTensorType, Arc<dyn QuantKernel>>>,
359}
360
361impl CachedDispatcher {
362 pub fn new() -> Self {
364 Self {
365 inner: KernelDispatcher::new(),
366 cache: Mutex::new(HashMap::new()),
367 }
368 }
369
370 pub fn get_kernel(&self, tensor_type: GgufTensorType) -> QuantResult<Arc<dyn QuantKernel>> {
375 {
377 let cache = self.cache.lock().map_err(|_| QuantError::Internal {
378 message: "kernel cache lock poisoned".to_string(),
379 })?;
380 if let Some(kernel) = cache.get(&tensor_type) {
381 return Ok(Arc::clone(kernel));
382 }
383 }
384
385 let kernel: Arc<dyn QuantKernel> = self.inner.get_kernel(tensor_type)?.into();
387 let mut cache = self.cache.lock().map_err(|_| QuantError::Internal {
388 message: "kernel cache lock poisoned".to_string(),
389 })?;
390 cache
391 .entry(tensor_type)
392 .or_insert_with(|| Arc::clone(&kernel));
393 Ok(kernel)
394 }
395
396 pub fn capabilities(&self) -> &SimdCapabilities {
398 &self.inner.capabilities
399 }
400
401 pub fn is_supported(&self, tensor_type: GgufTensorType) -> bool {
403 self.inner.is_supported(tensor_type)
404 }
405
406 pub fn supported_types(&self) -> Vec<GgufTensorType> {
408 self.inner.supported_types()
409 }
410}
411
412impl Default for CachedDispatcher {
413 fn default() -> Self {
414 Self::new()
415 }
416}
417
418static GLOBAL_DISPATCHER: OnceLock<CachedDispatcher> = OnceLock::new();
420
421pub fn global_dispatcher() -> &'static CachedDispatcher {
426 GLOBAL_DISPATCHER.get_or_init(CachedDispatcher::new)
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_simd_detection() {
435 let caps = SimdCapabilities::detect();
436 let tier = caps.best_tier();
438 assert!(!tier.is_empty());
439 }
440
441 #[test]
442 fn test_dispatcher_returns_all_supported() {
443 let dispatcher = KernelDispatcher::new();
444 for tensor_type in dispatcher.supported_types() {
445 let kernel = dispatcher.get_kernel(tensor_type);
446 assert!(kernel.is_ok(), "failed to get kernel for {:?}", tensor_type);
447 }
448 }
449
450 #[test]
451 fn test_dispatcher_unsupported() {
452 let dispatcher = KernelDispatcher::new();
453 assert!(dispatcher.is_supported(GgufTensorType::Tq1_0));
455 assert!(dispatcher.is_supported(GgufTensorType::Tq2_0));
456 assert!(!dispatcher.is_supported(GgufTensorType::Q4_0_4_4));
458 assert!(dispatcher.is_supported(GgufTensorType::Iq1S));
460 assert!(dispatcher.is_supported(GgufTensorType::Iq4Nl));
461 }
462
463 #[test]
464 fn test_cached_dispatcher_returns_same_kernel() {
465 let dispatcher = CachedDispatcher::new();
466 let k1 = dispatcher.get_kernel(GgufTensorType::Q4_0).expect("k1");
467 let k2 = dispatcher.get_kernel(GgufTensorType::Q4_0).expect("k2");
468 assert_eq!(k1.name(), k2.name());
469 assert!(
470 Arc::ptr_eq(&k1, &k2),
471 "second call should return cached Arc"
472 );
473 }
474
475 #[test]
476 fn test_global_dispatcher_singleton() {
477 let d1 = global_dispatcher();
478 let d2 = global_dispatcher();
479 assert!(
480 std::ptr::eq(d1, d2),
481 "global_dispatcher should return same instance"
482 );
483 }
484
485 #[test]
486 fn test_cached_dispatcher_all_types() {
487 let dispatcher = CachedDispatcher::new();
488 for t in dispatcher.supported_types() {
489 let k = dispatcher.get_kernel(t);
490 assert!(k.is_ok(), "cached dispatch failed for {:?}", t);
491 }
492 }
493}