1pub mod buffer;
29pub mod context;
30pub mod error;
31pub mod kernels;
32
33pub use context::GpuContext;
34pub use context::GpuDeviceInfo;
35pub use error::{GpuError, GpuResult};
36pub use kernels::sampling::SamplingKernel;
37pub use kernels::{
38 batched_gemv_f32, supports_f16, BatchedGemvConfig, BatchedGpuKernel, F16AccumulatorConfig,
39 FusedAttentionKernel, GpuKernel, Iq1MGpuKernel, Iq1SGpuKernel, Iq2SGpuKernel, Iq2XsGpuKernel,
40 Iq2XxsGpuKernel, Iq3SGpuKernel, Iq3XxsGpuKernel, Iq4NlGpuKernel, Iq4XsGpuKernel,
41 Q1_0_G128GpuKernel, Q2_KGpuKernel, Q3_KGpuKernel, Q4_0GpuKernel, Q4_1GpuKernel, Q4_KGpuKernel,
42 Q5_0GpuKernel, Q5_1GpuKernel, Q5_KGpuKernel, Q6_KGpuKernel, Q8_0GpuKernel, Q8_1GpuKernel,
43 Q8_KGpuKernel, TiledGemmKernel, Tq1_0GpuKernel, Tq2_0GpuKernel,
44};
45#[cfg(any(feature = "gpu", test))]
46pub use kernels::{dequant_q4_0_to_f16, dequant_q8_0_to_f16};
47#[cfg(feature = "gpu")]
48pub use kernels::{f16_gemv, upload_f16};
49
50use oxillama_gguf::GgufTensorType;
51
52pub struct GpuDispatcher {
59 ctx: Option<GpuContext>,
60}
61
62impl GpuDispatcher {
63 pub fn new() -> Self {
66 Self {
67 ctx: GpuContext::try_init(),
68 }
69 }
70
71 pub fn has_gpu(&self) -> bool {
73 self.ctx.is_some()
74 }
75
76 pub fn get_kernel(&self, tensor_type: GgufTensorType) -> Option<Box<dyn GpuKernel>> {
80 self.ctx.as_ref()?;
82
83 match tensor_type {
84 GgufTensorType::Q2K => Some(Box::new(Q2_KGpuKernel)),
85 GgufTensorType::Q3K => Some(Box::new(Q3_KGpuKernel)),
86 GgufTensorType::Q4_0 => Some(Box::new(Q4_0GpuKernel)),
87 GgufTensorType::Q4_1 => Some(Box::new(Q4_1GpuKernel)),
88 GgufTensorType::Q4K => Some(Box::new(Q4_KGpuKernel)),
89 GgufTensorType::Q5_0 => Some(Box::new(Q5_0GpuKernel)),
90 GgufTensorType::Q5_1 => Some(Box::new(Q5_1GpuKernel)),
91 GgufTensorType::Q5K => Some(Box::new(Q5_KGpuKernel)),
92 GgufTensorType::Q6K => Some(Box::new(Q6_KGpuKernel)),
93 GgufTensorType::Q8_0 => Some(Box::new(Q8_0GpuKernel)),
94 GgufTensorType::Q8_1 => Some(Box::new(Q8_1GpuKernel)),
95 GgufTensorType::Q8K => Some(Box::new(Q8_KGpuKernel)),
96 GgufTensorType::Q1_0G128 => Some(Box::new(Q1_0_G128GpuKernel)),
97 GgufTensorType::Iq4Xs => Some(Box::new(Iq4XsGpuKernel)),
98 GgufTensorType::Iq2Xxs => Some(Box::new(Iq2XxsGpuKernel)),
99 GgufTensorType::Iq2S => Some(Box::new(Iq2SGpuKernel)),
100 GgufTensorType::Iq2Xs => Some(Box::new(Iq2XsGpuKernel)),
101 GgufTensorType::Iq3Xxs => Some(Box::new(Iq3XxsGpuKernel)),
102 GgufTensorType::Iq3S => Some(Box::new(Iq3SGpuKernel)),
103 GgufTensorType::Iq1S => Some(Box::new(Iq1SGpuKernel)),
104 GgufTensorType::Iq1M => Some(Box::new(Iq1MGpuKernel)),
105 GgufTensorType::Iq4Nl => Some(Box::new(Iq4NlGpuKernel)),
106 GgufTensorType::Tq1_0 => Some(Box::new(Tq1_0GpuKernel)),
107 GgufTensorType::Tq2_0 => Some(Box::new(Tq2_0GpuKernel)),
108 _ => None,
109 }
110 }
111
112 pub fn context(&self) -> Option<&GpuContext> {
114 self.ctx.as_ref()
115 }
116
117 pub fn with_device_name(name: &str) -> Self {
120 Self {
121 ctx: GpuContext::try_init_with_name(name),
122 }
123 }
124
125 pub fn with_device_index(index: usize) -> Self {
128 Self {
129 ctx: GpuContext::try_init_with_index(index),
130 }
131 }
132
133 pub fn enumerate_devices() -> Vec<GpuDeviceInfo> {
135 GpuContext::enumerate_devices()
136 }
137}
138
139impl Default for GpuDispatcher {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
154 fn test_gpu_context_try_init_no_crash() {
155 let _ctx = GpuContext::try_init();
157 }
158
159 #[test]
160 fn test_gpu_dispatcher_new_no_crash() {
161 let dispatcher = GpuDispatcher::new();
162 let _ = dispatcher.has_gpu();
164 }
165
166 #[test]
167 fn test_gpu_dispatcher_default_no_crash() {
168 let _dispatcher = GpuDispatcher::default();
169 }
170
171 #[test]
172 fn test_gpu_dispatcher_no_kernel_for_f32() {
173 let dispatcher = GpuDispatcher::new();
174 let kernel = dispatcher.get_kernel(GgufTensorType::F32);
175 assert!(kernel.is_none(), "F32 should not have a GPU kernel");
176 }
177
178 #[test]
179 fn test_gpu_dispatcher_kernel_for_q4k_when_gpu() {
180 let dispatcher = GpuDispatcher::new();
181 let kernel = dispatcher.get_kernel(GgufTensorType::Q4K);
182 if dispatcher.has_gpu() {
183 assert!(
184 kernel.is_some(),
185 "Q4K should have a GPU kernel when GPU is present"
186 );
187 } else {
188 assert!(kernel.is_none(), "Q4K should not have a kernel without GPU");
189 }
190 }
191
192 #[test]
193 fn test_gpu_dispatcher_kernel_for_q5k_when_gpu() {
194 let dispatcher = GpuDispatcher::new();
195 let kernel = dispatcher.get_kernel(GgufTensorType::Q5K);
196 if dispatcher.has_gpu() {
197 assert!(
198 kernel.is_some(),
199 "Q5K should have a GPU kernel when GPU is present"
200 );
201 } else {
202 assert!(kernel.is_none(), "Q5K should not have a kernel without GPU");
203 }
204 }
205
206 #[test]
207 fn test_gpu_dispatcher_kernel_for_q6k_when_gpu() {
208 let dispatcher = GpuDispatcher::new();
209 let kernel = dispatcher.get_kernel(GgufTensorType::Q6K);
210 if dispatcher.has_gpu() {
211 assert!(
212 kernel.is_some(),
213 "Q6K should have a GPU kernel when GPU is present"
214 );
215 } else {
216 assert!(kernel.is_none(), "Q6K should not have a kernel without GPU");
217 }
218 }
219
220 #[test]
221 fn test_gpu_dispatcher_kernel_for_q2k_when_gpu() {
222 let dispatcher = GpuDispatcher::new();
223 let kernel = dispatcher.get_kernel(GgufTensorType::Q2K);
224 if dispatcher.has_gpu() {
225 assert!(
226 kernel.is_some(),
227 "Q2K should have a GPU kernel when GPU is present"
228 );
229 } else {
230 assert!(kernel.is_none(), "Q2K should not have a kernel without GPU");
231 }
232 }
233
234 #[test]
235 fn test_gpu_dispatcher_kernel_for_q3k_when_gpu() {
236 let dispatcher = GpuDispatcher::new();
237 let kernel = dispatcher.get_kernel(GgufTensorType::Q3K);
238 if dispatcher.has_gpu() {
239 assert!(
240 kernel.is_some(),
241 "Q3K should have a GPU kernel when GPU is present"
242 );
243 } else {
244 assert!(kernel.is_none(), "Q3K should not have a kernel without GPU");
245 }
246 }
247
248 #[test]
249 fn test_gpu_dispatcher_kernel_for_q8k_when_gpu() {
250 let dispatcher = GpuDispatcher::new();
251 let kernel = dispatcher.get_kernel(GgufTensorType::Q8K);
252 if dispatcher.has_gpu() {
253 assert!(
254 kernel.is_some(),
255 "Q8K should have a GPU kernel when GPU is present"
256 );
257 } else {
258 assert!(kernel.is_none(), "Q8K should not have a kernel without GPU");
259 }
260 }
261
262 #[test]
263 fn test_gpu_dispatcher_kernel_for_iq4xs_when_gpu() {
264 let dispatcher = GpuDispatcher::new();
265 let kernel = dispatcher.get_kernel(GgufTensorType::Iq4Xs);
266 if dispatcher.has_gpu() {
267 assert!(
268 kernel.is_some(),
269 "Iq4Xs should have a GPU kernel when GPU is present"
270 );
271 } else {
272 assert!(
273 kernel.is_none(),
274 "Iq4Xs should not have a kernel without GPU"
275 );
276 }
277 }
278
279 #[test]
280 fn test_gpu_dispatcher_kernel_for_iq2xxs_when_gpu() {
281 let dispatcher = GpuDispatcher::new();
282 let kernel = dispatcher.get_kernel(GgufTensorType::Iq2Xxs);
283 if dispatcher.has_gpu() {
284 assert!(
285 kernel.is_some(),
286 "Iq2Xxs should have a GPU kernel when GPU is present"
287 );
288 } else {
289 assert!(
290 kernel.is_none(),
291 "Iq2Xxs should not have a kernel without GPU"
292 );
293 }
294 }
295
296 #[test]
297 fn test_gpu_dispatcher_kernel_for_iq2s_when_gpu() {
298 let dispatcher = GpuDispatcher::new();
299 let kernel = dispatcher.get_kernel(GgufTensorType::Iq2S);
300 if dispatcher.has_gpu() {
301 assert!(
302 kernel.is_some(),
303 "Iq2S should have a GPU kernel when GPU is present"
304 );
305 } else {
306 assert!(
307 kernel.is_none(),
308 "Iq2S should not have a kernel without GPU"
309 );
310 }
311 }
312
313 #[test]
314 fn test_gpu_dispatcher_kernel_for_iq3xxs_when_gpu() {
315 let dispatcher = GpuDispatcher::new();
316 let kernel = dispatcher.get_kernel(GgufTensorType::Iq3Xxs);
317 if dispatcher.has_gpu() {
318 assert!(
319 kernel.is_some(),
320 "Iq3Xxs should have a GPU kernel when GPU is present"
321 );
322 } else {
323 assert!(
324 kernel.is_none(),
325 "Iq3Xxs should not have a kernel without GPU"
326 );
327 }
328 }
329
330 #[test]
331 fn test_gpu_dispatcher_kernel_for_iq3s_when_gpu() {
332 let dispatcher = GpuDispatcher::new();
333 let kernel = dispatcher.get_kernel(GgufTensorType::Iq3S);
334 if dispatcher.has_gpu() {
335 assert!(
336 kernel.is_some(),
337 "Iq3S should have a GPU kernel when GPU is present"
338 );
339 } else {
340 assert!(
341 kernel.is_none(),
342 "Iq3S should not have a kernel without GPU"
343 );
344 }
345 }
346
347 #[test]
348 fn test_gpu_dispatcher_kernel_for_q4_1_when_gpu() {
349 let dispatcher = GpuDispatcher::new();
350 let kernel = dispatcher.get_kernel(GgufTensorType::Q4_1);
351 if dispatcher.has_gpu() {
352 assert!(
353 kernel.is_some(),
354 "Q4_1 should have a GPU kernel when GPU is present"
355 );
356 } else {
357 assert!(
358 kernel.is_none(),
359 "Q4_1 should not have a kernel without GPU"
360 );
361 }
362 }
363
364 #[test]
365 fn test_gpu_dispatcher_kernel_for_q5_0_when_gpu() {
366 let dispatcher = GpuDispatcher::new();
367 let kernel = dispatcher.get_kernel(GgufTensorType::Q5_0);
368 if dispatcher.has_gpu() {
369 assert!(
370 kernel.is_some(),
371 "Q5_0 should have a GPU kernel when GPU is present"
372 );
373 } else {
374 assert!(
375 kernel.is_none(),
376 "Q5_0 should not have a kernel without GPU"
377 );
378 }
379 }
380
381 #[test]
382 fn test_gpu_dispatcher_kernel_for_q5_1_when_gpu() {
383 let dispatcher = GpuDispatcher::new();
384 let kernel = dispatcher.get_kernel(GgufTensorType::Q5_1);
385 if dispatcher.has_gpu() {
386 assert!(
387 kernel.is_some(),
388 "Q5_1 should have a GPU kernel when GPU is present"
389 );
390 } else {
391 assert!(
392 kernel.is_none(),
393 "Q5_1 should not have a kernel without GPU"
394 );
395 }
396 }
397
398 #[test]
399 fn test_gpu_dispatcher_kernel_for_q8_1_when_gpu() {
400 let dispatcher = GpuDispatcher::new();
401 let kernel = dispatcher.get_kernel(GgufTensorType::Q8_1);
402 if dispatcher.has_gpu() {
403 assert!(
404 kernel.is_some(),
405 "Q8_1 should have a GPU kernel when GPU is present"
406 );
407 } else {
408 assert!(
409 kernel.is_none(),
410 "Q8_1 should not have a kernel without GPU"
411 );
412 }
413 }
414
415 #[test]
416 fn test_gpu_error_display() {
417 let e = GpuError::NoAdapter;
418 assert!(!e.to_string().is_empty(), "error message must not be empty");
419 }
420
421 #[test]
422 fn test_gpu_error_buffer_size() {
423 let e = GpuError::BufferSize {
424 expected: 32,
425 got: 16,
426 };
427 let msg = e.to_string();
428 assert!(msg.contains("32"), "message should mention expected=32");
429 assert!(msg.contains("16"), "message should mention got=16");
430 }
431
432 #[test]
433 fn test_gpu_error_device_request() {
434 let e = GpuError::DeviceRequest("timeout".to_owned());
435 assert!(e.to_string().contains("timeout"));
436 }
437
438 #[test]
439 fn test_gpu_error_unsupported_type() {
440 let e = GpuError::UnsupportedType {
441 name: "Q6K".to_owned(),
442 };
443 assert!(e.to_string().contains("Q6K"));
444 }
445
446 #[test]
447 fn test_gpu_error_shader_compilation() {
448 let e = GpuError::ShaderCompilation {
449 detail: "parse error".to_owned(),
450 };
451 assert!(e.to_string().contains("parse error"));
452 }
453
454 #[test]
455 fn test_gpu_error_buffer_map() {
456 let e = GpuError::BufferMap {
457 detail: "lost".to_owned(),
458 };
459 assert!(e.to_string().contains("lost"));
460 }
461
462 #[test]
466 fn test_gpu_dispatcher_kernels_when_gpu_present() {
467 let dispatcher = GpuDispatcher::new();
468 if !dispatcher.has_gpu() {
469 return; }
471 assert!(
472 dispatcher.get_kernel(GgufTensorType::Q4_0).is_some(),
473 "Q4_0 kernel must be available when GPU is present"
474 );
475 assert!(
476 dispatcher.get_kernel(GgufTensorType::Q8_0).is_some(),
477 "Q8_0 kernel must be available when GPU is present"
478 );
479 }
480
481 #[cfg(feature = "gpu")]
484 #[test]
485 fn test_gpu_gemv_q4_0_matches_cpu() {
486 use crate::kernels::q4_0::Q4_0GpuKernel;
487
488 let ctx = match GpuContext::try_init() {
489 Some(c) => c,
490 None => return, };
492
493 let make_block = |scale: f32, first_nibble: u8| -> Vec<u8> {
497 let mut nibbles = [0x88u8; 16];
498 nibbles[0] = first_nibble; let mut block = Vec::with_capacity(18);
500 let d_bits = half::f16::from_f32(scale).to_bits();
501 block.extend_from_slice(&d_bits.to_le_bytes());
502 block.extend_from_slice(&nibbles);
503 block
504 };
505
506 let mut weight_bytes = Vec::new();
509 weight_bytes.extend_from_slice(&make_block(1.0, 0x8A)); weight_bytes.extend_from_slice(&make_block(0.5, 0x86)); let mut input = vec![1.0f32; 32];
514 input[0] = 3.0;
515
516 let expected = [6.0f32, -3.0f32];
519
520 let mut output = vec![0.0f32; 2];
521 let kernel = Q4_0GpuKernel;
522 kernel
523 .gemv(&ctx, &weight_bytes, &input, &mut output, 2, 32)
524 .expect("Q4_0 GPU GEMV");
525
526 for (i, (&got, &want)) in output.iter().zip(expected.iter()).enumerate() {
527 assert!(
528 (got - want).abs() < 1e-3,
529 "row {i}: got {got}, expected {want}"
530 );
531 }
532 }
533
534 #[cfg(feature = "gpu")]
536 #[test]
537 fn test_gpu_gemv_q8_0_matches_cpu() {
538 use crate::kernels::q8_0::Q8_0GpuKernel;
539
540 let ctx = match GpuContext::try_init() {
541 Some(c) => c,
542 None => return,
543 };
544
545 let make_block = |scale: f32, first_val: i8| -> Vec<u8> {
546 let mut vals = [0i8; 32];
547 vals[0] = first_val;
548 let mut block = Vec::with_capacity(34);
549 let d_bits = half::f16::from_f32(scale).to_bits();
550 block.extend_from_slice(&d_bits.to_le_bytes());
551 for &v in &vals {
552 block.push(v as u8);
553 }
554 block
555 };
556
557 let mut weight_bytes = Vec::new();
560 weight_bytes.extend_from_slice(&make_block(2.0, 3));
561 weight_bytes.extend_from_slice(&make_block(1.0, -4));
562
563 let mut input = vec![0.0f32; 32];
564 input[0] = 1.5;
565
566 let expected = [9.0f32, -6.0f32];
568
569 let mut output = vec![0.0f32; 2];
570 let kernel = Q8_0GpuKernel;
571 kernel
572 .gemv(&ctx, &weight_bytes, &input, &mut output, 2, 32)
573 .expect("Q8_0 GPU GEMV");
574
575 for (i, (&got, &want)) in output.iter().zip(expected.iter()).enumerate() {
576 assert!(
577 (got - want).abs() < 1e-3,
578 "row {i}: got {got}, expected {want}"
579 );
580 }
581 }
582
583 #[test]
586 fn test_gpu_dispatcher_kernel_for_q1_0_g128_when_gpu() {
587 let dispatcher = GpuDispatcher::new();
588 let kernel = dispatcher.get_kernel(GgufTensorType::Q1_0G128);
589 if dispatcher.has_gpu() {
590 assert!(
591 kernel.is_some(),
592 "Q1_0G128 should have a GPU kernel when GPU is present"
593 );
594 } else {
595 assert!(
596 kernel.is_none(),
597 "Q1_0G128 should not have a kernel without GPU"
598 );
599 }
600 }
601
602 #[cfg(feature = "gpu")]
604 #[test]
605 fn test_gpu_gemv_q1_0_g128_matches_cpu() {
606 use crate::kernels::q1_0_g128::Q1_0_G128GpuKernel;
607
608 let ctx = match GpuContext::try_init() {
609 Some(c) => c,
610 None => return, };
612
613 let make_block = |scale: f32, sign_bits: &[u8; 16]| -> Vec<u8> {
614 let mut block = Vec::with_capacity(18);
615 let d_bits = half::f16::from_f32(scale).to_bits();
616 block.extend_from_slice(&d_bits.to_le_bytes());
617 block.extend_from_slice(sign_bits);
618 block
619 };
620
621 let mut weight_bytes = Vec::new();
624 weight_bytes.extend_from_slice(&make_block(2.0, &[0xFF; 16]));
625 weight_bytes.extend_from_slice(&make_block(1.0, &[0x00; 16]));
626
627 let input = vec![1.0f32; 128];
629
630 let expected = [256.0f32, -128.0f32];
633
634 let mut output = vec![0.0f32; 2];
635 let kernel = Q1_0_G128GpuKernel;
636 kernel
637 .gemv(&ctx, &weight_bytes, &input, &mut output, 2, 128)
638 .expect("Q1_0_G128 GPU GEMV");
639
640 for (i, (&got, &want)) in output.iter().zip(expected.iter()).enumerate() {
641 assert!(
642 (got - want).abs() < 1e-1,
643 "row {i}: got {got}, expected {want}"
644 );
645 }
646 }
647
648 #[test]
651 fn test_enumerate_devices_no_panic() {
652 let devices = GpuDispatcher::enumerate_devices();
653 let _ = devices.len();
655 }
656
657 #[test]
658 fn test_enumerate_devices_from_context_no_panic() {
659 let devices = GpuContext::enumerate_devices();
660 let _ = devices.len();
661 }
662
663 #[test]
664 fn test_try_init_with_name_nonexistent_returns_none() {
665 let ctx = GpuContext::try_init_with_name("__nonexistent_gpu_xyz_999__");
666 assert!(ctx.is_none(), "Non-matching name pattern must return None");
667 }
668
669 #[test]
670 fn test_try_init_with_index_out_of_bounds_returns_none() {
671 let ctx = GpuContext::try_init_with_index(9999);
672 assert!(ctx.is_none(), "Out-of-bounds index must return None");
673 }
674
675 #[test]
676 fn test_dispatcher_with_device_name_nonexistent() {
677 let dispatcher = GpuDispatcher::with_device_name("__nonexistent_gpu_xyz_999__");
678 assert!(
679 !dispatcher.has_gpu(),
680 "Non-matching device name must yield no GPU"
681 );
682 }
683
684 #[test]
685 fn test_dispatcher_with_device_index_out_of_bounds() {
686 let dispatcher = GpuDispatcher::with_device_index(9999);
687 assert!(
688 !dispatcher.has_gpu(),
689 "Out-of-bounds index must yield no GPU"
690 );
691 }
692
693 #[test]
694 fn test_gpu_device_info_debug() {
695 let info = GpuDeviceInfo {
696 name: "Test GPU".to_owned(),
697 backend: "Vulkan".to_owned(),
698 device_type: "DiscreteGpu".to_owned(),
699 };
700 let debug_str = format!("{info:?}");
701 assert!(debug_str.contains("Test GPU"));
702 assert!(debug_str.contains("Vulkan"));
703 }
704
705 #[test]
706 fn test_gpu_device_info_clone() {
707 let info = GpuDeviceInfo {
708 name: "GPU".to_owned(),
709 backend: "Metal".to_owned(),
710 device_type: "IntegratedGpu".to_owned(),
711 };
712 let cloned = info.clone();
713 assert_eq!(cloned.name, info.name);
714 assert_eq!(cloned.backend, info.backend);
715 assert_eq!(cloned.device_type, info.device_type);
716 }
717}