Skip to main content

oxicuda_levelzero/
backend.rs

1//! [`LevelZeroBackend`] — the main entry point for the oxicuda-levelzero crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! Intel's Level Zero API for GPU compute on Linux and Windows.
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11
12use crate::{device::LevelZeroDevice, memory::LevelZeroMemoryManager};
13
14// ─── Backend struct ───────────────────────────────────────────────────────────
15
16/// Intel Level Zero GPU compute backend.
17///
18/// On Linux and Windows this selects the first Intel GPU via the Level Zero
19/// loader library (`libze_loader.so` / `ze_loader.dll`) and allocates device
20/// memory through the Level Zero memory model.
21///
22/// On macOS every operation returns [`BackendError::DeviceError`] wrapping
23/// [`crate::error::LevelZeroError::UnsupportedPlatform`].
24///
25/// # Lifecycle
26///
27/// 1. `LevelZeroBackend::new()` — create an uninitialised backend.
28/// 2. `init()` — load the Level Zero driver and select a GPU.
29/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
30/// 4. `synchronize()` — wait for all pending GPU work to finish.
31#[derive(Debug)]
32pub struct LevelZeroBackend {
33    device: Option<Arc<LevelZeroDevice>>,
34    memory: Option<Arc<LevelZeroMemoryManager>>,
35    initialized: bool,
36}
37
38impl LevelZeroBackend {
39    /// Create a new, uninitialised Level Zero backend.
40    pub fn new() -> Self {
41        Self {
42            device: None,
43            memory: None,
44            initialized: false,
45        }
46    }
47
48    /// Return an error if the backend has not been initialised yet.
49    fn check_init(&self) -> BackendResult<()> {
50        if self.initialized {
51            Ok(())
52        } else {
53            Err(BackendError::NotInitialized)
54        }
55    }
56
57    /// Convenience accessor: get the memory manager or return `NotInitialized`.
58    fn memory(&self) -> BackendResult<&Arc<LevelZeroMemoryManager>> {
59        self.memory.as_ref().ok_or(BackendError::NotInitialized)
60    }
61}
62
63impl Default for LevelZeroBackend {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69// ─── ComputeBackend impl ──────────────────────────────────────────────────────
70
71impl ComputeBackend for LevelZeroBackend {
72    fn name(&self) -> &str {
73        "level-zero"
74    }
75
76    fn init(&mut self) -> BackendResult<()> {
77        if self.initialized {
78            return Ok(());
79        }
80        match LevelZeroDevice::new() {
81            Ok(dev) => {
82                let dev = Arc::new(dev);
83                tracing::info!("Level Zero backend initialised on: {}", dev.name());
84                let memory = LevelZeroMemoryManager::new(Arc::clone(&dev));
85                self.device = Some(dev);
86                self.memory = Some(Arc::new(memory));
87                self.initialized = true;
88                Ok(())
89            }
90            Err(e) => Err(BackendError::from(e)),
91        }
92    }
93
94    fn is_initialized(&self) -> bool {
95        self.initialized
96    }
97
98    // ── Compute operations ────────────────────────────────────────────────────
99
100    fn gemm(
101        &self,
102        _trans_a: BackendTranspose,
103        _trans_b: BackendTranspose,
104        m: usize,
105        n: usize,
106        k: usize,
107        alpha: f64,
108        a_ptr: u64,
109        _lda: usize,
110        b_ptr: u64,
111        _ldb: usize,
112        beta: f64,
113        c_ptr: u64,
114        _ldc: usize,
115    ) -> BackendResult<()> {
116        self.check_init()?;
117        if m == 0 || n == 0 || k == 0 {
118            return Ok(());
119        }
120        self.dispatch_gemm(m, n, k, alpha as f32, a_ptr, b_ptr, beta as f32, c_ptr)
121    }
122
123    fn conv2d_forward(
124        &self,
125        input_ptr: u64,
126        input_shape: &[usize],
127        filter_ptr: u64,
128        filter_shape: &[usize],
129        output_ptr: u64,
130        output_shape: &[usize],
131        stride: &[usize],
132        padding: &[usize],
133    ) -> BackendResult<()> {
134        self.check_init()?;
135
136        if input_shape.len() != 4 {
137            return Err(BackendError::InvalidArgument(
138                "input_shape must have 4 elements (NCHW)".into(),
139            ));
140        }
141        if filter_shape.len() != 4 {
142            return Err(BackendError::InvalidArgument(
143                "filter_shape must have 4 elements (KCFHFW)".into(),
144            ));
145        }
146        if output_shape.len() != 4 {
147            return Err(BackendError::InvalidArgument(
148                "output_shape must have 4 elements (NKOhOw)".into(),
149            ));
150        }
151        if stride.len() != 2 {
152            return Err(BackendError::InvalidArgument(
153                "stride must have 2 elements [sh, sw]".into(),
154            ));
155        }
156        if padding.len() != 2 {
157            return Err(BackendError::InvalidArgument(
158                "padding must have 2 elements [ph, pw]".into(),
159            ));
160        }
161
162        let n = input_shape[0];
163        let c_in = input_shape[1];
164        let h_in = input_shape[2];
165        let w_in = input_shape[3];
166        let k_out = filter_shape[0];
167        let fh = filter_shape[2];
168        let fw = filter_shape[3];
169        let o_h = output_shape[2];
170        let o_w = output_shape[3];
171        let stride_h = stride[0];
172        let stride_w = stride[1];
173        let pad_h = padding[0];
174        let pad_w = padding[1];
175
176        // CPU fallback: copy input + filter from device
177        let in_len = n * c_in * h_in * w_in;
178        let flt_len = k_out * c_in * fh * fw;
179        let out_len = n * k_out * o_h * o_w;
180
181        let mut in_bytes = vec![0u8; in_len * 4];
182        self.copy_dtoh(&mut in_bytes, input_ptr)?;
183        let inp: Vec<f32> = in_bytes
184            .chunks_exact(4)
185            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
186            .collect();
187
188        let mut flt_bytes = vec![0u8; flt_len * 4];
189        self.copy_dtoh(&mut flt_bytes, filter_ptr)?;
190        let flt: Vec<f32> = flt_bytes
191            .chunks_exact(4)
192            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
193            .collect();
194
195        // NCHW convolution
196        let mut out = vec![0.0f32; out_len];
197        for b_idx in 0..n {
198            for kf in 0..k_out {
199                for oy in 0..o_h {
200                    for ox in 0..o_w {
201                        let mut acc = 0.0f32;
202                        for ci in 0..c_in {
203                            for fy in 0..fh {
204                                for fx in 0..fw {
205                                    let iy = (oy * stride_h + fy) as isize - pad_h as isize;
206                                    let ix = (ox * stride_w + fx) as isize - pad_w as isize;
207                                    if iy >= 0
208                                        && (iy as usize) < h_in
209                                        && ix >= 0
210                                        && (ix as usize) < w_in
211                                    {
212                                        let iy = iy as usize;
213                                        let ix = ix as usize;
214                                        acc += inp[((b_idx * c_in + ci) * h_in + iy) * w_in + ix]
215                                            * flt[((kf * c_in + ci) * fh + fy) * fw + fx];
216                                    }
217                                }
218                            }
219                        }
220                        out[((b_idx * k_out + kf) * o_h + oy) * o_w + ox] = acc;
221                    }
222                }
223            }
224        }
225
226        let out_bytes: Vec<u8> = out.iter().flat_map(|f| f.to_ne_bytes()).collect();
227        self.copy_htod(output_ptr, &out_bytes)
228    }
229
230    fn attention(
231        &self,
232        q_ptr: u64,
233        k_ptr: u64,
234        v_ptr: u64,
235        o_ptr: u64,
236        batch: usize,
237        heads: usize,
238        seq_q: usize,
239        seq_kv: usize,
240        head_dim: usize,
241        scale: f64,
242        causal: bool,
243    ) -> BackendResult<()> {
244        self.check_init()?;
245
246        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
247            return Err(BackendError::InvalidArgument(
248                "seq_q, seq_kv, and head_dim must all be > 0".into(),
249            ));
250        }
251        if scale <= 0.0 || !scale.is_finite() {
252            return Err(BackendError::InvalidArgument(format!(
253                "scale must be a positive finite number, got {scale}"
254            )));
255        }
256
257        let batch_heads = batch * heads;
258        let scale_f32 = scale as f32;
259
260        // CPU fallback: copy Q, K, V from device
261        let q_len = batch_heads * seq_q * head_dim;
262        let kv_len = batch_heads * seq_kv * head_dim;
263
264        let mut q_bytes = vec![0u8; q_len * 4];
265        self.copy_dtoh(&mut q_bytes, q_ptr)?;
266        let q: Vec<f32> = q_bytes
267            .chunks_exact(4)
268            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
269            .collect();
270
271        let mut k_bytes = vec![0u8; kv_len * 4];
272        self.copy_dtoh(&mut k_bytes, k_ptr)?;
273        let k: Vec<f32> = k_bytes
274            .chunks_exact(4)
275            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
276            .collect();
277
278        let mut v_bytes = vec![0u8; kv_len * 4];
279        self.copy_dtoh(&mut v_bytes, v_ptr)?;
280        let v: Vec<f32> = v_bytes
281            .chunks_exact(4)
282            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
283            .collect();
284
285        // Numerically-stable scaled dot-product attention
286        let mut output = vec![0.0f32; q_len];
287
288        for bh in 0..batch_heads {
289            for sq in 0..seq_q {
290                let q_off = (bh * seq_q + sq) * head_dim;
291                let o_off = q_off;
292
293                // Pass 1: find max score
294                let mut max_score = f32::NEG_INFINITY;
295                for sk in 0..seq_kv {
296                    if causal && sk > sq {
297                        continue;
298                    }
299                    let k_off = (bh * seq_kv + sk) * head_dim;
300                    let mut dot = 0.0f32;
301                    for d in 0..head_dim {
302                        dot += q[q_off + d] * k[k_off + d];
303                    }
304                    let score = dot * scale_f32;
305                    if score > max_score {
306                        max_score = score;
307                    }
308                }
309
310                if max_score == f32::NEG_INFINITY {
311                    max_score = 0.0;
312                }
313
314                // Pass 2: accumulate exp-weighted V
315                let mut sum_exp = 0.0f32;
316                for sk in 0..seq_kv {
317                    if causal && sk > sq {
318                        continue;
319                    }
320                    let k_off = (bh * seq_kv + sk) * head_dim;
321                    let v_off = (bh * seq_kv + sk) * head_dim;
322                    let mut dot = 0.0f32;
323                    for d in 0..head_dim {
324                        dot += q[q_off + d] * k[k_off + d];
325                    }
326                    let w = (dot * scale_f32 - max_score).exp();
327                    sum_exp += w;
328                    for d in 0..head_dim {
329                        output[o_off + d] += w * v[v_off + d];
330                    }
331                }
332
333                // Normalize
334                if sum_exp > 0.0 {
335                    for d in 0..head_dim {
336                        output[o_off + d] /= sum_exp;
337                    }
338                }
339            }
340        }
341
342        let o_bytes: Vec<u8> = output.iter().flat_map(|f| f.to_ne_bytes()).collect();
343        self.copy_htod(o_ptr, &o_bytes)
344    }
345
346    fn reduce(
347        &self,
348        op: ReduceOp,
349        input_ptr: u64,
350        output_ptr: u64,
351        shape: &[usize],
352        axis: usize,
353    ) -> BackendResult<()> {
354        self.check_init()?;
355
356        if shape.is_empty() {
357            return Err(BackendError::InvalidArgument(
358                "shape must not be empty".into(),
359            ));
360        }
361        if axis >= shape.len() {
362            return Err(BackendError::InvalidArgument(format!(
363                "axis {axis} is out of bounds for shape of length {}",
364                shape.len()
365            )));
366        }
367
368        self.dispatch_reduce(op, input_ptr, output_ptr, shape, axis)
369    }
370
371    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
372        self.check_init()?;
373        if n == 0 {
374            return Ok(());
375        }
376        self.dispatch_unary(op, input_ptr, output_ptr, n)
377    }
378
379    fn binary(
380        &self,
381        op: BinaryOp,
382        a_ptr: u64,
383        b_ptr: u64,
384        output_ptr: u64,
385        n: usize,
386    ) -> BackendResult<()> {
387        self.check_init()?;
388        if n == 0 {
389            return Ok(());
390        }
391        self.dispatch_binary(op, a_ptr, b_ptr, output_ptr, n)
392    }
393
394    // ── Synchronisation ───────────────────────────────────────────────────────
395
396    fn synchronize(&self) -> BackendResult<()> {
397        self.check_init()?;
398
399        #[cfg(any(target_os = "linux", target_os = "windows"))]
400        {
401            if let Some(dev) = &self.device {
402                let api = &dev.api;
403                let queue = dev.queue;
404                // SAFETY: `queue` is a valid command queue handle and the
405                // backend is initialized.  u64::MAX means "wait indefinitely".
406                let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
407                if rc != 0 {
408                    return Err(BackendError::DeviceError(format!(
409                        "zeCommandQueueSynchronize failed: 0x{rc:08x}"
410                    )));
411                }
412            }
413        }
414
415        Ok(())
416    }
417
418    // ── Memory management ─────────────────────────────────────────────────────
419
420    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
421        self.check_init()?;
422        if bytes == 0 {
423            return Err(BackendError::InvalidArgument(
424                "cannot allocate 0 bytes".into(),
425            ));
426        }
427        self.memory()?.alloc(bytes).map_err(BackendError::from)
428    }
429
430    fn free(&self, ptr: u64) -> BackendResult<()> {
431        self.check_init()?;
432        self.memory()?.free(ptr).map_err(BackendError::from)
433    }
434
435    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
436        self.check_init()?;
437        if src.is_empty() {
438            return Ok(());
439        }
440        self.memory()?
441            .copy_to_device(dst, src)
442            .map_err(BackendError::from)
443    }
444
445    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
446        self.check_init()?;
447        if dst.is_empty() {
448            return Ok(());
449        }
450        self.memory()?
451            .copy_from_device(dst, src)
452            .map_err(BackendError::from)
453    }
454}
455
456// ─── Dispatch helpers ────────────────────────────────────────────────────────
457
458/// Workgroup size matching the SPIR-V LocalSize declaration.
459const WORKGROUP_SIZE: u32 = crate::spirv::WORKGROUP_SIZE;
460
461/// A kernel argument value for the Level Zero dispatch pipeline.
462#[cfg_attr(not(any(target_os = "linux", target_os = "windows")), allow(dead_code))]
463enum KernelArg {
464    /// Buffer handle — resolved to a raw device pointer at dispatch time.
465    Buffer(u64),
466    /// 32-bit unsigned integer scalar.
467    U32(u32),
468    /// 32-bit float scalar.
469    F32(f32),
470}
471
472impl LevelZeroBackend {
473    /// Dispatch a SPIR-V compute kernel via Level Zero.
474    ///
475    /// 1. Build a Level Zero module from `spv_words`.
476    /// 2. Create a kernel named `"main"` from the module.
477    /// 3. Set group size and kernel arguments.
478    /// 4. Append a launch to a command list, execute, and wait.
479    /// 5. Clean up all Level Zero objects.
480    fn run_kernel(
481        &self,
482        spv_words: &[u32],
483        args: &[KernelArg],
484        workgroups: u32,
485    ) -> BackendResult<()> {
486        #[cfg(any(target_os = "linux", target_os = "windows"))]
487        {
488            use std::ffi::c_void;
489
490            use crate::device::{
491                ZE_MODULE_FORMAT_IL_SPIRV, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
492                ZE_STRUCTURE_TYPE_KERNEL_DESC, ZE_STRUCTURE_TYPE_MODULE_DESC, ZeCommandListDesc,
493                ZeGroupCount, ZeKernelDesc, ZeKernelHandle, ZeModuleDesc, ZeModuleHandle,
494            };
495
496            let device = self.device.as_ref().ok_or(BackendError::NotInitialized)?;
497            let memory = self.memory()?;
498            let api = &device.api;
499            let context = device.context;
500            let dev_handle = device.device;
501            let queue = device.queue;
502
503            // ── 1. SPIR-V words → bytes ──
504            let spv_bytes: Vec<u8> = spv_words.iter().flat_map(|w| w.to_ne_bytes()).collect();
505
506            // ── 2. Create module ──
507            let module_desc = ZeModuleDesc {
508                stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
509                p_next: std::ptr::null(),
510                format: ZE_MODULE_FORMAT_IL_SPIRV,
511                input_size: spv_bytes.len(),
512                p_input_module: spv_bytes.as_ptr(),
513                p_build_flags: std::ptr::null(),
514                p_constants: std::ptr::null(),
515            };
516            let mut module: ZeModuleHandle = std::ptr::null_mut();
517            let rc = unsafe {
518                (api.ze_module_create)(
519                    context,
520                    dev_handle,
521                    &module_desc,
522                    &mut module as *mut ZeModuleHandle,
523                    std::ptr::null_mut(),
524                )
525            };
526            if rc != 0 {
527                return Err(BackendError::DeviceError(format!(
528                    "zeModuleCreate failed: 0x{rc:08x}"
529                )));
530            }
531
532            // ── 3. Create kernel ──
533            let kernel_name = b"main\0";
534            let kernel_desc = ZeKernelDesc {
535                stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
536                p_next: std::ptr::null(),
537                flags: 0,
538                p_kernel_name: kernel_name.as_ptr(),
539            };
540            let mut kernel: ZeKernelHandle = std::ptr::null_mut();
541            let rc = unsafe {
542                (api.ze_kernel_create)(module, &kernel_desc, &mut kernel as *mut ZeKernelHandle)
543            };
544            if rc != 0 {
545                unsafe { (api.ze_module_destroy)(module) };
546                return Err(BackendError::DeviceError(format!(
547                    "zeKernelCreate failed: 0x{rc:08x}"
548                )));
549            }
550
551            // ── 4. Set group size ──
552            let rc = unsafe { (api.ze_kernel_set_group_size)(kernel, WORKGROUP_SIZE, 1, 1) };
553            if rc != 0 {
554                unsafe {
555                    (api.ze_kernel_destroy)(kernel);
556                    (api.ze_module_destroy)(module);
557                }
558                return Err(BackendError::DeviceError(format!(
559                    "zeKernelSetGroupSize failed: 0x{rc:08x}"
560                )));
561            }
562
563            // ── 5. Set kernel arguments ──
564            for (idx, arg) in args.iter().enumerate() {
565                let rc = match arg {
566                    KernelArg::Buffer(handle) => {
567                        let dev_ptr = memory.device_ptr(*handle).map_err(|e| {
568                            unsafe {
569                                (api.ze_kernel_destroy)(kernel);
570                                (api.ze_module_destroy)(module);
571                            }
572                            BackendError::from(e)
573                        })?;
574                        unsafe {
575                            (api.ze_kernel_set_argument_value)(
576                                kernel,
577                                idx as u32,
578                                std::mem::size_of::<*mut c_void>(),
579                                &dev_ptr as *const *mut c_void as *const c_void,
580                            )
581                        }
582                    }
583                    KernelArg::U32(val) => unsafe {
584                        (api.ze_kernel_set_argument_value)(
585                            kernel,
586                            idx as u32,
587                            std::mem::size_of::<u32>(),
588                            val as *const u32 as *const c_void,
589                        )
590                    },
591                    KernelArg::F32(val) => unsafe {
592                        (api.ze_kernel_set_argument_value)(
593                            kernel,
594                            idx as u32,
595                            std::mem::size_of::<f32>(),
596                            val as *const f32 as *const c_void,
597                        )
598                    },
599                };
600                if rc != 0 {
601                    unsafe {
602                        (api.ze_kernel_destroy)(kernel);
603                        (api.ze_module_destroy)(module);
604                    }
605                    return Err(BackendError::DeviceError(format!(
606                        "zeKernelSetArgumentValue(arg={idx}) failed: 0x{rc:08x}"
607                    )));
608                }
609            }
610
611            // ── 6. Create command list ──
612            let list_desc = ZeCommandListDesc {
613                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
614                p_next: std::ptr::null(),
615                command_queue_group_ordinal: 0,
616                flags: 0,
617            };
618            let mut list = std::ptr::null_mut();
619            let rc =
620                unsafe { (api.ze_command_list_create)(context, dev_handle, &list_desc, &mut list) };
621            if rc != 0 {
622                unsafe {
623                    (api.ze_kernel_destroy)(kernel);
624                    (api.ze_module_destroy)(module);
625                }
626                return Err(BackendError::DeviceError(format!(
627                    "zeCommandListCreate failed: 0x{rc:08x}"
628                )));
629            }
630
631            // ── 7. Append launch kernel ──
632            let group_count = ZeGroupCount {
633                group_count_x: workgroups,
634                group_count_y: 1,
635                group_count_z: 1,
636            };
637            let rc = unsafe {
638                (api.ze_command_list_append_launch_kernel)(
639                    list,
640                    kernel,
641                    &group_count,
642                    0,
643                    0,
644                    std::ptr::null(),
645                )
646            };
647            if rc != 0 {
648                unsafe {
649                    (api.ze_command_list_destroy)(list);
650                    (api.ze_kernel_destroy)(kernel);
651                    (api.ze_module_destroy)(module);
652                }
653                return Err(BackendError::DeviceError(format!(
654                    "zeCommandListAppendLaunchKernel failed: 0x{rc:08x}"
655                )));
656            }
657
658            // ── 8. Close + execute + wait ──
659            let rc = unsafe { (api.ze_command_list_close)(list) };
660            if rc != 0 {
661                unsafe {
662                    (api.ze_command_list_destroy)(list);
663                    (api.ze_kernel_destroy)(kernel);
664                    (api.ze_module_destroy)(module);
665                }
666                return Err(BackendError::DeviceError(format!(
667                    "zeCommandListClose failed: 0x{rc:08x}"
668                )));
669            }
670
671            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
672            if rc != 0 {
673                unsafe {
674                    (api.ze_command_list_destroy)(list);
675                    (api.ze_kernel_destroy)(kernel);
676                    (api.ze_module_destroy)(module);
677                }
678                return Err(BackendError::DeviceError(format!(
679                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
680                )));
681            }
682
683            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
684            if rc != 0 {
685                unsafe {
686                    (api.ze_command_list_destroy)(list);
687                    (api.ze_kernel_destroy)(kernel);
688                    (api.ze_module_destroy)(module);
689                }
690                return Err(BackendError::DeviceError(format!(
691                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
692                )));
693            }
694
695            // ── 9. Clean up ──
696            unsafe {
697                (api.ze_command_list_destroy)(list);
698                (api.ze_kernel_destroy)(kernel);
699                (api.ze_module_destroy)(module);
700            }
701
702            Ok(())
703        }
704
705        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
706        {
707            let _ = (spv_words, args, workgroups);
708            Err(BackendError::DeviceError(
709                "Level Zero requires Linux or Windows".into(),
710            ))
711        }
712    }
713
714    fn dispatch_unary(
715        &self,
716        op: UnaryOp,
717        input_ptr: u64,
718        output_ptr: u64,
719        n: usize,
720    ) -> BackendResult<()> {
721        let spv = crate::spirv::unary_compute_shader(op);
722        let args = [
723            KernelArg::Buffer(input_ptr),
724            KernelArg::Buffer(output_ptr),
725            KernelArg::U32(n as u32),
726        ];
727        self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
728    }
729
730    fn dispatch_binary(
731        &self,
732        op: BinaryOp,
733        a_ptr: u64,
734        b_ptr: u64,
735        output_ptr: u64,
736        n: usize,
737    ) -> BackendResult<()> {
738        let spv = crate::spirv::binary_compute_shader(op);
739        let args = [
740            KernelArg::Buffer(a_ptr),
741            KernelArg::Buffer(b_ptr),
742            KernelArg::Buffer(output_ptr),
743            KernelArg::U32(n as u32),
744        ];
745        self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
746    }
747
748    fn dispatch_reduce(
749        &self,
750        op: ReduceOp,
751        input_ptr: u64,
752        output_ptr: u64,
753        shape: &[usize],
754        axis: usize,
755    ) -> BackendResult<()> {
756        let outer_size: usize = shape[..axis].iter().product::<usize>().max(1);
757        let reduce_size = shape[axis];
758        let inner_size: usize = shape[axis + 1..].iter().product::<usize>().max(1);
759
760        let spv = crate::spirv::reduce_compute_shader(op);
761        let total_output = (outer_size * inner_size) as u32;
762        let args = [
763            KernelArg::Buffer(input_ptr),
764            KernelArg::Buffer(output_ptr),
765            KernelArg::U32(outer_size as u32),
766            KernelArg::U32(reduce_size as u32),
767            KernelArg::U32(inner_size as u32),
768        ];
769        self.run_kernel(&spv, &args, total_output.div_ceil(WORKGROUP_SIZE))
770    }
771
772    #[allow(clippy::too_many_arguments)]
773    fn dispatch_gemm(
774        &self,
775        m: usize,
776        n: usize,
777        k: usize,
778        alpha: f32,
779        a_ptr: u64,
780        b_ptr: u64,
781        beta: f32,
782        c_ptr: u64,
783    ) -> BackendResult<()> {
784        let spv = crate::spirv::gemm_compute_shader();
785        let total = (m * n) as u32;
786        let args = [
787            KernelArg::Buffer(a_ptr),
788            KernelArg::Buffer(b_ptr),
789            KernelArg::Buffer(c_ptr),
790            KernelArg::U32(m as u32),
791            KernelArg::U32(n as u32),
792            KernelArg::U32(k as u32),
793            KernelArg::F32(alpha),
794            KernelArg::F32(beta),
795        ];
796        self.run_kernel(&spv, &args, total.div_ceil(WORKGROUP_SIZE))
797    }
798}
799
800// ─── Tests ───────────────────────────────────────────────────────────────────
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805    use oxicuda_backend::{BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp};
806
807    // ── Construction ──────────────────────────────────────────────────────────
808
809    #[test]
810    fn level_zero_backend_new_uninitialized() {
811        let b = LevelZeroBackend::new();
812        assert!(!b.is_initialized());
813    }
814
815    #[test]
816    fn level_zero_backend_name() {
817        let b = LevelZeroBackend::new();
818        assert_eq!(b.name(), "level-zero");
819    }
820
821    #[test]
822    fn level_zero_backend_default() {
823        let b = LevelZeroBackend::default();
824        assert!(!b.is_initialized());
825        assert_eq!(b.name(), "level-zero");
826    }
827
828    #[test]
829    fn backend_debug_impl() {
830        let b = LevelZeroBackend::new();
831        let s = format!("{b:?}");
832        assert!(s.contains("LevelZeroBackend"));
833    }
834
835    // ── Object-safety smoke test ──────────────────────────────────────────────
836
837    #[test]
838    fn backend_object_safe() {
839        let b: Box<dyn ComputeBackend> = Box::new(LevelZeroBackend::new());
840        assert_eq!(b.name(), "level-zero");
841    }
842
843    // ── Not-initialized guards ────────────────────────────────────────────────
844
845    #[test]
846    fn backend_not_initialized_gemm() {
847        let b = LevelZeroBackend::new();
848        let result = b.gemm(
849            BackendTranspose::NoTrans,
850            BackendTranspose::NoTrans,
851            4,
852            4,
853            4,
854            1.0,
855            0,
856            4,
857            0,
858            4,
859            0.0,
860            0,
861            4,
862        );
863        assert_eq!(result, Err(BackendError::NotInitialized));
864    }
865
866    #[test]
867    fn backend_not_initialized_alloc() {
868        let b = LevelZeroBackend::new();
869        assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
870    }
871
872    #[test]
873    fn backend_not_initialized_synchronize() {
874        let b = LevelZeroBackend::new();
875        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
876    }
877
878    #[test]
879    fn backend_not_initialized_free() {
880        let b = LevelZeroBackend::new();
881        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
882    }
883
884    #[test]
885    fn backend_not_initialized_copy_htod() {
886        let b = LevelZeroBackend::new();
887        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
888    }
889
890    #[test]
891    fn backend_not_initialized_copy_dtoh() {
892        let b = LevelZeroBackend::new();
893        let mut buf = [0u8; 4];
894        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
895    }
896
897    // ── Helper: try to get an initialised backend (skip if no GPU or no loader) ─
898
899    fn try_init() -> Option<LevelZeroBackend> {
900        let mut b = LevelZeroBackend::new();
901        match b.init() {
902            Ok(()) => Some(b),
903            Err(_) => None,
904        }
905    }
906
907    // ── Graceful init failure ─────────────────────────────────────────────────
908
909    #[test]
910    fn init_graceful_failure() {
911        // Verify that init() returns a Result and never panics.
912        let mut b = LevelZeroBackend::new();
913        let _result = b.init();
914        // Ok or Err — both are acceptable.
915    }
916
917    // ── Zero-size / trivial-OK paths (post-init) ──────────────────────────────
918
919    #[test]
920    fn alloc_zero_bytes_error() {
921        let Some(b) = try_init() else {
922            return;
923        };
924        assert_eq!(
925            b.alloc(0),
926            Err(BackendError::InvalidArgument(
927                "cannot allocate 0 bytes".into()
928            ))
929        );
930    }
931
932    #[test]
933    fn copy_htod_empty_noop() {
934        let Some(b) = try_init() else {
935            return;
936        };
937        assert_eq!(b.copy_htod(0, &[]), Ok(()));
938    }
939
940    #[test]
941    fn copy_dtoh_empty_noop() {
942        let Some(b) = try_init() else {
943            return;
944        };
945        assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
946    }
947
948    #[test]
949    fn gemm_zero_dims_noop() {
950        let Some(b) = try_init() else {
951            return;
952        };
953        assert_eq!(
954            b.gemm(
955                BackendTranspose::NoTrans,
956                BackendTranspose::NoTrans,
957                0,
958                0,
959                0,
960                1.0,
961                0,
962                1,
963                0,
964                1,
965                0.0,
966                0,
967                1
968            ),
969            Ok(())
970        );
971    }
972
973    #[test]
974    fn unary_zero_n_noop() {
975        let Some(b) = try_init() else {
976            return;
977        };
978        assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
979    }
980
981    #[test]
982    fn binary_zero_n_noop() {
983        let Some(b) = try_init() else {
984            return;
985        };
986        assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
987    }
988
989    #[test]
990    fn synchronize_after_init() {
991        let Some(b) = try_init() else {
992            return;
993        };
994        assert_eq!(b.synchronize(), Ok(()));
995    }
996
997    // ── Argument validation (post-init) ───────────────────────────────────────
998
999    #[test]
1000    fn reduce_empty_shape_error() {
1001        let Some(b) = try_init() else {
1002            return;
1003        };
1004        assert_eq!(
1005            b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1006            Err(BackendError::InvalidArgument(
1007                "shape must not be empty".into()
1008            ))
1009        );
1010    }
1011
1012    #[test]
1013    fn reduce_axis_out_of_bounds_error() {
1014        let Some(b) = try_init() else {
1015            return;
1016        };
1017        assert_eq!(
1018            b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1019            Err(BackendError::InvalidArgument(
1020                "axis 5 is out of bounds for shape of length 2".into()
1021            ))
1022        );
1023    }
1024
1025    #[test]
1026    fn attention_zero_seq_error() {
1027        let Some(b) = try_init() else {
1028            return;
1029        };
1030        assert_eq!(
1031            b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1032            Err(BackendError::InvalidArgument(
1033                "seq_q, seq_kv, and head_dim must all be > 0".into()
1034            ))
1035        );
1036    }
1037
1038    #[test]
1039    fn attention_invalid_scale_error() {
1040        let Some(b) = try_init() else {
1041            return;
1042        };
1043        assert_eq!(
1044            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1045            Err(BackendError::InvalidArgument(
1046                "scale must be a positive finite number, got 0".into()
1047            ))
1048        );
1049        assert_eq!(
1050            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1051            Err(BackendError::InvalidArgument(
1052                "scale must be a positive finite number, got -1".into()
1053            ))
1054        );
1055        assert!(
1056            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1057                .is_err()
1058        );
1059    }
1060
1061    #[test]
1062    fn conv2d_wrong_input_rank() {
1063        let Some(b) = try_init() else {
1064            return;
1065        };
1066        assert_eq!(
1067            b.conv2d_forward(
1068                0,
1069                &[1, 3, 32],
1070                0,
1071                &[16, 3, 3, 3],
1072                0,
1073                &[1, 16, 30, 30],
1074                &[1, 1],
1075                &[0, 0]
1076            ),
1077            Err(BackendError::InvalidArgument(
1078                "input_shape must have 4 elements (NCHW)".into()
1079            ))
1080        );
1081    }
1082
1083    #[test]
1084    fn conv2d_wrong_filter_rank() {
1085        let Some(b) = try_init() else {
1086            return;
1087        };
1088        assert_eq!(
1089            b.conv2d_forward(
1090                0,
1091                &[1, 3, 32, 32],
1092                0,
1093                &[16, 3, 3],
1094                0,
1095                &[1, 16, 30, 30],
1096                &[1, 1],
1097                &[0, 0]
1098            ),
1099            Err(BackendError::InvalidArgument(
1100                "filter_shape must have 4 elements (KCFHFW)".into()
1101            ))
1102        );
1103    }
1104
1105    // ── Init is idempotent ────────────────────────────────────────────────────
1106
1107    #[test]
1108    fn init_idempotent() {
1109        let Some(mut b) = try_init() else {
1110            return;
1111        };
1112        assert_eq!(b.init(), Ok(()));
1113        assert!(b.is_initialized());
1114    }
1115
1116    // ── alloc/free/copy roundtrip ─────────────────────────────────────────────
1117
1118    #[test]
1119    fn alloc_copy_roundtrip() {
1120        let Some(b) = try_init() else {
1121            return;
1122        };
1123        let src: Vec<u8> = (0u8..64).collect();
1124        let handle = match b.alloc(src.len()) {
1125            Ok(h) => h,
1126            Err(_) => return,
1127        };
1128        b.copy_htod(handle, &src).expect("copy_htod");
1129        let mut dst = vec![0u8; src.len()];
1130        b.copy_dtoh(&mut dst, handle).expect("copy_dtoh");
1131        assert_eq!(src, dst);
1132        b.free(handle).expect("free");
1133    }
1134
1135    // ── Double init is a no-op ────────────────────────────────────────────────
1136
1137    #[test]
1138    fn double_init_is_noop() {
1139        let Some(mut b) = try_init() else {
1140            return;
1141        };
1142        let first = b.is_initialized();
1143        let _ = b.init();
1144        assert_eq!(first, b.is_initialized());
1145    }
1146
1147    // ── alloc and free basic ──────────────────────────────────────────────────
1148
1149    #[test]
1150    fn alloc_and_free_basic() {
1151        let Some(b) = try_init() else {
1152            return;
1153        };
1154        match b.alloc(128) {
1155            Ok(handle) => {
1156                assert!(handle > 0);
1157                b.free(handle).expect("free should succeed");
1158            }
1159            Err(_) => {
1160                // Allocation failure is acceptable in environments without GPU.
1161            }
1162        }
1163    }
1164
1165    // ── Conv2D correctness tests ──────────────────────────────────────────────
1166
1167    /// Helper: allocate device memory, store f32 data, return handle.
1168    fn upload_f32(b: &LevelZeroBackend, data: &[f32]) -> Option<u64> {
1169        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_ne_bytes()).collect();
1170        let handle = b.alloc(bytes.len()).ok()?;
1171        b.copy_htod(handle, &bytes).ok()?;
1172        Some(handle)
1173    }
1174
1175    /// Helper: download f32 data from device.
1176    fn download_f32(b: &LevelZeroBackend, handle: u64, len: usize) -> Option<Vec<f32>> {
1177        let mut bytes = vec![0u8; len * 4];
1178        b.copy_dtoh(&mut bytes, handle).ok()?;
1179        Some(
1180            bytes
1181                .chunks_exact(4)
1182                .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
1183                .collect(),
1184        )
1185    }
1186
1187    #[test]
1188    fn l0_conv2d_identity_1x1() {
1189        let Some(b) = try_init() else {
1190            return;
1191        };
1192        // 1x1x4x4 input, 1x1x1x1 filter = identity (filter=[1.0])
1193        let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1194        let filter = vec![1.0f32];
1195        let output_len = 16;
1196
1197        let Some(in_h) = upload_f32(&b, &input) else {
1198            return;
1199        };
1200        let Some(flt_h) = upload_f32(&b, &filter) else {
1201            return;
1202        };
1203        let Some(out_h) = b.alloc(output_len * 4).ok() else {
1204            return;
1205        };
1206
1207        let result = b.conv2d_forward(
1208            in_h,
1209            &[1, 1, 4, 4],
1210            flt_h,
1211            &[1, 1, 1, 1],
1212            out_h,
1213            &[1, 1, 4, 4],
1214            &[1, 1],
1215            &[0, 0],
1216        );
1217        assert!(result.is_ok(), "conv2d_forward failed: {result:?}");
1218
1219        if let Some(out) = download_f32(&b, out_h, output_len) {
1220            for (i, &val) in out.iter().enumerate() {
1221                assert!(
1222                    (val - input[i]).abs() < 1e-5,
1223                    "mismatch at {i}: expected {}, got {val}",
1224                    input[i]
1225                );
1226            }
1227        }
1228
1229        let _ = b.free(in_h);
1230        let _ = b.free(flt_h);
1231        let _ = b.free(out_h);
1232    }
1233
1234    #[test]
1235    fn l0_conv2d_3x3_basic() {
1236        let Some(b) = try_init() else {
1237            return;
1238        };
1239        // 1x1x4x4 input, 1x1x3x3 filter, stride=1, pad=0 → 1x1x2x2 output
1240        let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1241        // All-ones 3x3 filter: output[oy,ox] = sum of 3x3 window
1242        let filter = vec![1.0f32; 9];
1243        let output_len = 4;
1244
1245        let Some(in_h) = upload_f32(&b, &input) else {
1246            return;
1247        };
1248        let Some(flt_h) = upload_f32(&b, &filter) else {
1249            return;
1250        };
1251        let Some(out_h) = b.alloc(output_len * 4).ok() else {
1252            return;
1253        };
1254
1255        let result = b.conv2d_forward(
1256            in_h,
1257            &[1, 1, 4, 4],
1258            flt_h,
1259            &[1, 1, 3, 3],
1260            out_h,
1261            &[1, 1, 2, 2],
1262            &[1, 1],
1263            &[0, 0],
1264        );
1265        assert!(result.is_ok());
1266
1267        // Expected:
1268        // out[0,0] = 0+1+2+4+5+6+8+9+10 = 45
1269        // out[0,1] = 1+2+3+5+6+7+9+10+11 = 54
1270        // out[1,0] = 4+5+6+8+9+10+12+13+14 = 81
1271        // out[1,1] = 5+6+7+9+10+11+13+14+15 = 90
1272        let expected = [45.0f32, 54.0, 81.0, 90.0];
1273        if let Some(out) = download_f32(&b, out_h, output_len) {
1274            for (i, &val) in out.iter().enumerate() {
1275                assert!(
1276                    (val - expected[i]).abs() < 1e-4,
1277                    "mismatch at {i}: expected {}, got {val}",
1278                    expected[i]
1279                );
1280            }
1281        }
1282
1283        let _ = b.free(in_h);
1284        let _ = b.free(flt_h);
1285        let _ = b.free(out_h);
1286    }
1287
1288    #[test]
1289    fn l0_conv2d_with_padding() {
1290        let Some(b) = try_init() else {
1291            return;
1292        };
1293        // 1x1x3x3 input, 1x1x3x3 filter (all ones), stride=1, pad=1 → 1x1x3x3 output
1294        let input: Vec<f32> = (1..=9).map(|i| i as f32).collect();
1295        let filter = vec![1.0f32; 9];
1296        let output_len = 9;
1297
1298        let Some(in_h) = upload_f32(&b, &input) else {
1299            return;
1300        };
1301        let Some(flt_h) = upload_f32(&b, &filter) else {
1302            return;
1303        };
1304        let Some(out_h) = b.alloc(output_len * 4).ok() else {
1305            return;
1306        };
1307
1308        let result = b.conv2d_forward(
1309            in_h,
1310            &[1, 1, 3, 3],
1311            flt_h,
1312            &[1, 1, 3, 3],
1313            out_h,
1314            &[1, 1, 3, 3],
1315            &[1, 1],
1316            &[1, 1],
1317        );
1318        assert!(result.is_ok());
1319
1320        // Center element: sum of all 9 = 45
1321        if let Some(out) = download_f32(&b, out_h, output_len) {
1322            assert!(
1323                (out[4] - 45.0).abs() < 1e-4,
1324                "center expected 45, got {}",
1325                out[4]
1326            );
1327            // Corner [0,0]: sum of [1,2,4,5] = 12
1328            assert!(
1329                (out[0] - 12.0).abs() < 1e-4,
1330                "corner expected 12, got {}",
1331                out[0]
1332            );
1333        }
1334
1335        let _ = b.free(in_h);
1336        let _ = b.free(flt_h);
1337        let _ = b.free(out_h);
1338    }
1339
1340    // ── Attention correctness tests ───────────────────────────────────────────
1341
1342    #[test]
1343    fn l0_attention_uniform() {
1344        let Some(b) = try_init() else {
1345            return;
1346        };
1347        // batch=1, heads=1, seq_q=2, seq_kv=2, head_dim=2
1348        // Q=K=all zeros → uniform attention → O = mean(V)
1349        let seq_q = 2;
1350        let seq_kv = 2;
1351        let head_dim = 2;
1352        let q = vec![0.0f32; seq_q * head_dim];
1353        let k = vec![0.0f32; seq_kv * head_dim];
1354        let v = vec![1.0f32, 2.0, 3.0, 4.0]; // V[0]=[1,2], V[1]=[3,4]
1355        let o_len = seq_q * head_dim;
1356
1357        let Some(q_h) = upload_f32(&b, &q) else {
1358            return;
1359        };
1360        let Some(k_h) = upload_f32(&b, &k) else {
1361            return;
1362        };
1363        let Some(v_h) = upload_f32(&b, &v) else {
1364            return;
1365        };
1366        let Some(o_h) = b.alloc(o_len * 4).ok() else {
1367            return;
1368        };
1369        // Zero out output
1370        let zeros = vec![0u8; o_len * 4];
1371        let _ = b.copy_htod(o_h, &zeros);
1372
1373        let scale = 1.0 / (head_dim as f64).sqrt();
1374        let result = b.attention(
1375            q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1376        );
1377        assert!(result.is_ok(), "attention failed: {result:?}");
1378
1379        // With uniform attention weights, output = mean(V rows)
1380        // mean = [(1+3)/2, (2+4)/2] = [2, 3]
1381        if let Some(out) = download_f32(&b, o_h, o_len) {
1382            // Both query positions should get the same result
1383            for sq_idx in 0..seq_q {
1384                let off = sq_idx * head_dim;
1385                assert!(
1386                    (out[off] - 2.0).abs() < 1e-4,
1387                    "q{sq_idx}[0] expected 2.0, got {}",
1388                    out[off]
1389                );
1390                assert!(
1391                    (out[off + 1] - 3.0).abs() < 1e-4,
1392                    "q{sq_idx}[1] expected 3.0, got {}",
1393                    out[off + 1]
1394                );
1395            }
1396        }
1397
1398        let _ = b.free(q_h);
1399        let _ = b.free(k_h);
1400        let _ = b.free(v_h);
1401        let _ = b.free(o_h);
1402    }
1403
1404    #[test]
1405    fn l0_attention_causal() {
1406        let Some(b) = try_init() else {
1407            return;
1408        };
1409        // batch=1, heads=1, seq_q=2, seq_kv=2, head_dim=2, causal=true
1410        let seq_q = 2;
1411        let seq_kv = 2;
1412        let head_dim = 2;
1413        let q = vec![0.0f32; seq_q * head_dim];
1414        let k = vec![0.0f32; seq_kv * head_dim];
1415        let v = vec![1.0f32, 2.0, 3.0, 4.0];
1416        let o_len = seq_q * head_dim;
1417
1418        let Some(q_h) = upload_f32(&b, &q) else {
1419            return;
1420        };
1421        let Some(k_h) = upload_f32(&b, &k) else {
1422            return;
1423        };
1424        let Some(v_h) = upload_f32(&b, &v) else {
1425            return;
1426        };
1427        let Some(o_h) = b.alloc(o_len * 4).ok() else {
1428            return;
1429        };
1430        let zeros = vec![0u8; o_len * 4];
1431        let _ = b.copy_htod(o_h, &zeros);
1432
1433        let scale = 1.0 / (head_dim as f64).sqrt();
1434        let result = b.attention(
1435            q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, true,
1436        );
1437        assert!(result.is_ok());
1438
1439        if let Some(out) = download_f32(&b, o_h, o_len) {
1440            // q=0 (causal: can only attend to k=0): output = V[0] = [1, 2]
1441            assert!(
1442                (out[0] - 1.0).abs() < 1e-4,
1443                "q0[0] expected 1.0, got {}",
1444                out[0]
1445            );
1446            assert!(
1447                (out[1] - 2.0).abs() < 1e-4,
1448                "q0[1] expected 2.0, got {}",
1449                out[1]
1450            );
1451            // q=1 (can attend to k=0,1): output = mean(V) = [2, 3]
1452            assert!(
1453                (out[2] - 2.0).abs() < 1e-4,
1454                "q1[0] expected 2.0, got {}",
1455                out[2]
1456            );
1457            assert!(
1458                (out[3] - 3.0).abs() < 1e-4,
1459                "q1[1] expected 3.0, got {}",
1460                out[3]
1461            );
1462        }
1463
1464        let _ = b.free(q_h);
1465        let _ = b.free(k_h);
1466        let _ = b.free(v_h);
1467        let _ = b.free(o_h);
1468    }
1469
1470    #[test]
1471    fn l0_attention_dominant_key() {
1472        let Some(b) = try_init() else {
1473            return;
1474        };
1475        // One key has a very large dot product → attention should be concentrated on it
1476        let seq_q = 1;
1477        let seq_kv = 3;
1478        let head_dim = 2;
1479        // Q = [10, 0]
1480        let q = vec![10.0f32, 0.0];
1481        // K = [[10, 0], [0, 0], [0, 0]]  → dot(Q,K[0]) = 100, others = 0
1482        let k = vec![10.0f32, 0.0, 0.0, 0.0, 0.0, 0.0];
1483        let v = vec![1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0]; // V[0]=[1,0], V[1]=[0,1], V[2]=[0,0]
1484        let o_len = seq_q * head_dim;
1485
1486        let Some(q_h) = upload_f32(&b, &q) else {
1487            return;
1488        };
1489        let Some(k_h) = upload_f32(&b, &k) else {
1490            return;
1491        };
1492        let Some(v_h) = upload_f32(&b, &v) else {
1493            return;
1494        };
1495        let Some(o_h) = b.alloc(o_len * 4).ok() else {
1496            return;
1497        };
1498        let zeros = vec![0u8; o_len * 4];
1499        let _ = b.copy_htod(o_h, &zeros);
1500
1501        let scale = 1.0;
1502        let result = b.attention(
1503            q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1504        );
1505        assert!(result.is_ok());
1506
1507        if let Some(out) = download_f32(&b, o_h, o_len) {
1508            // With dot=100*scale=100 for K[0] vs 0 for others,
1509            // softmax should heavily favour V[0]=[1,0]
1510            assert!(out[0] > 0.99, "expected output[0] ≈ 1.0, got {}", out[0]);
1511            assert!(out[1] < 0.01, "expected output[1] ≈ 0.0, got {}", out[1]);
1512        }
1513
1514        let _ = b.free(q_h);
1515        let _ = b.free(k_h);
1516        let _ = b.free(v_h);
1517        let _ = b.free(o_h);
1518    }
1519}