Skip to main content

oxillama_gpu/
lib.rs

1//! # oxillama-gpu
2//!
3//! Optional wgpu-based GPU compute backend for OxiLLaMa.
4//!
5//! ## Feature flags
6//!
7//! | Feature | Description | Default |
8//! |---------|-------------|---------|
9//! | `gpu`   | Enable wgpu device init, buffer helpers, and WGSL shaders | No |
10//!
11//! When `gpu` is **disabled** (the default) this crate still compiles and all
12//! public types are available.  [`GpuContext::try_init`] returns `None` and
13//! [`GpuDispatcher::has_gpu`] returns `false`.
14//!
15//! ## Quick start
16//!
17//! ```rust
18//! use oxillama_gpu::{GpuDispatcher, GpuContext};
19//!
20//! let dispatcher = GpuDispatcher::new();
21//! if dispatcher.has_gpu() {
22//!     println!("GPU available — will use hardware acceleration");
23//! } else {
24//!     println!("No GPU — CPU fallback active");
25//! }
26//! ```
27
28pub mod buffer;
29pub mod context;
30pub mod error;
31pub mod kernels;
32
33pub use context::GpuContext;
34pub use context::GpuDeviceInfo;
35pub use error::{GpuError, GpuResult};
36pub use kernels::sampling::SamplingKernel;
37pub use kernels::{
38    batched_gemv_f32, supports_f16, BatchedGemvConfig, BatchedGpuKernel, F16AccumulatorConfig,
39    FusedAttentionKernel, GpuKernel, Iq1MGpuKernel, Iq1SGpuKernel, Iq2SGpuKernel, Iq2XsGpuKernel,
40    Iq2XxsGpuKernel, Iq3SGpuKernel, Iq3XxsGpuKernel, Iq4NlGpuKernel, Iq4XsGpuKernel,
41    Q1_0_G128GpuKernel, Q2_KGpuKernel, Q3_KGpuKernel, Q4_0GpuKernel, Q4_1GpuKernel, Q4_KGpuKernel,
42    Q5_0GpuKernel, Q5_1GpuKernel, Q5_KGpuKernel, Q6_KGpuKernel, Q8_0GpuKernel, Q8_1GpuKernel,
43    Q8_KGpuKernel, TiledGemmKernel, Tq1_0GpuKernel, Tq2_0GpuKernel,
44};
45#[cfg(any(feature = "gpu", test))]
46pub use kernels::{dequant_q4_0_to_f16, dequant_q8_0_to_f16};
47#[cfg(feature = "gpu")]
48pub use kernels::{f16_gemv, upload_f16};
49
50use oxillama_gguf::GgufTensorType;
51
52/// Central dispatcher that holds an optional [`GpuContext`] and vends
53/// GPU kernels for supported tensor types.
54///
55/// Construct with [`GpuDispatcher::new`].  The dispatcher performs GPU
56/// initialisation exactly once at construction time.  Kernel dispatch is
57/// then `O(1)` (a simple `match`).
58pub struct GpuDispatcher {
59    ctx: Option<GpuContext>,
60}
61
62impl GpuDispatcher {
63    /// Create a new dispatcher.  Attempts to initialise a GPU context; stores
64    /// `None` if no GPU is available or the `gpu` feature is disabled.
65    pub fn new() -> Self {
66        Self {
67            ctx: GpuContext::try_init(),
68        }
69    }
70
71    /// Returns `true` if a GPU context was successfully initialised.
72    pub fn has_gpu(&self) -> bool {
73        self.ctx.is_some()
74    }
75
76    /// Return a GPU kernel for the given tensor type, or `None` if:
77    /// - No GPU is available (`has_gpu() == false`).
78    /// - The tensor type has no GPU kernel implementation.
79    pub fn get_kernel(&self, tensor_type: GgufTensorType) -> Option<Box<dyn GpuKernel>> {
80        // No context → no kernel.
81        self.ctx.as_ref()?;
82
83        match tensor_type {
84            GgufTensorType::Q2K => Some(Box::new(Q2_KGpuKernel)),
85            GgufTensorType::Q3K => Some(Box::new(Q3_KGpuKernel)),
86            GgufTensorType::Q4_0 => Some(Box::new(Q4_0GpuKernel)),
87            GgufTensorType::Q4_1 => Some(Box::new(Q4_1GpuKernel)),
88            GgufTensorType::Q4K => Some(Box::new(Q4_KGpuKernel)),
89            GgufTensorType::Q5_0 => Some(Box::new(Q5_0GpuKernel)),
90            GgufTensorType::Q5_1 => Some(Box::new(Q5_1GpuKernel)),
91            GgufTensorType::Q5K => Some(Box::new(Q5_KGpuKernel)),
92            GgufTensorType::Q6K => Some(Box::new(Q6_KGpuKernel)),
93            GgufTensorType::Q8_0 => Some(Box::new(Q8_0GpuKernel)),
94            GgufTensorType::Q8_1 => Some(Box::new(Q8_1GpuKernel)),
95            GgufTensorType::Q8K => Some(Box::new(Q8_KGpuKernel)),
96            GgufTensorType::Q1_0G128 => Some(Box::new(Q1_0_G128GpuKernel)),
97            GgufTensorType::Iq4Xs => Some(Box::new(Iq4XsGpuKernel)),
98            GgufTensorType::Iq2Xxs => Some(Box::new(Iq2XxsGpuKernel)),
99            GgufTensorType::Iq2S => Some(Box::new(Iq2SGpuKernel)),
100            GgufTensorType::Iq2Xs => Some(Box::new(Iq2XsGpuKernel)),
101            GgufTensorType::Iq3Xxs => Some(Box::new(Iq3XxsGpuKernel)),
102            GgufTensorType::Iq3S => Some(Box::new(Iq3SGpuKernel)),
103            GgufTensorType::Iq1S => Some(Box::new(Iq1SGpuKernel)),
104            GgufTensorType::Iq1M => Some(Box::new(Iq1MGpuKernel)),
105            GgufTensorType::Iq4Nl => Some(Box::new(Iq4NlGpuKernel)),
106            GgufTensorType::Tq1_0 => Some(Box::new(Tq1_0GpuKernel)),
107            GgufTensorType::Tq2_0 => Some(Box::new(Tq2_0GpuKernel)),
108            _ => None,
109        }
110    }
111
112    /// Return a reference to the underlying [`GpuContext`], if one exists.
113    pub fn context(&self) -> Option<&GpuContext> {
114        self.ctx.as_ref()
115    }
116
117    /// Create a dispatcher selecting a GPU adapter by name substring
118    /// (case-insensitive).  Falls back to no-GPU if no adapter matches.
119    pub fn with_device_name(name: &str) -> Self {
120        Self {
121            ctx: GpuContext::try_init_with_name(name),
122        }
123    }
124
125    /// Create a dispatcher selecting a GPU adapter by index.
126    /// Falls back to no-GPU if the index is out of bounds.
127    pub fn with_device_index(index: usize) -> Self {
128        Self {
129            ctx: GpuContext::try_init_with_index(index),
130        }
131    }
132
133    /// Enumerate available GPU adapters.
134    pub fn enumerate_devices() -> Vec<GpuDeviceInfo> {
135        GpuContext::enumerate_devices()
136    }
137}
138
139impl Default for GpuDispatcher {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145// ─── Tests ────────────────────────────────────────────────────────────────────
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    // ── Basic smoke tests (always run, even without GPU) ─────────────────────
152
153    #[test]
154    fn test_gpu_context_try_init_no_crash() {
155        // Must not panic regardless of whether a GPU is present.
156        let _ctx = GpuContext::try_init();
157    }
158
159    #[test]
160    fn test_gpu_dispatcher_new_no_crash() {
161        let dispatcher = GpuDispatcher::new();
162        // has_gpu() may be false in CI — that is fine.
163        let _ = dispatcher.has_gpu();
164    }
165
166    #[test]
167    fn test_gpu_dispatcher_default_no_crash() {
168        let _dispatcher = GpuDispatcher::default();
169    }
170
171    #[test]
172    fn test_gpu_dispatcher_no_kernel_for_f32() {
173        let dispatcher = GpuDispatcher::new();
174        let kernel = dispatcher.get_kernel(GgufTensorType::F32);
175        assert!(kernel.is_none(), "F32 should not have a GPU kernel");
176    }
177
178    #[test]
179    fn test_gpu_dispatcher_kernel_for_q4k_when_gpu() {
180        let dispatcher = GpuDispatcher::new();
181        let kernel = dispatcher.get_kernel(GgufTensorType::Q4K);
182        if dispatcher.has_gpu() {
183            assert!(
184                kernel.is_some(),
185                "Q4K should have a GPU kernel when GPU is present"
186            );
187        } else {
188            assert!(kernel.is_none(), "Q4K should not have a kernel without GPU");
189        }
190    }
191
192    #[test]
193    fn test_gpu_dispatcher_kernel_for_q5k_when_gpu() {
194        let dispatcher = GpuDispatcher::new();
195        let kernel = dispatcher.get_kernel(GgufTensorType::Q5K);
196        if dispatcher.has_gpu() {
197            assert!(
198                kernel.is_some(),
199                "Q5K should have a GPU kernel when GPU is present"
200            );
201        } else {
202            assert!(kernel.is_none(), "Q5K should not have a kernel without GPU");
203        }
204    }
205
206    #[test]
207    fn test_gpu_dispatcher_kernel_for_q6k_when_gpu() {
208        let dispatcher = GpuDispatcher::new();
209        let kernel = dispatcher.get_kernel(GgufTensorType::Q6K);
210        if dispatcher.has_gpu() {
211            assert!(
212                kernel.is_some(),
213                "Q6K should have a GPU kernel when GPU is present"
214            );
215        } else {
216            assert!(kernel.is_none(), "Q6K should not have a kernel without GPU");
217        }
218    }
219
220    #[test]
221    fn test_gpu_dispatcher_kernel_for_q2k_when_gpu() {
222        let dispatcher = GpuDispatcher::new();
223        let kernel = dispatcher.get_kernel(GgufTensorType::Q2K);
224        if dispatcher.has_gpu() {
225            assert!(
226                kernel.is_some(),
227                "Q2K should have a GPU kernel when GPU is present"
228            );
229        } else {
230            assert!(kernel.is_none(), "Q2K should not have a kernel without GPU");
231        }
232    }
233
234    #[test]
235    fn test_gpu_dispatcher_kernel_for_q3k_when_gpu() {
236        let dispatcher = GpuDispatcher::new();
237        let kernel = dispatcher.get_kernel(GgufTensorType::Q3K);
238        if dispatcher.has_gpu() {
239            assert!(
240                kernel.is_some(),
241                "Q3K should have a GPU kernel when GPU is present"
242            );
243        } else {
244            assert!(kernel.is_none(), "Q3K should not have a kernel without GPU");
245        }
246    }
247
248    #[test]
249    fn test_gpu_dispatcher_kernel_for_q8k_when_gpu() {
250        let dispatcher = GpuDispatcher::new();
251        let kernel = dispatcher.get_kernel(GgufTensorType::Q8K);
252        if dispatcher.has_gpu() {
253            assert!(
254                kernel.is_some(),
255                "Q8K should have a GPU kernel when GPU is present"
256            );
257        } else {
258            assert!(kernel.is_none(), "Q8K should not have a kernel without GPU");
259        }
260    }
261
262    #[test]
263    fn test_gpu_dispatcher_kernel_for_iq4xs_when_gpu() {
264        let dispatcher = GpuDispatcher::new();
265        let kernel = dispatcher.get_kernel(GgufTensorType::Iq4Xs);
266        if dispatcher.has_gpu() {
267            assert!(
268                kernel.is_some(),
269                "Iq4Xs should have a GPU kernel when GPU is present"
270            );
271        } else {
272            assert!(
273                kernel.is_none(),
274                "Iq4Xs should not have a kernel without GPU"
275            );
276        }
277    }
278
279    #[test]
280    fn test_gpu_dispatcher_kernel_for_iq2xxs_when_gpu() {
281        let dispatcher = GpuDispatcher::new();
282        let kernel = dispatcher.get_kernel(GgufTensorType::Iq2Xxs);
283        if dispatcher.has_gpu() {
284            assert!(
285                kernel.is_some(),
286                "Iq2Xxs should have a GPU kernel when GPU is present"
287            );
288        } else {
289            assert!(
290                kernel.is_none(),
291                "Iq2Xxs should not have a kernel without GPU"
292            );
293        }
294    }
295
296    #[test]
297    fn test_gpu_dispatcher_kernel_for_iq2s_when_gpu() {
298        let dispatcher = GpuDispatcher::new();
299        let kernel = dispatcher.get_kernel(GgufTensorType::Iq2S);
300        if dispatcher.has_gpu() {
301            assert!(
302                kernel.is_some(),
303                "Iq2S should have a GPU kernel when GPU is present"
304            );
305        } else {
306            assert!(
307                kernel.is_none(),
308                "Iq2S should not have a kernel without GPU"
309            );
310        }
311    }
312
313    #[test]
314    fn test_gpu_dispatcher_kernel_for_iq3xxs_when_gpu() {
315        let dispatcher = GpuDispatcher::new();
316        let kernel = dispatcher.get_kernel(GgufTensorType::Iq3Xxs);
317        if dispatcher.has_gpu() {
318            assert!(
319                kernel.is_some(),
320                "Iq3Xxs should have a GPU kernel when GPU is present"
321            );
322        } else {
323            assert!(
324                kernel.is_none(),
325                "Iq3Xxs should not have a kernel without GPU"
326            );
327        }
328    }
329
330    #[test]
331    fn test_gpu_dispatcher_kernel_for_iq3s_when_gpu() {
332        let dispatcher = GpuDispatcher::new();
333        let kernel = dispatcher.get_kernel(GgufTensorType::Iq3S);
334        if dispatcher.has_gpu() {
335            assert!(
336                kernel.is_some(),
337                "Iq3S should have a GPU kernel when GPU is present"
338            );
339        } else {
340            assert!(
341                kernel.is_none(),
342                "Iq3S should not have a kernel without GPU"
343            );
344        }
345    }
346
347    #[test]
348    fn test_gpu_dispatcher_kernel_for_q4_1_when_gpu() {
349        let dispatcher = GpuDispatcher::new();
350        let kernel = dispatcher.get_kernel(GgufTensorType::Q4_1);
351        if dispatcher.has_gpu() {
352            assert!(
353                kernel.is_some(),
354                "Q4_1 should have a GPU kernel when GPU is present"
355            );
356        } else {
357            assert!(
358                kernel.is_none(),
359                "Q4_1 should not have a kernel without GPU"
360            );
361        }
362    }
363
364    #[test]
365    fn test_gpu_dispatcher_kernel_for_q5_0_when_gpu() {
366        let dispatcher = GpuDispatcher::new();
367        let kernel = dispatcher.get_kernel(GgufTensorType::Q5_0);
368        if dispatcher.has_gpu() {
369            assert!(
370                kernel.is_some(),
371                "Q5_0 should have a GPU kernel when GPU is present"
372            );
373        } else {
374            assert!(
375                kernel.is_none(),
376                "Q5_0 should not have a kernel without GPU"
377            );
378        }
379    }
380
381    #[test]
382    fn test_gpu_dispatcher_kernel_for_q5_1_when_gpu() {
383        let dispatcher = GpuDispatcher::new();
384        let kernel = dispatcher.get_kernel(GgufTensorType::Q5_1);
385        if dispatcher.has_gpu() {
386            assert!(
387                kernel.is_some(),
388                "Q5_1 should have a GPU kernel when GPU is present"
389            );
390        } else {
391            assert!(
392                kernel.is_none(),
393                "Q5_1 should not have a kernel without GPU"
394            );
395        }
396    }
397
398    #[test]
399    fn test_gpu_dispatcher_kernel_for_q8_1_when_gpu() {
400        let dispatcher = GpuDispatcher::new();
401        let kernel = dispatcher.get_kernel(GgufTensorType::Q8_1);
402        if dispatcher.has_gpu() {
403            assert!(
404                kernel.is_some(),
405                "Q8_1 should have a GPU kernel when GPU is present"
406            );
407        } else {
408            assert!(
409                kernel.is_none(),
410                "Q8_1 should not have a kernel without GPU"
411            );
412        }
413    }
414
415    #[test]
416    fn test_gpu_error_display() {
417        let e = GpuError::NoAdapter;
418        assert!(!e.to_string().is_empty(), "error message must not be empty");
419    }
420
421    #[test]
422    fn test_gpu_error_buffer_size() {
423        let e = GpuError::BufferSize {
424            expected: 32,
425            got: 16,
426        };
427        let msg = e.to_string();
428        assert!(msg.contains("32"), "message should mention expected=32");
429        assert!(msg.contains("16"), "message should mention got=16");
430    }
431
432    #[test]
433    fn test_gpu_error_device_request() {
434        let e = GpuError::DeviceRequest("timeout".to_owned());
435        assert!(e.to_string().contains("timeout"));
436    }
437
438    #[test]
439    fn test_gpu_error_unsupported_type() {
440        let e = GpuError::UnsupportedType {
441            name: "Q6K".to_owned(),
442        };
443        assert!(e.to_string().contains("Q6K"));
444    }
445
446    #[test]
447    fn test_gpu_error_shader_compilation() {
448        let e = GpuError::ShaderCompilation {
449            detail: "parse error".to_owned(),
450        };
451        assert!(e.to_string().contains("parse error"));
452    }
453
454    #[test]
455    fn test_gpu_error_buffer_map() {
456        let e = GpuError::BufferMap {
457            detail: "lost".to_owned(),
458        };
459        assert!(e.to_string().contains("lost"));
460    }
461
462    // ── GPU-available tests (skip gracefully when no GPU) ────────────────────
463
464    /// When a GPU is available, Q4_0 and Q8_0 kernels must be returned.
465    #[test]
466    fn test_gpu_dispatcher_kernels_when_gpu_present() {
467        let dispatcher = GpuDispatcher::new();
468        if !dispatcher.has_gpu() {
469            return; // CI — no GPU
470        }
471        assert!(
472            dispatcher.get_kernel(GgufTensorType::Q4_0).is_some(),
473            "Q4_0 kernel must be available when GPU is present"
474        );
475        assert!(
476            dispatcher.get_kernel(GgufTensorType::Q8_0).is_some(),
477            "Q8_0 kernel must be available when GPU is present"
478        );
479    }
480
481    /// Full end-to-end Q4_0 GEMV: GPU result must match CPU dequant+dot to
482    /// within 1e-4 absolute tolerance.
483    #[cfg(feature = "gpu")]
484    #[test]
485    fn test_gpu_gemv_q4_0_matches_cpu() {
486        use crate::kernels::q4_0::Q4_0GpuKernel;
487
488        let ctx = match GpuContext::try_init() {
489            Some(c) => c,
490            None => return, // skip if no GPU
491        };
492
493        // Two Q4_0 blocks (rows=2, cols=32).
494        // Nibble layout: 0x88 → lo=0, hi=8 after -8 bias, so all values are 0.
495        // But let's make the first nibble of each row non-zero.
496        let make_block = |scale: f32, first_nibble: u8| -> Vec<u8> {
497            let mut nibbles = [0x88u8; 16];
498            nibbles[0] = first_nibble; // lo byte of pair 0
499            let mut block = Vec::with_capacity(18);
500            let d_bits = half::f16::from_f32(scale).to_bits();
501            block.extend_from_slice(&d_bits.to_le_bytes());
502            block.extend_from_slice(&nibbles);
503            block
504        };
505
506        // Row 0: scale=1.0, first nibble lo=0xA (10-8=2), hi=0x8 (0)
507        // Row 1: scale=0.5, first nibble lo=0x6 (6-8=-2), hi=0x8 (0)
508        let mut weight_bytes = Vec::new();
509        weight_bytes.extend_from_slice(&make_block(1.0, 0x8A)); // lo=A=10→+2, hi=8→0
510        weight_bytes.extend_from_slice(&make_block(0.5, 0x86)); // lo=6→-2, hi=8→0
511
512        // input: all 1.0 except index 0 = 3.0
513        let mut input = vec![1.0f32; 32];
514        input[0] = 3.0;
515
516        // CPU reference: row0 = 2.0*3.0 + 0 = 6.0; row1 = -1.0*3.0 = -3.0
517        // (scale*lo * input[0], rest are 0)
518        let expected = [6.0f32, -3.0f32];
519
520        let mut output = vec![0.0f32; 2];
521        let kernel = Q4_0GpuKernel;
522        kernel
523            .gemv(&ctx, &weight_bytes, &input, &mut output, 2, 32)
524            .expect("Q4_0 GPU GEMV");
525
526        for (i, (&got, &want)) in output.iter().zip(expected.iter()).enumerate() {
527            assert!(
528                (got - want).abs() < 1e-3,
529                "row {i}: got {got}, expected {want}"
530            );
531        }
532    }
533
534    /// Full end-to-end Q8_0 GEMV.
535    #[cfg(feature = "gpu")]
536    #[test]
537    fn test_gpu_gemv_q8_0_matches_cpu() {
538        use crate::kernels::q8_0::Q8_0GpuKernel;
539
540        let ctx = match GpuContext::try_init() {
541            Some(c) => c,
542            None => return,
543        };
544
545        let make_block = |scale: f32, first_val: i8| -> Vec<u8> {
546            let mut vals = [0i8; 32];
547            vals[0] = first_val;
548            let mut block = Vec::with_capacity(34);
549            let d_bits = half::f16::from_f32(scale).to_bits();
550            block.extend_from_slice(&d_bits.to_le_bytes());
551            for &v in &vals {
552                block.push(v as u8);
553            }
554            block
555        };
556
557        // Row 0: scale=2.0, q[0]=3  → weight[0][0] = 6.0
558        // Row 1: scale=1.0, q[0]=-4 → weight[1][0] = -4.0
559        let mut weight_bytes = Vec::new();
560        weight_bytes.extend_from_slice(&make_block(2.0, 3));
561        weight_bytes.extend_from_slice(&make_block(1.0, -4));
562
563        let mut input = vec![0.0f32; 32];
564        input[0] = 1.5;
565
566        // row0 = 6.0*1.5 = 9.0; row1 = -4.0*1.5 = -6.0
567        let expected = [9.0f32, -6.0f32];
568
569        let mut output = vec![0.0f32; 2];
570        let kernel = Q8_0GpuKernel;
571        kernel
572            .gemv(&ctx, &weight_bytes, &input, &mut output, 2, 32)
573            .expect("Q8_0 GPU GEMV");
574
575        for (i, (&got, &want)) in output.iter().zip(expected.iter()).enumerate() {
576            assert!(
577                (got - want).abs() < 1e-3,
578                "row {i}: got {got}, expected {want}"
579            );
580        }
581    }
582
583    // ── Q1_0_G128 GPU tests ─────────────────────────────────────────────────
584
585    #[test]
586    fn test_gpu_dispatcher_kernel_for_q1_0_g128_when_gpu() {
587        let dispatcher = GpuDispatcher::new();
588        let kernel = dispatcher.get_kernel(GgufTensorType::Q1_0G128);
589        if dispatcher.has_gpu() {
590            assert!(
591                kernel.is_some(),
592                "Q1_0G128 should have a GPU kernel when GPU is present"
593            );
594        } else {
595            assert!(
596                kernel.is_none(),
597                "Q1_0G128 should not have a kernel without GPU"
598            );
599        }
600    }
601
602    /// Full end-to-end Q1_0_G128 GEMV: GPU result must match CPU dequant+dot.
603    #[cfg(feature = "gpu")]
604    #[test]
605    fn test_gpu_gemv_q1_0_g128_matches_cpu() {
606        use crate::kernels::q1_0_g128::Q1_0_G128GpuKernel;
607
608        let ctx = match GpuContext::try_init() {
609            Some(c) => c,
610            None => return, // skip if no GPU
611        };
612
613        let make_block = |scale: f32, sign_bits: &[u8; 16]| -> Vec<u8> {
614            let mut block = Vec::with_capacity(18);
615            let d_bits = half::f16::from_f32(scale).to_bits();
616            block.extend_from_slice(&d_bits.to_le_bytes());
617            block.extend_from_slice(sign_bits);
618            block
619        };
620
621        // Row 0: scale=2.0, all bits=1 → all weights = +2.0
622        // Row 1: scale=1.0, all bits=0 → all weights = -1.0
623        let mut weight_bytes = Vec::new();
624        weight_bytes.extend_from_slice(&make_block(2.0, &[0xFF; 16]));
625        weight_bytes.extend_from_slice(&make_block(1.0, &[0x00; 16]));
626
627        // input: all 1.0
628        let input = vec![1.0f32; 128];
629
630        // row0 = sum(2.0 * 1.0) for 128 weights = 256.0
631        // row1 = sum(-1.0 * 1.0) for 128 weights = -128.0
632        let expected = [256.0f32, -128.0f32];
633
634        let mut output = vec![0.0f32; 2];
635        let kernel = Q1_0_G128GpuKernel;
636        kernel
637            .gemv(&ctx, &weight_bytes, &input, &mut output, 2, 128)
638            .expect("Q1_0_G128 GPU GEMV");
639
640        for (i, (&got, &want)) in output.iter().zip(expected.iter()).enumerate() {
641            assert!(
642                (got - want).abs() < 1e-1,
643                "row {i}: got {got}, expected {want}"
644            );
645        }
646    }
647
648    // ── Device selection tests ───────────────────────────────────────────────
649
650    #[test]
651    fn test_enumerate_devices_no_panic() {
652        let devices = GpuDispatcher::enumerate_devices();
653        // May be empty in CI — just checking it doesn't panic.
654        let _ = devices.len();
655    }
656
657    #[test]
658    fn test_enumerate_devices_from_context_no_panic() {
659        let devices = GpuContext::enumerate_devices();
660        let _ = devices.len();
661    }
662
663    #[test]
664    fn test_try_init_with_name_nonexistent_returns_none() {
665        let ctx = GpuContext::try_init_with_name("__nonexistent_gpu_xyz_999__");
666        assert!(ctx.is_none(), "Non-matching name pattern must return None");
667    }
668
669    #[test]
670    fn test_try_init_with_index_out_of_bounds_returns_none() {
671        let ctx = GpuContext::try_init_with_index(9999);
672        assert!(ctx.is_none(), "Out-of-bounds index must return None");
673    }
674
675    #[test]
676    fn test_dispatcher_with_device_name_nonexistent() {
677        let dispatcher = GpuDispatcher::with_device_name("__nonexistent_gpu_xyz_999__");
678        assert!(
679            !dispatcher.has_gpu(),
680            "Non-matching device name must yield no GPU"
681        );
682    }
683
684    #[test]
685    fn test_dispatcher_with_device_index_out_of_bounds() {
686        let dispatcher = GpuDispatcher::with_device_index(9999);
687        assert!(
688            !dispatcher.has_gpu(),
689            "Out-of-bounds index must yield no GPU"
690        );
691    }
692
693    #[test]
694    fn test_gpu_device_info_debug() {
695        let info = GpuDeviceInfo {
696            name: "Test GPU".to_owned(),
697            backend: "Vulkan".to_owned(),
698            device_type: "DiscreteGpu".to_owned(),
699        };
700        let debug_str = format!("{info:?}");
701        assert!(debug_str.contains("Test GPU"));
702        assert!(debug_str.contains("Vulkan"));
703    }
704
705    #[test]
706    fn test_gpu_device_info_clone() {
707        let info = GpuDeviceInfo {
708            name: "GPU".to_owned(),
709            backend: "Metal".to_owned(),
710            device_type: "IntegratedGpu".to_owned(),
711        };
712        let cloned = info.clone();
713        assert_eq!(cloned.name, info.name);
714        assert_eq!(cloned.backend, info.backend);
715        assert_eq!(cloned.device_type, info.device_type);
716    }
717}